Author: Helena Kotas
Date: 2024-09-18T10:51:30-07:00
New Revision: f2128267c26e548bef59209e7a351ff94d343bf3

URL: 
https://github.com/llvm/llvm-project/commit/f2128267c26e548bef59209e7a351ff94d343bf3
DIFF: 
https://github.com/llvm/llvm-project/commit/f2128267c26e548bef59209e7a351ff94d343bf3.diff

LOG: [HLSL][NFC] Remove RegisterBindingFlags struct (#108924)

When diagnosing register bindings we just need to make sure there is a
resource that matches the provided register type. We can emit the
diagnostics right away instead of collecting flags in the
RegisterBindingFlags struct. That also enables early exit when scanning
user defined types because we can return as soon as we find a matching
resource for the given register type.

Added: 
    

Modified: 
    clang/lib/Sema/SemaHLSL.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a303f211501348..03b7c2edb605fe 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -40,6 +40,48 @@
 #include <utility>
 
 using namespace clang;
+using llvm::dxil::ResourceClass;
+
+enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+static RegisterType getRegisterType(ResourceClass RC) {
+  switch (RC) {
+  case ResourceClass::SRV:
+    return RegisterType::SRV;
+  case ResourceClass::UAV:
+    return RegisterType::UAV;
+  case ResourceClass::CBuffer:
+    return RegisterType::CBuffer;
+  case ResourceClass::Sampler:
+    return RegisterType::Sampler;
+  }
+  llvm_unreachable("unexpected ResourceClass value");
+}
+
+static RegisterType getRegisterType(StringRef Slot) {
+  switch (Slot[0]) {
+  case 't':
+  case 'T':
+    return RegisterType::SRV;
+  case 'u':
+  case 'U':
+    return RegisterType::UAV;
+  case 'b':
+  case 'B':
+    return RegisterType::CBuffer;
+  case 's':
+  case 'S':
+    return RegisterType::Sampler;
+  case 'c':
+  case 'C':
+    return RegisterType::C;
+  case 'i':
+  case 'I':
+    return RegisterType::I;
+  default:
+    return RegisterType::Invalid;
+  }
+}
 
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
@@ -586,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType(
     LocEnd = A->getRange().getEnd();
     switch (A->getKind()) {
     case attr::HLSLResourceClass: {
-      llvm::dxil::ResourceClass RC =
-          cast<HLSLResourceClassAttr>(A)->getResourceClass();
+      ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
       if (HasResourceClass) {
         S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
                                      ? diag::warn_duplicate_attribute_exact
@@ -672,7 +713,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) 
{
     SourceLocation ArgLoc = Loc->Loc;
 
     // Validate resource class value
-    llvm::dxil::ResourceClass RC;
+    ResourceClass RC;
     if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
       Diag(ArgLoc, diag::warn_attribute_type_not_supported)
           << "ResourceClass" << Identifier;
@@ -750,28 +791,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const 
HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-struct RegisterBindingFlags {
-  bool Resource = false;
-  bool UDT = false;
-  bool Other = false;
-  bool Basic = false;
-
-  bool SRV = false;
-  bool UAV = false;
-  bool CBV = false;
-  bool Sampler = false;
-
-  bool ContainsNumeric = false;
-  bool DefaultGlobals = false;
-
-  // used only when Resource == true
-  std::optional<llvm::dxil::ResourceClass> ResourceClass;
-};
-
-static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
-  return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext());
-}
-
 // get the record decl from a var decl that we expect
 // represents a resource
 static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
@@ -786,24 +805,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl 
*VD) {
   return TheRecordDecl;
 }
 
-static void updateResourceClassFlagsFromDeclResourceClass(
-    RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
-  switch (DeclResourceClass) {
-  case llvm::hlsl::ResourceClass::SRV:
-    Flags.SRV = true;
-    break;
-  case llvm::hlsl::ResourceClass::UAV:
-    Flags.UAV = true;
-    break;
-  case llvm::hlsl::ResourceClass::CBuffer:
-    Flags.CBV = true;
-    break;
-  case llvm::hlsl::ResourceClass::Sampler:
-    Flags.Sampler = true;
-    break;
-  }
-}
-
 const HLSLAttributedResourceType *
 findAttributedResourceTypeOnField(VarDecl *VD) {
   assert(VD != nullptr && "expected VarDecl");
@@ -817,8 +818,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
   return nullptr;
 }
 
-static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
-                                                   const RecordType *RT) {
+// Iterate over RecordType fields and return true if any of them matched the
+// register type
+static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
+                                            RegisterType RegType) {
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
 
@@ -827,8 +830,8 @@ static void 
updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
     while (T->isArrayType())
       T = T->getArrayElementTypeNoTypeQual();
     if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      Flags.ContainsNumeric = true;
-      continue;
+      if (RegType == RegisterType::C)
+        return true;
     }
     const RecordType *RT = T->getAs<RecordType>();
     if (!RT)
@@ -839,100 +842,84 @@ static void 
updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
       const Type *FieldTy = FD->getType().getTypePtr();
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        updateResourceClassFlagsFromDeclResourceClass(
-            Flags, AttrResType->getAttrs().ResourceClass);
-        continue;
+        ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+        if (getRegisterType(RC) == RegType)
+          return true;
+      } else {
+        TypesToScan.emplace_back(FD->getType().getTypePtr());
       }
-      TypesToScan.emplace_back(FD->getType().getTypePtr());
     }
   }
+  return false;
 }
 
-static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
-                                                         Decl *TheDecl) {
-  RegisterBindingFlags Flags;
+static void CheckContainsResourceForRegisterType(Sema &S,
+                                                 SourceLocation &ArgLoc,
+                                                 Decl *D, RegisterType RegType,
+                                                 bool SpecifiedSpace) {
+  int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
-  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
-    Flags.Other = true;
-    return Flags;
+  if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
-  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
-    Flags.Resource = true;
-    Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
-                              ? llvm::dxil::ResourceClass::CBuffer
-                              : llvm::dxil::ResourceClass::SRV;
+  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+    ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
+                                                     : ResourceClass::SRV;
+    if (RegType != getRegisterType(RC))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
+    return;
   }
+
   // Samplers, UAVs, and SRVs are VarDecl types
-  else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
-    if (const HLSLAttributedResourceType *AttrResType =
-            findAttributedResourceTypeOnField(TheVarDecl)) {
-      Flags.Resource = true;
-      Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
-    } else {
-      const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
-      while (TheBaseType->isArrayType())
-        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-
-      if (TheBaseType->isArithmeticType()) {
-        Flags.Basic = true;
-        if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
-            (TheBaseType->isIntegralType(S.getASTContext()) ||
-             TheBaseType->isFloatingType()))
-          Flags.DefaultGlobals = true;
-      } else if (TheBaseType->isRecordType()) {
-        Flags.UDT = true;
-        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-        updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
-      } else
-        Flags.Other = true;
-    }
-  } else {
-    llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
+  assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
+  VarDecl *VD = cast<VarDecl>(D);
+
+  // Resource
+  if (const HLSLAttributedResourceType *AttrResType =
+          findAttributedResourceTypeOnField(VD)) {
+    if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
+    return;
   }
-  return Flags;
-}
 
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+  const clang::Type *Ty = VD->getType().getTypePtr();
+  while (Ty->isArrayType())
+    Ty = Ty->getArrayElementTypeNoTypeQual();
 
-static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
-  switch (RC) {
-  case llvm::dxil::ResourceClass::SRV:
-    return RegisterType::SRV;
-  case llvm::dxil::ResourceClass::UAV:
-    return RegisterType::UAV;
-  case llvm::dxil::ResourceClass::CBuffer:
-    return RegisterType::CBuffer;
-  case llvm::dxil::ResourceClass::Sampler:
-    return RegisterType::Sampler;
-  }
-  llvm_unreachable("unexpected ResourceClass value");
-}
+  // Basic types
+  if (Ty->isArithmeticType()) {
+    bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
+    if (SpecifiedSpace && !DeclaredInCOrTBuffer)
+      S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
 
-static RegisterType getRegisterType(StringRef Slot) {
-  switch (Slot[0]) {
-  case 't':
-  case 'T':
-    return RegisterType::SRV;
-  case 'u':
-  case 'U':
-    return RegisterType::UAV;
-  case 'b':
-  case 'B':
-    return RegisterType::CBuffer;
-  case 's':
-  case 'S':
-    return RegisterType::Sampler;
-  case 'c':
-  case 'C':
-    return RegisterType::C;
-  case 'i':
-  case 'I':
-    return RegisterType::I;
-  default:
-    return RegisterType::Invalid;
+    if (!DeclaredInCOrTBuffer &&
+        (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
+      // Default Globals
+      if (RegType == RegisterType::CBuffer)
+        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
+      else if (RegType != RegisterType::C)
+        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    } else {
+      if (RegType == RegisterType::C)
+        S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
+      else
+        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    }
+  } else if (Ty->isRecordType()) {
+    // Class/struct types - walk the declaration and check each field and
+    // subclass
+    if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
+      S.Diag(D->getLocation(), 
diag::warn_hlsl_user_defined_type_missing_member)
+          << RegTypeNum;
+  } else {
+    // Anything else is an error
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
   }
 }
 
@@ -969,76 +956,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, 
Decl *TheDecl,
 }
 
 static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
-                                          Decl *TheDecl, RegisterType RegType,
-                                          const bool SpecifiedSpace) {
+                                          Decl *D, RegisterType RegType,
+                                          bool SpecifiedSpace) {
 
   // exactly one of these two types should be set
-  assert(((isa<VarDecl>(TheDecl) && !isa<HLSLBufferDecl>(TheDecl)) ||
-          (!isa<VarDecl>(TheDecl) && isa<HLSLBufferDecl>(TheDecl))) &&
+  assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
+          (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
          "expecting VarDecl or HLSLBufferDecl");
 
-  RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
-  assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
-                 (int)Flags.UDT ==
-             1 &&
-         "only one resource analysis result should be expected");
-
-  int RegTypeNum = static_cast<int>(RegType);
-
-  // first, if "other" is set, emit an error
-  if (Flags.Other) {
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-    return;
-  }
+  // check if the declaration contains resource matching the register type
+  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);
 
   // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, TheDecl, RegType);
-
-  // next, if resource is set, make sure the register type in the register
-  // annotation is compatible with the variable's resource type.
-  if (Flags.Resource) {
-    RegisterType ExpRegType = getRegisterType(Flags.ResourceClass.value());
-    if (RegType != ExpRegType) {
-      S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << RegTypeNum;
-    }
-
-    return;
-  }
-
-  // next, handle diagnostics for when the "basic" flag is set
-  if (Flags.Basic) {
-    if (SpecifiedSpace && !isDeclaredWithinCOrTBuffer(TheDecl))
-      S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);
-
-    if (Flags.DefaultGlobals) {
-      if (RegType == RegisterType::CBuffer)
-        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (RegType != RegisterType::C)
-        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-      return;
-    }
-
-    if (RegType == RegisterType::C)
-      S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-    else
-      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
-
-    return;
-  }
-
-  // finally, we handle the udt case
-  if (Flags.UDT) {
-    const bool ExpectedRegisterTypesForUDT[] = {
-        Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric};
-    assert((size_t)RegTypeNum < std::size(ExpectedRegisterTypesForUDT) &&
-           "regType has unexpected value");
-
-    if (!ExpectedRegisterTypesForUDT[RegTypeNum])
-      S.Diag(TheDecl->getLocation(),
-             diag::warn_hlsl_user_defined_type_missing_member)
-          << RegTypeNum;
-  }
+  ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to