python3kgae created this revision. python3kgae added reviewers: pow2clk, beanz, bogner. Herald added a subscriber: Anastasia. Herald added a project: All. python3kgae requested review of this revision. Herald added a project: clang. Herald added a subscriber: cfe-commits.
Only RWByteAddressBuffer::Load is supported in this PR. Implement Load by cast handle to uint8_t ptr first. Then add the byte address to get target address. Cast the target address to target type ptr and deref. struct RWByteAddressBuffer { void *handle; template<typename T> Load(uint i) { return *((T*)(((uint8_t*)handle) + i)); } }; Allow ref/ptr use for HLSLExternalSemaSource to skip the check when Location is invalid. Also add HLSLResourceHelper to share code for RWByteAddressBuffer and RWBuffer. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D136913 Files: clang/include/clang/Sema/HLSLExternalSemaSource.h clang/lib/Sema/HLSLExternalSemaSource.cpp clang/lib/Sema/SemaExpr.cpp clang/lib/Sema/SemaExprMember.cpp clang/test/AST/HLSL/RWByteAddressBuffer.hlsl clang/test/SemaHLSL/BuiltIns/RWByteAddressBuffer.hlsl
Index: clang/test/SemaHLSL/BuiltIns/RWByteAddressBuffer.hlsl =================================================================== --- /dev/null +++ clang/test/SemaHLSL/BuiltIns/RWByteAddressBuffer.hlsl @@ -0,0 +1,12 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -fsyntax-only -verify %s + +typedef vector<float, 3> float3; + +RWByteAddressBuffer Buffer; + +[numthreads(1,1,1)] +void main() { + (void)Buffer.h; // expected-error {{'h' is a private member of 'hlsl::RWByteAddressBuffer'}} + // expected-note@* {{implicitly declared private here}} + Buffer.Load<float3>(0); +} Index: clang/test/AST/HLSL/RWByteAddressBuffer.hlsl =================================================================== --- /dev/null +++ clang/test/AST/HLSL/RWByteAddressBuffer.hlsl @@ -0,0 +1,104 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -fsyntax-only -ast-dump %s | FileCheck %s + +RWByteAddressBuffer U; +typedef unsigned int uint; +typedef vector<float, 2> float2; + +// CHECK:CXXRecordDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> implicit referenced <undeserialized declarations> class RWByteAddressBuffer definition +// CHECK:FinalAttr 0x{{[0-9a-f]+}} <<invalid sloc>> Implicit final +// CHECK-NEXT:HLSLResourceAttr 0x{{[0-9a-f]+}} <<invalid sloc>> Implicit UAV RawBuffer +// CHECK-NEXT:-FieldDecl 0x[[HDL:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> implicit referenced h 'void *' +// CHECK-NEXT:CXXConstructorDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> used RWByteAddressBuffer 'void ()' inline +// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <<invalid sloc>> +// CHECK-NEXT:BinaryOperator 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' lvalue '=' +// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' lvalue ->h 0x[[HDL]] +// CHECK-NEXT:-CXXThisExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'hlsl::RWByteAddressBuffer *' implicit this +// CHECK-NEXT:CallExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' +// CHECK-NEXT:-DeclRefExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *(unsigned char) throw()' Function 0x{{[0-9a-f]+}} '__builtin_hlsl_create_handle' 'void *(unsigned char) throw()' +// CHECK-NEXT:-IntegerLiteral 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char' 1 +// CHECK-NEXT:FunctionTemplateDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> Load +// CHECK-NEXT:TemplateTypeParmDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 T +// CHECK-NEXT:CXXMethodDecl 0x{{[0-9a-f]+}} <<invalid sloc>> <invalid sloc> Load 'T (unsigned int)' +// CHECK-NEXT:ParmVarDecl 0x[[LOAD_PARAM:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> ByteAddress 'unsigned int' +// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <<invalid sloc>> +// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <<invalid sloc>> +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'T' <LValueToRValue> +// CHECK-NEXT:UnaryOperator 0x{{[0-9a-f]+}} <<invalid sloc>> 'T' lvalue prefix '*' cannot overflow +// CHECK-NEXT:CStyleCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'T *' lvalue <BitCast> +// CHECK-NEXT:inaryOperator 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char *' '+' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char *' <LValueToRValue> +// CHECK-NEXT:CStyleCastExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned char *' lvalue <BitCast> +// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'void *' lvalue ->h 0x[[HDL]] +// CHECK-NEXT:CXXThisExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'hlsl::RWByteAddressBuffer *' implicit this +// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x[[LOAD_PARAM]] 'ByteAddress' 'unsigned int' + +// CHECK:CXXMethodDecl 0x[[LOAD_F:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> used Load 'float (unsigned int)' +// CHECK-NEXT:TemplateArgument type 'float' +// CHECK-NEXT:BuiltinType 0x{{[0-9a-f]+}} 'float' + +// CHECK:CXXMethodDecl 0x[[LOAD_F2:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> used Load 'float (unsigned int) __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:TemplateArgument type 'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ExtVectorType 0x{{[0-9a-f]+}} 'float __attribute__((ext_vector_type(2)))' 2 +// CHECK-NEXT:BuiltinType 0x{{[0-9a-f]+}} 'float' + +// CHECK:CXXMethodDecl 0x[[LOAD_S:[0-9a-f]+]] <<invalid sloc>> <invalid sloc> used Load 'S (unsigned int)' +// CHECK-NEXT:TemplateArgument type 'S' +// CHECK-NEXT:-RecordType 0x{{[0-9a-f]+}} 'S' +// CHECK-NEXT:CXXRecord 0x[[RECORD:[0-9a-f]+]] 'S' + +// CHECK:VarDecl 0x[[BUF:[0-9a-f]+]] <{{.+}}:1, col:21> col:21 used U 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' + +// CHECK:FunctionDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:7 foo 'float (uint)' +// CHECK-NEXT:ParmVarDecl 0x[[FOO_PARAM:[0-9a-f]+]] <col:11, col:16> col:16 used i 'uint':'unsigned int' +// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <col:19, line:{{[0-9]+}}:1> +// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:27> +// CHECK-NEXT:CXXMemberCallExpr 0x{{[0-9a-f]+}} <col:12, col:27> 'float':'float' +// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <col:12, col:24> '<bound member function type>' .Load 0x[[LOAD_F]] +// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:12> 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' lvalue Var 0x[[BUF]] 'U' 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:26> 'uint':'unsigned int' <LValueToRValue> +// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:26> 'uint':'unsigned int' lvalue ParmVar 0x[[FOO_PARAM]] 'i' 'uint':'unsigned int' + +// CHECK:FunctionDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:8 foo2 'float2 (uint)' +// CHECK-NEXT:ParmVarDecl 0x[[FOO2_PARAM:[0-9a-f]+]] <col:13, col:18> col:18 used i 'uint':'unsigned int' +// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <col:21, line:{{[0-9]+}}:1> +// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:28> +// CHECK-NEXT:CXXMemberCallExpr 0x{{[0-9a-f]+}} <col:12, col:28> 'float __attribute__((ext_vector_type(2)))':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <col:12, col:25> '<bound member function type>' .Load 0x[[LOAD_F2]] +// CHECK-NEXT:-DeclRefExpr 0x{{[0-9a-f]+}} <col:12> 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' lvalue Var 0x[[BUF]] 'U' 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:27> 'uint':'unsigned int' <LValueToRValue> +// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:27> 'uint':'unsigned int' lvalue ParmVar 0x[[FOO2_PARAM]] 'i' 'uint':'unsigned int' + +// CHECK:CXXRecordDecl 0x[[RECORD]] <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:8 referenced struct S definition + +// CHECK:CXXRecordDecl 0x{{[0-9a-f]+}} <col:1, col:8> col:8 implicit struct S +// CHECK-NEXT:FieldDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:11> col:11 referenced a 'float' + +// CHECK:FunctionDecl 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:1, line:{{[0-9]+}}:1> line:{{[0-9]+}}:3 foo3 'S (uint)' +// CHECK-NEXT:ParmVarDecl 0x[[FOO3_PARAM:[0-9a-f]+]] <col:8, col:13> col:13 used i 'uint':'unsigned int' +// CHECK-NEXT:CompoundStmt 0x{{[0-9a-f]+}} <col:16, line:{{[0-9]+}}:1> +// CHECK-NEXT:ReturnStmt 0x{{[0-9a-f]+}} <line:{{[0-9]+}}:5, col:23> +// CHECK-NEXT:ExprWithCleanups 0x{{[0-9a-f]+}} <col:12, col:23> 'S':'S' +// CHECK-NEXT:CXXConstructExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'S':'S' 'void (const S &) throw()' elidable +// CHECK-NEXT:MaterializeTemporaryExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'const S':'const S' lvalue +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'const S':'const S' <NoOp> +// CHECK-NEXT:CXXMemberCallExpr 0x{{[0-9a-f]+}} <col:12, col:23> 'S':'S' +// CHECK-NEXT:MemberExpr 0x{{[0-9a-f]+}} <col:12, col:20> '<bound member function type>' .Load 0x[[LOAD_S]] +// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:12> 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' lvalue Var 0x[[BUF]] 'U' 'RWByteAddressBuffer':'hlsl::RWByteAddressBuffer' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:22> 'uint':'unsigned int' <LValueToRValue> +// CHECK-NEXT:DeclRefExpr 0x{{[0-9a-f]+}} <col:22> 'uint':'unsigned int' lvalue ParmVar 0x[[FOO3_PARAM]] 'i' 'uint':'unsigned int' + +float foo(uint i) { + return U.Load<float>(i); +} + +float2 foo2(uint i) { + return U.Load<float2>(i); +} + +struct S { + float a; +}; + +S foo3(uint i) { + return U.Load<S>(i); +} Index: clang/lib/Sema/SemaExprMember.cpp =================================================================== --- clang/lib/Sema/SemaExprMember.cpp +++ clang/lib/Sema/SemaExprMember.cpp @@ -1736,7 +1736,7 @@ DeclarationName Name = NameInfo.getName(); bool IsArrow = (OpKind == tok::arrow); - if (getLangOpts().HLSL && IsArrow) + if (getLangOpts().HLSL && IsArrow && OpLoc.isValid()) return ExprError(Diag(OpLoc, diag::err_hlsl_operator_unsupported) << 2); NamedDecl *FirstQualifierInScope Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -15587,7 +15587,7 @@ } } - if (getLangOpts().HLSL) { + if (getLangOpts().HLSL && OpLoc.isValid()) { if (Opc == UO_AddrOf) return ExprError(Diag(OpLoc, diag::err_hlsl_operator_unsupported) << 0); if (Opc == UO_Deref) Index: clang/lib/Sema/HLSLExternalSemaSource.cpp =================================================================== --- clang/lib/Sema/HLSLExternalSemaSource.cpp +++ clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -34,13 +34,15 @@ ClassTemplateDecl *PrevTemplate = nullptr; NamespaceDecl *HLSLNamespace = nullptr; llvm::StringMap<FieldDecl *> Fields; + llvm::SmallVector<CXXBaseSpecifier, 4> Bases; BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) { Record->startDefinition(); Template = Record->getDescribedClassTemplate(); } - BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name) + BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name, + bool DelayTypeCreation) : HLSLNamespace(Namespace) { ASTContext &AST = S.getASTContext(); IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier); @@ -62,9 +64,9 @@ return; } - Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class, - HLSLNamespace, SourceLocation(), - SourceLocation(), &II, PrevDecl, true); + Record = CXXRecordDecl::Create( + AST, TagDecl::TagKind::TTK_Class, HLSLNamespace, SourceLocation(), + SourceLocation(), &II, PrevDecl, DelayTypeCreation); Record->setImplicit(true); Record->setLexicalDeclContext(HLSLNamespace); Record->setHasExternalLexicalStorage(); @@ -289,6 +291,132 @@ return *this; } + BuiltinTypeDeclBuilder &addByteAddressBufferLoad() { + if (Record->isCompleteDefinition()) + return *this; + assert(Fields.count("h") > 0 && "Load must be added after the handle."); + + // Implement Load by cast handle to uint8_t ptr first. + // Then add the byte address to get target address. + // Cast the target address to target type ptr and deref. + // struct RWByteAddressBuffer { + // void *handle; + // template<typename T> Load(uint i) { + // return *((T*)(((uint8_t*)handle) + i)); + // } + // }; + + FieldDecl *Handle = Fields["h"]; + ASTContext &AST = Record->getASTContext(); + + assert(Handle->getType().getCanonicalType() == AST.VoidPtrTy && + "ByteAddressBuffer use void pointer handles."); + + // Create the declaration for the template parameter. + TemplateTypeParmDecl *TypeParmDecl = TemplateTypeParmDecl::Create( + AST, Record->getDeclContext(), SourceLocation(), SourceLocation(), + /* TemplateDepth */ 0, /*Position*/ 0, + &AST.Idents.get("T", tok::TokenKind::identifier), + /* Typename */ false, + /* ParameterPack */ false); + // Create the type that the parameter represents. + QualType ReturnTy = AST.getTemplateTypeParmType( + /* TemplateDepth */ 0, /*Position*/ 0, + /* ParameterPack */ false, TypeParmDecl); + + FunctionProtoType::ExtProtoInfo ExtInfo; + QualType MethodTy = + AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo); + auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation()); + auto *MethodDecl = CXXMethodDecl::Create( + AST, Record, SourceLocation(), + DeclarationNameInfo(DeclarationName(&AST.Idents.get( + "Load", tok::TokenKind::identifier)), + SourceLocation()), + MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified, + SourceLocation()); + + IdentifierInfo &II = + AST.Idents.get("ByteAddress", tok::TokenKind::identifier); + auto *IdxParam = ParmVarDecl::Create( + AST, MethodDecl, SourceLocation(), SourceLocation(), + &II, AST.UnsignedIntTy, + AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()), + SC_None, nullptr); + MethodDecl->setParams({IdxParam}); + + // Also add the parameter to the function prototype. + auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>(); + FnProtoLoc.setParam(0, IdxParam); + + auto *This = new (AST) + CXXThisExpr(SourceLocation(), MethodDecl->getThisType(), true); + auto *HandleAccess = MemberExpr::CreateImplicit( + AST, This, true, Handle, Handle->getType(), VK_LValue, OK_Ordinary); + + QualType U8Ty = AST.getIntTypeForBitwidth(8, false); + TypeSourceInfo *U8PtrTy = + AST.getTrivialTypeSourceInfo( + AST.getPointerType(U8Ty)); + + auto *ByteBasePtr = CStyleCastExpr::Create(AST, U8PtrTy->getType(), ExprValueKind::VK_LValue, + CastKind::CK_BitCast, HandleAccess, nullptr, + FPOptionsOverride(), U8PtrTy, SourceLocation(), + SourceLocation()); + auto *IndexExpr = DeclRefExpr::Create( + AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false, + DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()), + AST.UnsignedIntTy, VK_PRValue); + + auto *ByteBasePtrRV = ImplicitCastExpr::Create( + AST, U8PtrTy->getType(), CastKind::CK_LValueToRValue, + ByteBasePtr, + nullptr, ExprValueKind::VK_PRValue, + FPOptionsOverride()); + auto *BytePtr = BinaryOperator::Create( + AST, ByteBasePtrRV, IndexExpr, BinaryOperator::Opcode::BO_Add, + U8PtrTy->getType(), ExprValueKind::VK_PRValue, + ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride()); + + TypeSourceInfo *ResultPtrTy = + AST.getTrivialTypeSourceInfo( + AST.getPointerType(ReturnTy)); + + auto *ResultPtr = CStyleCastExpr::Create( + AST, ResultPtrTy->getType(), ExprValueKind::VK_LValue, + CastKind::CK_BitCast, BytePtr, nullptr, FPOptionsOverride(), + ResultPtrTy, SourceLocation(), SourceLocation()); + auto *Deref = UnaryOperator::Create(AST, ResultPtr, UnaryOperator::Opcode::UO_Deref, + ReturnTy, ExprValueKind::VK_LValue, + ExprObjectKind::OK_Ordinary, SourceLocation(), + /*CanOverflow*/ false, FPOptionsOverride()); + auto *Result = ImplicitCastExpr::Create( + AST, ReturnTy, CastKind::CK_LValueToRValue, Deref, nullptr, + ExprValueKind::VK_PRValue, FPOptionsOverride()); + + auto *Return = ReturnStmt::Create(AST, SourceLocation(), Result, nullptr); + + MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(), + SourceLocation(), + SourceLocation())); + MethodDecl->setLexicalDeclContext(Record); + MethodDecl->setAccess(AccessSpecifier::AS_public); + MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit( + AST, SourceRange(), AttributeCommonInfo::AS_Keyword, + AlwaysInlineAttr::CXX11_clang_always_inline)); + TemplateParameterList *ParamList = TemplateParameterList::Create( + AST, SourceLocation(), SourceLocation(), {TypeParmDecl}, + SourceLocation(), nullptr); + FunctionTemplateDecl *FunctionTemplate = FunctionTemplateDecl::Create( + AST, Record, SourceLocation(), MethodDecl->getDeclName(), + ParamList, MethodDecl); + FunctionTemplate->setAccess(AccessSpecifier::AS_public); + FunctionTemplate->setLexicalDeclContext(Record); + MethodDecl->setDescribedFunctionTemplate(FunctionTemplate); + Record->addDecl(FunctionTemplate); + return *this; + } + BuiltinTypeDeclBuilder &startDefinition() { if (Record->isCompleteDefinition()) return *this; @@ -367,6 +495,104 @@ } } // namespace +// Helper class to decl HLSL resource types. +namespace clang { +namespace hlsl { +class HLSLResourceHelper { + Sema *SemaPtr = nullptr; + NamespaceDecl *HLSLNamespace = nullptr; + using ForwardDeclFunction = std::function<CXXRecordDecl *(void)>; + using CompletionFunction = std::function<void(CXXRecordDecl *)>; + llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions; + void declType(ForwardDeclFunction ForwardDecl, + CompletionFunction Completion) { + CXXRecordDecl *Decl = ForwardDecl(); + if (!Decl->isCompleteDefinition()) + Completions.insert(std::make_pair(Decl->getCanonicalDecl(), Completion)); + } + +public: + HLSLResourceHelper(Sema *SemaPtr, NamespaceDecl *HLSLNamespace) + : SemaPtr(SemaPtr), HLSLNamespace(HLSLNamespace) {} + void completeType(CXXRecordDecl *Record) { + // If this is a specialization, we need to get the underlying templated + // declaration and complete that. + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record)) + Record = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + Record = Record->getCanonicalDecl(); + auto It = Completions.find(Record); + if (It == Completions.end()) + return; + It->second(Record); + } + + void declResourceTypes() { + declRWBuffer(); + declRWByteAddressBuffer(); + } + +private: + void declRWBuffer(); + void declRWByteAddressBuffer(); + CXXRecordDecl *forwardDeclRWBuffer(); + void completeRWBuffer(CXXRecordDecl *Record); + CXXRecordDecl *forwardDeclRWByteAddressBuffer(); + void completeRWByteAddressBuffer(CXXRecordDecl *Record); +}; +void HLSLResourceHelper::declRWBuffer() { + ForwardDeclFunction ForwardDecl = + std::bind(&HLSLResourceHelper::forwardDeclRWBuffer, this); + CompletionFunction Complition = std::bind( + &HLSLResourceHelper::completeRWBuffer, this, std::placeholders::_1); + declType(ForwardDecl, Complition); +} + +CXXRecordDecl *HLSLResourceHelper::forwardDeclRWBuffer() { + return BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer", /*DelayTypeCreation*/ true) + .addTemplateArgumentList() + .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy) + .finalizeTemplateArgs() + .annotateResourceClass(HLSLResourceAttr::UAV, + HLSLResourceAttr::TypedBuffer) + .Record; +} +void HLSLResourceHelper::completeRWBuffer(CXXRecordDecl *Record) { + BuiltinTypeDeclBuilder(Record) + .addHandleMember() + .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV) + .addArraySubscriptOperators() + .completeDefinition(); +} + +void HLSLResourceHelper::declRWByteAddressBuffer() { + ForwardDeclFunction ForwardDecl = + std::bind(&HLSLResourceHelper::forwardDeclRWByteAddressBuffer, this); + CompletionFunction Complition = + std::bind(&HLSLResourceHelper::completeRWByteAddressBuffer, this, + std::placeholders::_1); + declType(ForwardDecl, Complition); +} + +CXXRecordDecl *HLSLResourceHelper::forwardDeclRWByteAddressBuffer() { + return BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWByteAddressBuffer", /*DelayTypeCreation*/ false) + .annotateResourceClass(HLSLResourceAttr::UAV, + HLSLResourceAttr::RawBuffer) + .Record; +} +void HLSLResourceHelper::completeRWByteAddressBuffer(CXXRecordDecl *Record) { + BuiltinTypeDeclBuilder Builder(Record); + + Builder.addHandleMember() + .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV); + + Builder.addByteAddressBufferLoad(); + + Builder.completeDefinition(); +} +} // namespace hlsl +} // namespace clang + +HLSLExternalSemaSource::HLSLExternalSemaSource() : ResourceHelper(nullptr) {} HLSLExternalSemaSource::~HLSLExternalSemaSource() {} void HLSLExternalSemaSource::InitializeSema(Sema &S) { @@ -388,6 +614,10 @@ HLSLNamespace->setHasExternalLexicalStorage(); AST.getTranslationUnitDecl()->addDecl(HLSLNamespace); + // Init HLSLResourceHelper. + ResourceHelper = + std::make_unique<hlsl::HLSLResourceHelper>(SemaPtr, HLSLNamespace); + // Force external decls in the HLSL namespace to load from the PCH. (void)HLSLNamespace->getCanonicalDecl()->decls_begin(); defineTrivialHLSLTypes(); @@ -461,7 +691,7 @@ void HLSLExternalSemaSource::defineTrivialHLSLTypes() { defineHLSLVectorAlias(); - ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource") + ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource", /*DelayTypeCreation*/ true) .startDefinition() .addHandleMember(AccessSpecifier::AS_public) .completeDefinition() @@ -469,41 +699,12 @@ } void HLSLExternalSemaSource::forwardDeclareHLSLTypes() { - CXXRecordDecl *Decl; - Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer") - .addTemplateArgumentList() - .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy) - .finalizeTemplateArgs() - .Record; - if (!Decl->isCompleteDefinition()) - Completions.insert( - std::make_pair(Decl->getCanonicalDecl(), - std::bind(&HLSLExternalSemaSource::completeBufferType, - this, std::placeholders::_1))); + ResourceHelper->declResourceTypes(); } void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) { if (!isa<CXXRecordDecl>(Tag)) return; auto Record = cast<CXXRecordDecl>(Tag); - - // If this is a specialization, we need to get the underlying templated - // declaration and complete that. - if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record)) - Record = TDecl->getSpecializedTemplate()->getTemplatedDecl(); - Record = Record->getCanonicalDecl(); - auto It = Completions.find(Record); - if (It == Completions.end()) - return; - It->second(Record); -} - -void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) { - BuiltinTypeDeclBuilder(Record) - .addHandleMember() - .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV) - .addArraySubscriptOperators() - .annotateResourceClass(HLSLResourceAttr::UAV, - HLSLResourceAttr::TypedBuffer) - .completeDefinition(); + ResourceHelper->completeType(Record); } Index: clang/include/clang/Sema/HLSLExternalSemaSource.h =================================================================== --- clang/include/clang/Sema/HLSLExternalSemaSource.h +++ clang/include/clang/Sema/HLSLExternalSemaSource.h @@ -15,26 +15,27 @@ #include "llvm/ADT/DenseMap.h" #include "clang/Sema/ExternalSemaSource.h" +#include <memory> namespace clang { class NamespaceDecl; class Sema; +namespace hlsl { +class HLSLResourceHelper; +} class HLSLExternalSemaSource : public ExternalSemaSource { Sema *SemaPtr = nullptr; NamespaceDecl *HLSLNamespace = nullptr; CXXRecordDecl *ResourceDecl; - - using CompletionFunction = std::function<void(CXXRecordDecl *)>; - llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions; + std::unique_ptr<hlsl::HLSLResourceHelper> ResourceHelper; void defineHLSLVectorAlias(); void defineTrivialHLSLTypes(); void forwardDeclareHLSLTypes(); - void completeBufferType(CXXRecordDecl *Record); - public: + HLSLExternalSemaSource(); ~HLSLExternalSemaSource() override; /// Initialize the semantic source with the Sema instance
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits