llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Helena Kotas (hekota)

<details>
<summary>Changes</summary>

Adjust register binding diagnostic flags code in a couple of ways:
- Store the resource class in the Flags struct to avoid duplicated scanning for 
HLSLResourceClassAttribute
- Avoid unnecessary indirection when converting resource class to register type
- Remove recursion and reduce duplicated code

Also fixes a case where struct with an array was incorrectly diagnosed unfit 
for `c` register binding.

This will also simplify work that is needed to be done in this area for 
llvm/llvm-project#<!-- -->104861. 

---
Full diff: https://github.com/llvm/llvm-project/pull/106657.diff


2 Files Affected:

- (modified) clang/lib/Sema/SemaHLSL.cpp (+68-113) 
- (modified) clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl (+7) 


``````````diff
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 714e8f5cfa9926..1e484f754b931d 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -480,6 +480,9 @@ struct RegisterBindingFlags {
 
   bool ContainsNumeric = false;
   bool DefaultGlobals = false;
+
+  // used only when Resource == true
+  llvm::dxil::ResourceClass ResourceClass = llvm::dxil::ResourceClass::UAV;
 };
 
 static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
@@ -545,65 +548,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl 
*VD) {
   return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl);
 }
 
-static void updateFlagsFromType(QualType TheQualTy,
-                                RegisterBindingFlags &Flags);
-
-static void updateResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
-                                                   const RecordDecl *RD) {
-  if (!RD)
-    return;
-
-  if (RD->isCompleteDefinition()) {
-    for (auto Field : RD->fields()) {
-      QualType T = Field->getType();
-      updateFlagsFromType(T, Flags);
+static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
+                                                   const RecordType *RT) {
+  llvm::SmallVector<const Type *> TypesToScan;
+  TypesToScan.emplace_back(RT);
+
+  while (!TypesToScan.empty()) {
+    const Type *T = TypesToScan.pop_back_val();
+    while (T->isArrayType())
+      T = T->getArrayElementTypeNoTypeQual();
+    if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+      Flags.ContainsNumeric = true;
+      continue;
     }
-  }
-}
-
-static void updateFlagsFromType(QualType TheQualTy,
-                                RegisterBindingFlags &Flags) {
-  // if the member's type is a numeric type, set the ContainsNumeric flag
-  if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) 
{
-    Flags.ContainsNumeric = true;
-    return;
-  }
-
-  const clang::Type *TheBaseType = TheQualTy.getTypePtr();
-  while (TheBaseType->isArrayType())
-    TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-  // otherwise, if the member's base type is not a record type, return
-  const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-  if (!TheRecordTy)
-    return;
-
-  RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
-  const HLSLResourceClassAttr *Attr =
-      getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl);
-  // find the attr if it's on the member, or on any of the member's fields
-  if (Attr) {
-    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
-    updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
-  }
+    const RecordType *RT = T->getAs<RecordType>();
+    if (!RT)
+      continue;
 
-  // otherwise, dig deeper and recurse into the member
-  else {
-    updateResourceClassFlagsFromRecordDecl(Flags, SubRecordDecl);
+    const RecordDecl *RD = RT->getDecl();
+    for (FieldDecl *FD : RD->fields()) {
+      if (HLSLResourceClassAttr *RCAttr =
+              FD->getAttr<HLSLResourceClassAttr>()) {
+        updateResourceClassFlagsFromDeclResourceClass(
+            Flags, RCAttr->getResourceClass());
+        continue;
+      }
+      TypesToScan.emplace_back(FD->getType().getTypePtr());
+    }
   }
 }
 
 static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
                                                          Decl *TheDecl) {
-
-  // Cbuffers and Tbuffers are HLSLBufferDecl types
-  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
-  // Samplers, UAVs, and SRVs are VarDecl types
-  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
-
-  assert(((TheVarDecl && !CBufferOrTBuffer) ||
-          (!TheVarDecl && CBufferOrTBuffer)) &&
-         "either TheVarDecl or CBufferOrTBuffer should be set");
-
   RegisterBindingFlags Flags;
 
   // check if the decl type is groupshared
@@ -612,57 +588,61 @@ static RegisterBindingFlags 
HLSLFillRegisterBindingFlags(Sema &S,
     return Flags;
   }
 
-  if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
-    // make sure the type is a basic / numeric type
-    if (TheVarDecl) {
-      QualType TheQualTy = TheVarDecl->getType();
-      // a numeric variable or an array of numeric variables
-      // will inevitably end up in $Globals buffer
-      const clang::Type *TheBaseType = TheQualTy.getTypePtr();
-      while (TheBaseType->isArrayType())
-        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-      if (TheBaseType->isIntegralType(S.getASTContext()) ||
-          TheBaseType->isFloatingType())
-        Flags.DefaultGlobals = true;
-    }
-  }
-
-  if (CBufferOrTBuffer) {
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
     Flags.Resource = true;
-    if (CBufferOrTBuffer->isCBuffer())
-      Flags.CBV = true;
-    else
-      Flags.SRV = true;
-  } else if (TheVarDecl) {
+    Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
+                              ? llvm::dxil::ResourceClass::CBuffer
+                              : llvm::dxil::ResourceClass::SRV;
+  }
+  // Samplers, UAVs, and SRVs are VarDecl types
+  else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
     const HLSLResourceClassAttr *resClassAttr =
         getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
-
     if (resClassAttr) {
-      llvm::hlsl::ResourceClass DeclResourceClass =
-          resClassAttr->getResourceClass();
       Flags.Resource = true;
-      updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+      Flags.ResourceClass = resClassAttr->getResourceClass();
     } else {
       const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
       while (TheBaseType->isArrayType())
         TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-      if (TheBaseType->isArithmeticType())
+
+      if (TheBaseType->isArithmeticType()) {
         Flags.Basic = true;
-      else if (TheBaseType->isRecordType()) {
+        if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
+            (TheBaseType->isIntegralType(S.getASTContext()) ||
+             TheBaseType->isFloatingType()))
+          Flags.DefaultGlobals = true;
+      } else if (TheBaseType->isRecordType()) {
         Flags.UDT = true;
         const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-        assert(TheRecordTy && "The Qual Type should be Record Type");
-        const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
-        // recurse through members, set appropriate resource class flags.
-        updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
+        updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
       } else
         Flags.Other = true;
     }
+  } else {
+    llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
   }
   return Flags;
 }
 
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+enum class RegisterType {
+  SRV = static_cast<int>(llvm::dxil::ResourceClass::SRV),
+  UAV = static_cast<int>(llvm::dxil::ResourceClass::UAV),
+  CBuffer = static_cast<int>(llvm::dxil::ResourceClass::CBuffer),
+  Sampler = static_cast<int>(llvm::dxil::ResourceClass::Sampler),
+  C,
+  I,
+  Invalid
+};
+
+static RegisterType
+convertResourceClassToRegisterType(llvm::dxil::ResourceClass RC) {
+  assert(RC >= llvm::dxil::ResourceClass::SRV &&
+         RC <= llvm::dxil::ResourceClass::Sampler &&
+         "unexpected resource class value");
+  return static_cast<RegisterType>(RC);
+}
 
 static RegisterType getRegisterType(StringRef Slot) {
   switch (Slot[0]) {
@@ -754,34 +734,9 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, 
SourceLocation &ArgLoc,
   // next, if resource is set, make sure the register type in the register
   // annotation is compatible with the variable's resource type.
   if (Flags.Resource) {
-    const HLSLResourceClassAttr *resClassAttr = nullptr;
-    if (CBufferOrTBuffer) {
-      resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
-    } else if (TheVarDecl) {
-      resClassAttr =
-          getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
-    }
-
-    assert(resClassAttr &&
-           "any decl that set the resource flag on analysis should "
-           "have a resource class attribute attached.");
-    const llvm::hlsl::ResourceClass DeclResourceClass =
-        resClassAttr->getResourceClass();
-
-    // confirm that the register type is bound to its expected resource class
-    static RegisterType ExpectedRegisterTypesForResourceClass[] = {
-        RegisterType::SRV,
-        RegisterType::UAV,
-        RegisterType::CBuffer,
-        RegisterType::Sampler,
-    };
-    assert((size_t)DeclResourceClass <
-               std::size(ExpectedRegisterTypesForResourceClass) &&
-           "DeclResourceClass has unexpected value");
-
-    RegisterType ExpectedRegisterType =
-        ExpectedRegisterTypesForResourceClass[(int)DeclResourceClass];
-    if (regType != ExpectedRegisterType) {
+    RegisterType expRegType =
+        convertResourceClassToRegisterType(Flags.ResourceClass);
+    if (regType != expRegType) {
       S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
           << regTypeNum;
     }
@@ -823,7 +778,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, 
SourceLocation &ArgLoc,
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
-  if (dyn_cast<VarDecl>(TheDecl)) {
+  if (isa<VarDecl>(TheDecl)) {
     if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
                                     cast<ValueDecl>(TheDecl)->getType(),
                                     diag::err_incomplete_type))
diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl 
b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
index f8e38b6d2851d9..edb3f30739cdfd 100644
--- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
+++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
@@ -126,3 +126,10 @@ struct Eg14{
 };
 // expected-warning@+1{{binding type 't' only applies to types containing SRV 
resources}}
 Eg14 e14 : register(t9);
+
+struct Eg15 {
+  float f[4];
+}; 
+// expected no error
+Eg15 e15 : register(c0);
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/106657
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to