python3kgae created this revision.
python3kgae added reviewers: pow2clk, beanz, bogner.
Herald added a subscriber: Anastasia.
Herald added a project: All.
python3kgae requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Only RWByteAddressBuffer::Load is supported in this PR.
Implement Load by cast handle to uint8_t ptr first.
Then add the byte address to get target address.
Cast the target address to target type ptr and deref.

struct RWByteAddressBuffer {

  void *handle;
  template<typename T> Load(uint i) {
     return *((T*)(((uint8_t*)handle) + i));
  }

};

Allow ref/ptr use for HLSLExternalSemaSource to skip the check when Location is 
invalid.

Also add HLSLResourceHelper to share code for RWByteAddressBuffer and RWBuffer.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D136913

Files:
  clang/include/clang/Sema/HLSLExternalSemaSource.h
  clang/lib/Sema/HLSLExternalSemaSource.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaExprMember.cpp
  clang/test/AST/HLSL/RWByteAddressBuffer.hlsl
  clang/test/SemaHLSL/BuiltIns/RWByteAddressBuffer.hlsl

Index: clang/test/SemaHLSL/BuiltIns/RWByteAddressBuffer.hlsl
===================================================================
--- /dev/null
+++ clang/test/SemaHLSL/BuiltIns/RWByteAddressBuffer.hlsl
@@ -0,0 +1,12 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -fsyntax-only -verify %s
+
+typedef vector<float, 3> float3;
+
+RWByteAddressBuffer Buffer;
+
+[numthreads(1,1,1)]
+void main() {
+  (void)Buffer.h; // expected-error {{'h' is a private member of 'hlsl::RWByteAddressBuffer'}}
+  // expected-note@* {{implicitly declared private here}}
+  Buffer.Load<float3>(0);
+}
Index: clang/test/AST/HLSL/RWByteAddressBuffer.hlsl
===================================================================
--- /dev/null
+++ clang/test/AST/HLSL/RWByteAddressBuffer.hlsl
@@ -0,0 +1,104 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -fsyntax-only -ast-dump %s | FileCheck %s
+
+RWByteAddressBuffer U;
+typedef unsigned int uint;
+typedef vector<float, 2> float2;
+
+// CHECK:CXXRecordDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> implicit referenced <undeserialized declarations> class RWByteAddressBuffer definition
+// CHECK:FinalAttr 0x{{[0-9a-f]+}} <<invalid sloc>> Implicit final
+// CHECK-NEXT:HLSLResourceAttr 0x{{[0-9a-f]+}} <<invalid sloc>> Implicit UAV RawBuffer
+// CHECK-NEXT:-FieldDecl 0x[[HDL:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> implicit referenced h 'void *'
+// CHECK-NEXT:CXXConstructorDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> used RWByteAddressBuffer 'void ()' inline
+// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <<invalid sloc>>
+// CHECK-NEXT:BinaryOperator 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' lvalue '='
+// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' lvalue ->h 0x[[HDL]]
+// CHECK-NEXT:-CXXThisExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'hlsl::RWByteAddressBuffer *' implicit this
+// CHECK-NEXT:CallExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *'
+// CHECK-NEXT:-DeclRefExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *(unsigned char) throw()' Function 0x{{[0-9a-f]+}} '__builtin_hlsl_create_handle' 'void *(unsigned char) throw()'
+// CHECK-NEXT:-IntegerLiteral 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char' 1
+// CHECK-NEXT:FunctionTemplateDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> Load
+// CHECK-NEXT:TemplateTypeParmDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 T
+// CHECK-NEXT:CXXMethodDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> Load 'T (unsigned int)'
+// CHECK-NEXT:ParmVarDecl 0x[[LOAD_PARAM:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> ByteAddress 'unsigned int'
+// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <<invalid sloc>>
+// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <<invalid sloc>>
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'T' <LValueToRValue>
+// CHECK-NEXT:UnaryOperator 0x{{[0-9a-f]+}} <<invalid sloc>> 'T' lvalue prefix '*' cannot overflow
+// CHECK-NEXT:CStyleCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'T *' lvalue <BitCast>
+// CHECK-NEXT:inaryOperator 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char *' '+'
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char *' <LValueToRValue>
+// CHECK-NEXT:CStyleCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char *' lvalue <BitCast>
+// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' lvalue ->h 0x[[HDL]]
+// CHECK-NEXT:CXXThisExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'hlsl::RWByteAddressBuffer *' implicit this
+// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x[[LOAD_PARAM]] 'ByteAddress' 'unsigned int'
+
+// CHECK:CXXMethodDecl 0x[[LOAD_F:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> used Load 'float (unsigned int)'
+// CHECK-NEXT:TemplateArgument type 'float'
+// CHECK-NEXT:BuiltinType 0x{{[0-9a-f]+}} 'float'
+
+// CHECK:CXXMethodDecl 0x[[LOAD_F2:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> used Load 'float (unsigned int) __attribute__((ext_vector_type(2)))'
+// CHECK-NEXT:TemplateArgument type 'float __attribute__((ext_vector_type(2)))'
+// CHECK-NEXT:ExtVectorType 0x{{[0-9a-f]+}} 'float __attribute__((ext_vector_type(2)))' 2
+// CHECK-NEXT:BuiltinType 0x{{[0-9a-f]+}} 'float'
+
+// CHECK:CXXMethodDecl 0x[[LOAD_S:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> used Load 'S (unsigned int)'
+// CHECK-NEXT:TemplateArgument type 'S'
+// CHECK-NEXT:-RecordType 0x{{[0-9a-f]+}} 'S'
+// CHECK-NEXT:CXXRecord 0x[[RECORD:[0-9a-f]+]] 'S'
+
+// CHECK:VarDecl 0x[[BUF:[0-9a-f]+]] <{{.+}}:1, col:21> col:21 used U 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer'
+
+// CHECK:FunctionDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:7 foo 'float (uint)'
+// CHECK-NEXT:ParmVarDecl 0x[[FOO_PARAM:[0-9a-f]+]] <col:11, col:16> col:16 used i 'uint':'unsigned int'
+// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <col:19, line:{{[0-9]+}}:1>
+// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:27>
+// CHECK-NEXT:CXXMemberCallExpr 0x{{[0-9a-f]+}} <col:12, col:27> 'float':'float'
+// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <col:12, col:24> '<bound member function type>' .Load 0x[[LOAD_F]]
+// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:12> 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' lvalue Var 0x[[BUF]] 'U' 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer'
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:26> 'uint':'unsigned int' <LValueToRValue>
+// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:26> 'uint':'unsigned int' lvalue ParmVar 0x[[FOO_PARAM]] 'i' 'uint':'unsigned int'
+
+// CHECK:FunctionDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:8 foo2 'float2 (uint)'
+// CHECK-NEXT:ParmVarDecl 0x[[FOO2_PARAM:[0-9a-f]+]] <col:13, col:18> col:18 used i 'uint':'unsigned int'
+// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <col:21, line:{{[0-9]+}}:1>
+// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:28>
+// CHECK-NEXT:CXXMemberCallExpr 0x{{[0-9a-f]+}} <col:12, col:28> 'float __attribute__((ext_vector_type(2)))':'float __attribute__((ext_vector_type(2)))'
+// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <col:12, col:25> '<bound member function type>' .Load 0x[[LOAD_F2]]
+// CHECK-NEXT:-DeclRefExpr 0x{{[0-9a-f]+}} <col:12> 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' lvalue Var 0x[[BUF]] 'U' 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer'
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:27> 'uint':'unsigned int' <LValueToRValue>
+// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:27> 'uint':'unsigned int' lvalue ParmVar 0x[[FOO2_PARAM]] 'i' 'uint':'unsigned int'
+
+// CHECK:CXXRecordDecl 0x[[RECORD]] <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:8 referenced struct S definition
+
+// CHECK:CXXRecordDecl 0x{{[0-9a-f]+}} <col:1, col:8> col:8 implicit struct S
+// CHECK-NEXT:FieldDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:11> col:11 referenced a 'float'
+
+// CHECK:FunctionDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:3 foo3 'S (uint)'
+// CHECK-NEXT:ParmVarDecl 0x[[FOO3_PARAM:[0-9a-f]+]] <col:8, col:13> col:13 used i 'uint':'unsigned int'
+// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <col:16, line:{{[0-9]+}}:1>
+// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:23>
+// CHECK-NEXT:ExprWithCleanups 0x{{[0-9a-f]+}} <col:12, col:23> 'S':'S'
+// CHECK-NEXT:CXXConstructExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'S':'S' 'void (const S &) throw()' elidable
+// CHECK-NEXT:MaterializeTemporaryExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'const S':'const S' lvalue
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'const S':'const S' <NoOp>
+// CHECK-NEXT:CXXMemberCallExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'S':'S'
+// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <col:12, col:20> '<bound member function type>' .Load 0x[[LOAD_S]]
+// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:12> 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' lvalue Var 0x[[BUF]] 'U' 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer'
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:22> 'uint':'unsigned int' <LValueToRValue>
+// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:22> 'uint':'unsigned int' lvalue ParmVar 0x[[FOO3_PARAM]] 'i' 'uint':'unsigned int'
+
+float foo(uint i) {
+    return U.Load<float>(i);
+}
+
+float2 foo2(uint i) {
+    return U.Load<float2>(i);
+}
+
+struct S {
+    float a;
+};
+
+S foo3(uint i) {
+    return U.Load<S>(i);
+}
Index: clang/lib/Sema/SemaExprMember.cpp
===================================================================
--- clang/lib/Sema/SemaExprMember.cpp
+++ clang/lib/Sema/SemaExprMember.cpp
@@ -1736,7 +1736,7 @@
   DeclarationName Name = NameInfo.getName();
   bool IsArrow = (OpKind == tok::arrow);
 
-  if (getLangOpts().HLSL && IsArrow)
+  if (getLangOpts().HLSL && IsArrow && OpLoc.isValid())
     return ExprError(Diag(OpLoc, diag::err_hlsl_operator_unsupported) << 2);
 
   NamedDecl *FirstQualifierInScope
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -15587,7 +15587,7 @@
     }
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && OpLoc.isValid()) {
     if (Opc == UO_AddrOf)
       return ExprError(Diag(OpLoc, diag::err_hlsl_operator_unsupported) << 0);
     if (Opc == UO_Deref)
Index: clang/lib/Sema/HLSLExternalSemaSource.cpp
===================================================================
--- clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -34,13 +34,15 @@
   ClassTemplateDecl *PrevTemplate = nullptr;
   NamespaceDecl *HLSLNamespace = nullptr;
   llvm::StringMap<FieldDecl *> Fields;
+  llvm::SmallVector<CXXBaseSpecifier, 4> Bases;
 
   BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) {
     Record->startDefinition();
     Template = Record->getDescribedClassTemplate();
   }
 
-  BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name)
+  BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name,
+                         bool DelayTypeCreation)
       : HLSLNamespace(Namespace) {
     ASTContext &AST = S.getASTContext();
     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
@@ -62,9 +64,9 @@
       return;
     }
 
-    Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class,
-                                   HLSLNamespace, SourceLocation(),
-                                   SourceLocation(), &II, PrevDecl, true);
+    Record = CXXRecordDecl::Create(
+        AST, TagDecl::TagKind::TTK_Class, HLSLNamespace, SourceLocation(),
+        SourceLocation(), &II, PrevDecl, DelayTypeCreation);
     Record->setImplicit(true);
     Record->setLexicalDeclContext(HLSLNamespace);
     Record->setHasExternalLexicalStorage();
@@ -289,6 +291,132 @@
     return *this;
   }
 
+  BuiltinTypeDeclBuilder &addByteAddressBufferLoad() {
+    if (Record->isCompleteDefinition())
+      return *this;
+    assert(Fields.count("h") > 0 && "Load must be added after the handle.");
+
+    // Implement Load by cast handle to uint8_t ptr first.
+    // Then add the byte address to get target address.
+    // Cast the target address to target type ptr and deref.
+    // struct RWByteAddressBuffer {
+    //   void *handle;
+    //   template<typename T> Load(uint i) {
+    //      return *((T*)(((uint8_t*)handle) + i));
+    //   }
+    // };
+
+    FieldDecl *Handle = Fields["h"];
+    ASTContext &AST = Record->getASTContext();
+
+    assert(Handle->getType().getCanonicalType() == AST.VoidPtrTy &&
+           "ByteAddressBuffer use void pointer handles.");
+
+    // Create the declaration for the template parameter.
+    TemplateTypeParmDecl *TypeParmDecl = TemplateTypeParmDecl::Create(
+        AST, Record->getDeclContext(), SourceLocation(), SourceLocation(),
+        /* TemplateDepth */ 0, /*Position*/ 0,
+        &AST.Idents.get("T", tok::TokenKind::identifier),
+        /* Typename */ false,
+        /* ParameterPack */ false);
+    // Create the type that the parameter represents.
+    QualType ReturnTy = AST.getTemplateTypeParmType(
+        /* TemplateDepth */ 0, /*Position*/ 0,
+        /* ParameterPack */ false, TypeParmDecl);
+
+    FunctionProtoType::ExtProtoInfo ExtInfo;
+    QualType MethodTy =
+        AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    auto *MethodDecl = CXXMethodDecl::Create(
+        AST, Record, SourceLocation(),
+        DeclarationNameInfo(DeclarationName(&AST.Idents.get(
+                                "Load", tok::TokenKind::identifier)),
+                            SourceLocation()),
+        MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
+        SourceLocation());
+
+    IdentifierInfo &II =
+        AST.Idents.get("ByteAddress", tok::TokenKind::identifier);
+    auto *IdxParam = ParmVarDecl::Create(
+        AST, MethodDecl, SourceLocation(), SourceLocation(),
+        &II, AST.UnsignedIntTy,
+        AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
+        SC_None, nullptr);
+    MethodDecl->setParams({IdxParam});
+
+    // Also add the parameter to the function prototype.
+    auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    FnProtoLoc.setParam(0, IdxParam);
+
+    auto *This = new (AST)
+        CXXThisExpr(SourceLocation(), MethodDecl->getThisType(), true);
+    auto *HandleAccess = MemberExpr::CreateImplicit(
+        AST, This, true, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
+
+    QualType U8Ty = AST.getIntTypeForBitwidth(8, false);
+    TypeSourceInfo *U8PtrTy =
+        AST.getTrivialTypeSourceInfo(
+            AST.getPointerType(U8Ty));
+    
+    auto *ByteBasePtr = CStyleCastExpr::Create(AST, U8PtrTy->getType(), ExprValueKind::VK_LValue,
+                           CastKind::CK_BitCast, HandleAccess, nullptr,
+                           FPOptionsOverride(), U8PtrTy, SourceLocation(),
+                           SourceLocation());
+    auto *IndexExpr = DeclRefExpr::Create(
+        AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
+        DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
+        AST.UnsignedIntTy, VK_PRValue);
+
+    auto *ByteBasePtrRV = ImplicitCastExpr::Create(
+        AST, U8PtrTy->getType(), CastKind::CK_LValueToRValue,
+                             ByteBasePtr,
+                             nullptr, ExprValueKind::VK_PRValue,
+                             FPOptionsOverride());
+    auto *BytePtr = BinaryOperator::Create(
+        AST, ByteBasePtrRV, IndexExpr, BinaryOperator::Opcode::BO_Add,
+        U8PtrTy->getType(), ExprValueKind::VK_PRValue,
+        ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride());
+
+    TypeSourceInfo *ResultPtrTy =
+        AST.getTrivialTypeSourceInfo(
+            AST.getPointerType(ReturnTy));
+
+    auto *ResultPtr = CStyleCastExpr::Create(
+        AST, ResultPtrTy->getType(), ExprValueKind::VK_LValue,
+        CastKind::CK_BitCast, BytePtr, nullptr, FPOptionsOverride(),
+        ResultPtrTy, SourceLocation(), SourceLocation());
+    auto *Deref = UnaryOperator::Create(AST, ResultPtr, UnaryOperator::Opcode::UO_Deref,
+                          ReturnTy, ExprValueKind::VK_LValue,
+                          ExprObjectKind::OK_Ordinary, SourceLocation(),
+                          /*CanOverflow*/ false, FPOptionsOverride());
+    auto *Result = ImplicitCastExpr::Create(
+        AST, ReturnTy, CastKind::CK_LValueToRValue, Deref, nullptr,
+        ExprValueKind::VK_PRValue, FPOptionsOverride());
+
+    auto *Return = ReturnStmt::Create(AST, SourceLocation(), Result, nullptr);
+
+    MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
+                                             SourceLocation(),
+                                             SourceLocation()));
+    MethodDecl->setLexicalDeclContext(Record);
+    MethodDecl->setAccess(AccessSpecifier::AS_public);
+    MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
+        AST, SourceRange(), AttributeCommonInfo::AS_Keyword,
+        AlwaysInlineAttr::CXX11_clang_always_inline));
+    TemplateParameterList *ParamList = TemplateParameterList::Create(
+        AST, SourceLocation(), SourceLocation(), {TypeParmDecl},
+        SourceLocation(), nullptr);
+    FunctionTemplateDecl *FunctionTemplate = FunctionTemplateDecl::Create(
+        AST, Record, SourceLocation(), MethodDecl->getDeclName(),
+        ParamList, MethodDecl);
+    FunctionTemplate->setAccess(AccessSpecifier::AS_public);
+    FunctionTemplate->setLexicalDeclContext(Record);
+    MethodDecl->setDescribedFunctionTemplate(FunctionTemplate);
+    Record->addDecl(FunctionTemplate);
+    return *this;
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -367,6 +495,104 @@
 }
 } // namespace
 
+// Helper class to decl HLSL resource types.
+namespace clang {
+namespace hlsl {
+class HLSLResourceHelper {
+  Sema *SemaPtr = nullptr;
+  NamespaceDecl *HLSLNamespace = nullptr;
+  using ForwardDeclFunction = std::function<CXXRecordDecl *(void)>;
+  using CompletionFunction = std::function<void(CXXRecordDecl *)>;
+  llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions;
+  void declType(ForwardDeclFunction ForwardDecl,
+                CompletionFunction Completion) {
+    CXXRecordDecl *Decl = ForwardDecl();
+    if (!Decl->isCompleteDefinition())
+      Completions.insert(std::make_pair(Decl->getCanonicalDecl(), Completion));
+  }
+
+public:
+  HLSLResourceHelper(Sema *SemaPtr, NamespaceDecl *HLSLNamespace)
+      : SemaPtr(SemaPtr), HLSLNamespace(HLSLNamespace) {}
+  void completeType(CXXRecordDecl *Record) {
+    // If this is a specialization, we need to get the underlying templated
+    // declaration and complete that.
+    if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
+      Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+    Record = Record->getCanonicalDecl();
+    auto It = Completions.find(Record);
+    if (It == Completions.end())
+      return;
+    It->second(Record);
+  }
+
+  void declResourceTypes() {
+    declRWBuffer();
+    declRWByteAddressBuffer();
+  }
+
+private:
+  void declRWBuffer();
+  void declRWByteAddressBuffer();
+  CXXRecordDecl *forwardDeclRWBuffer();
+  void completeRWBuffer(CXXRecordDecl *Record);
+  CXXRecordDecl *forwardDeclRWByteAddressBuffer();
+  void completeRWByteAddressBuffer(CXXRecordDecl *Record);
+};
+void HLSLResourceHelper::declRWBuffer() {
+  ForwardDeclFunction ForwardDecl =
+      std::bind(&HLSLResourceHelper::forwardDeclRWBuffer, this);
+  CompletionFunction Complition = std::bind(
+      &HLSLResourceHelper::completeRWBuffer, this, std::placeholders::_1);
+  declType(ForwardDecl, Complition);
+}
+
+CXXRecordDecl *HLSLResourceHelper::forwardDeclRWBuffer() {
+  return BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer", /*DelayTypeCreation*/ true)
+      .addTemplateArgumentList()
+      .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy)
+      .finalizeTemplateArgs()
+      .annotateResourceClass(HLSLResourceAttr::UAV,
+                             HLSLResourceAttr::TypedBuffer)
+      .Record;
+}
+void HLSLResourceHelper::completeRWBuffer(CXXRecordDecl *Record) {
+  BuiltinTypeDeclBuilder(Record)
+      .addHandleMember()
+      .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
+      .addArraySubscriptOperators()
+      .completeDefinition();
+}
+
+void HLSLResourceHelper::declRWByteAddressBuffer() {
+  ForwardDeclFunction ForwardDecl =
+      std::bind(&HLSLResourceHelper::forwardDeclRWByteAddressBuffer, this);
+  CompletionFunction Complition =
+      std::bind(&HLSLResourceHelper::completeRWByteAddressBuffer, this,
+                std::placeholders::_1);
+  declType(ForwardDecl, Complition);
+}
+
+CXXRecordDecl *HLSLResourceHelper::forwardDeclRWByteAddressBuffer() {
+  return BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWByteAddressBuffer", /*DelayTypeCreation*/ false)
+      .annotateResourceClass(HLSLResourceAttr::UAV,
+                             HLSLResourceAttr::RawBuffer)
+      .Record;
+}
+void HLSLResourceHelper::completeRWByteAddressBuffer(CXXRecordDecl *Record) {
+  BuiltinTypeDeclBuilder Builder(Record);
+
+  Builder.addHandleMember()
+      .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV);
+
+      Builder.addByteAddressBufferLoad();
+
+  Builder.completeDefinition();
+}
+} // namespace hlsl
+} // namespace clang
+
+HLSLExternalSemaSource::HLSLExternalSemaSource() : ResourceHelper(nullptr) {}
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -388,6 +614,10 @@
   HLSLNamespace->setHasExternalLexicalStorage();
   AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
 
+  // Init HLSLResourceHelper.
+  ResourceHelper =
+      std::make_unique<hlsl::HLSLResourceHelper>(SemaPtr, HLSLNamespace);
+
   // Force external decls in the HLSL namespace to load from the PCH.
   (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
   defineTrivialHLSLTypes();
@@ -461,7 +691,7 @@
 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
   defineHLSLVectorAlias();
 
-  ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource")
+  ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource", /*DelayTypeCreation*/ true)
                      .startDefinition()
                      .addHandleMember(AccessSpecifier::AS_public)
                      .completeDefinition()
@@ -469,41 +699,12 @@
 }
 
 void HLSLExternalSemaSource::forwardDeclareHLSLTypes() {
-  CXXRecordDecl *Decl;
-  Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
-             .addTemplateArgumentList()
-             .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy)
-             .finalizeTemplateArgs()
-             .Record;
-  if (!Decl->isCompleteDefinition())
-    Completions.insert(
-        std::make_pair(Decl->getCanonicalDecl(),
-                       std::bind(&HLSLExternalSemaSource::completeBufferType,
-                                 this, std::placeholders::_1)));
+  ResourceHelper->declResourceTypes();
 }
 
 void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
   if (!isa<CXXRecordDecl>(Tag))
     return;
   auto Record = cast<CXXRecordDecl>(Tag);
-
-  // If this is a specialization, we need to get the underlying templated
-  // declaration and complete that.
-  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
-    Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
-  Record = Record->getCanonicalDecl();
-  auto It = Completions.find(Record);
-  if (It == Completions.end())
-    return;
-  It->second(Record);
-}
-
-void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) {
-  BuiltinTypeDeclBuilder(Record)
-      .addHandleMember()
-      .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
-      .addArraySubscriptOperators()
-      .annotateResourceClass(HLSLResourceAttr::UAV,
-                             HLSLResourceAttr::TypedBuffer)
-      .completeDefinition();
+  ResourceHelper->completeType(Record);
 }
Index: clang/include/clang/Sema/HLSLExternalSemaSource.h
===================================================================
--- clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -15,26 +15,27 @@
 #include "llvm/ADT/DenseMap.h"
 
 #include "clang/Sema/ExternalSemaSource.h"
+#include <memory>
 
 namespace clang {
 class NamespaceDecl;
 class Sema;
+namespace hlsl {
+class HLSLResourceHelper;
+}
 
 class HLSLExternalSemaSource : public ExternalSemaSource {
   Sema *SemaPtr = nullptr;
   NamespaceDecl *HLSLNamespace = nullptr;
   CXXRecordDecl *ResourceDecl;
-
-  using CompletionFunction = std::function<void(CXXRecordDecl *)>;
-  llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions;
+  std::unique_ptr<hlsl::HLSLResourceHelper> ResourceHelper;
 
   void defineHLSLVectorAlias();
   void defineTrivialHLSLTypes();
   void forwardDeclareHLSLTypes();
 
-  void completeBufferType(CXXRecordDecl *Record);
-
 public:
+  HLSLExternalSemaSource();
   ~HLSLExternalSemaSource() override;
 
   /// Initialize the semantic source with the Sema instance
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to