Author: Helena Kotas Date: 2024-12-09T15:15:37-08:00 New Revision: 274637d7e5960e37a33f29521905eef3d0fee13d
URL: https://github.com/llvm/llvm-project/commit/274637d7e5960e37a33f29521905eef3d0fee13d DIFF: https://github.com/llvm/llvm-project/commit/274637d7e5960e37a33f29521905eef3d0fee13d.diff LOG: [HLSL] Implement `Append` and `Consume` methods on `Append`/`ConsumeStructuredBuffer` (#118536) The methods are using existing clang builtins `__builtin_hlsl_buffer_update_counter` and `__builtin_hlsl_resource_getpointer` to update the buffer counter and then load or store the value. Fixes #112968 Added: Modified: clang/lib/Sema/HLSLExternalSemaSource.cpp clang/test/AST/HLSL/StructuredBuffers-AST.hlsl clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl Removed: ################################################################################ diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index 8e57123c503cba..29672658525403 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -246,6 +246,8 @@ class BuiltinTypeDeclBuilder { BuiltinTypeDeclBuilder &addDecrementCounterMethod(); BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name, bool IsConst, bool IsRef); + BuiltinTypeDeclBuilder &addAppendMethod(); + BuiltinTypeDeclBuilder &addConsumeMethod(); }; struct TemplateParameterListBuilder { @@ -443,14 +445,26 @@ struct BuiltinTypeMethodBuilder { llvm::SmallVector<Stmt *> StmtsList; // Argument placeholders, inspired by std::placeholder. These are the indices - // of arguments to forward to `callBuiltin`, and additionally `Handle` which - // refers to the resource handle. - enum class PlaceHolder { _0, _1, _2, _3, Handle = 127 }; + // of arguments to forward to `callBuiltin` and other method builder methods. + // Additional special values are: + // Handle - refers to the resource handle. + // LastStmt - refers to the last statement in the method body; referencing + // LastStmt will remove the statement from the method body since + // it will be linked from the new expression being constructed. + enum class PlaceHolder { _0, _1, _2, _3, Handle = 128, LastStmt }; Expr *convertPlaceholder(PlaceHolder PH) { if (PH == PlaceHolder::Handle) return getResourceHandleExpr(); + if (PH == PlaceHolder::LastStmt) { + assert(!StmtsList.empty() && "no statements in the list"); + Stmt *LastStmt = StmtsList.pop_back_val(); + assert(isa<ValueStmt>(LastStmt) && + "last statement does not have a value"); + return cast<ValueStmt>(LastStmt)->getExprStmt(); + } + ASTContext &AST = DeclBuilder.SemaRef.getASTContext(); ParmVarDecl *ParamDecl = Method->getParamDecl(static_cast<unsigned>(PH)); return DeclRefExpr::Create( @@ -573,17 +587,25 @@ struct BuiltinTypeMethodBuilder { return *this; } - BuiltinTypeMethodBuilder &dereference() { - assert(!StmtsList.empty() && "Nothing to dereference"); - ASTContext &AST = DeclBuilder.SemaRef.getASTContext(); + template <typename TLHS, typename TRHS> + BuiltinTypeMethodBuilder &assign(TLHS LHS, TRHS RHS) { + Expr *LHSExpr = convertPlaceholder(LHS); + Expr *RHSExpr = convertPlaceholder(RHS); + Stmt *AssignStmt = BinaryOperator::Create( + DeclBuilder.SemaRef.getASTContext(), LHSExpr, RHSExpr, BO_Assign, + LHSExpr->getType(), ExprValueKind::VK_PRValue, + ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride()); + StmtsList.push_back(AssignStmt); + return *this; + } - Expr *LastExpr = dyn_cast<Expr>(StmtsList.back()); - assert(LastExpr && "No expression to dereference"); - Expr *Deref = UnaryOperator::Create( - AST, LastExpr, UO_Deref, LastExpr->getType()->getPointeeType(), - VK_PRValue, OK_Ordinary, SourceLocation(), - /*CanOverflow=*/false, FPOptionsOverride()); - StmtsList.pop_back(); + template <typename T> BuiltinTypeMethodBuilder &dereference(T Ptr) { + Expr *PtrExpr = convertPlaceholder(Ptr); + Expr *Deref = + UnaryOperator::Create(DeclBuilder.SemaRef.getASTContext(), PtrExpr, + UO_Deref, PtrExpr->getType()->getPointeeType(), + VK_PRValue, OK_Ordinary, SourceLocation(), + /*CanOverflow=*/false, FPOptionsOverride()); StmtsList.push_back(Deref); return *this; } @@ -685,7 +707,35 @@ BuiltinTypeDeclBuilder::addHandleAccessFunction(DeclarationName &Name, .addParam("Index", AST.UnsignedIntTy) .callBuiltin("__builtin_hlsl_resource_getpointer", ElemPtrTy, PH::Handle, PH::_0) - .dereference() + .dereference(PH::LastStmt) + .finalizeMethod(); +} + +BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() { + using PH = BuiltinTypeMethodBuilder::PlaceHolder; + ASTContext &AST = SemaRef.getASTContext(); + QualType ElemTy = getHandleElementType(); + return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy) + .addParam("value", ElemTy) + .callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy, + PH::Handle, getConstantIntExpr(1)) + .callBuiltin("__builtin_hlsl_resource_getpointer", + AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt) + .dereference(PH::LastStmt) + .assign(PH::LastStmt, PH::_0) + .finalizeMethod(); +} + +BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() { + using PH = BuiltinTypeMethodBuilder::PlaceHolder; + ASTContext &AST = SemaRef.getASTContext(); + QualType ElemTy = getHandleElementType(); + return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy) + .callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy, + PH::Handle, getConstantIntExpr(-1)) + .callBuiltin("__builtin_hlsl_resource_getpointer", + AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt) + .dereference(PH::LastStmt) .finalizeMethod(); } @@ -915,6 +965,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() { onCompletion(Decl, [this](CXXRecordDecl *Decl) { setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer, /*IsROV=*/false, /*RawBuffer=*/true) + .addAppendMethod() .completeDefinition(); }); @@ -925,6 +976,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() { onCompletion(Decl, [this](CXXRecordDecl *Decl) { setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer, /*IsROV=*/false, /*RawBuffer=*/true) + .addConsumeMethod() .completeDefinition(); }); diff --git a/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl b/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl index afee0e120afdb1..6cb4379ef5f556 100644 --- a/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl +++ b/clang/test/AST/HLSL/StructuredBuffers-AST.hlsl @@ -20,7 +20,7 @@ // // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \ // RUN: -DRESOURCE=AppendStructuredBuffer %s | FileCheck -DRESOURCE=AppendStructuredBuffer \ -// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s +// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-APPEND %s // // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \ // RUN: -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \ @@ -28,7 +28,7 @@ // // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \ // RUN: -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \ -// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s +// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-CONSUME %s // // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \ // RUN: -DRESOURCE=RasterizerOrderedStructuredBuffer %s | FileCheck -DRESOURCE=RasterizerOrderedStructuredBuffer \ @@ -135,6 +135,48 @@ RESOURCE<float> Buffer; // CHECK-COUNTER-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1 // CHECK-COUNTER-NEXT: AlwaysInlineAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit always_inline +// CHECK-APPEND: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Append 'void (element_type)' +// CHECK-APPEND-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> value 'element_type' +// CHECK-APPEND-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> +// CHECK-APPEND-NEXT: BinaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' '=' +// CHECK-APPEND-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow +// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *' +// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept' +// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t +// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]] +// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]] +// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle +// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this +// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' +// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept' +// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t +// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]] +// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]] +// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle +// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this +// CHECK-APPEND-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' 1 +// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' ParmVar 0x{{[0-9A-Fa-f]+}} 'value' 'element_type' + +// CHECK-CONSUME: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Consume 'element_type ()' +// CHECK-CONSUME-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> +// CHECK-CONSUME-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> +// CHECK-CONSUME-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow +// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *' +// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept' +// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t +// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]] +// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]] +// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle +// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this +// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' +// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept' +// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t +// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]] +// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]] +// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle +// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this +// CHECK-CONSUME-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1 + // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class [[RESOURCE]] definition // CHECK: TemplateArgument type 'float' diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl index b7986ae7dda1c2..53abdc71bdd4b8 100644 --- a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl +++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl @@ -5,21 +5,45 @@ RWStructuredBuffer<float> RWSB1 : register(u0); RWStructuredBuffer<float> RWSB2 : register(u1); +AppendStructuredBuffer<float> ASB : register(u2); +ConsumeStructuredBuffer<float> CSB : register(u3); // CHECK: %"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", float, 1, 0) } -export void TestIncrementCounter() { - RWSB1.IncrementCounter(); +export int TestIncrementCounter() { + return RWSB1.IncrementCounter(); } -// CHECK: define void @_Z20TestIncrementCounterv() -// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1) +// CHECK: define noundef i32 @_Z20TestIncrementCounterv() +// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1) +// CHECK-DXIL: ret i32 %[[INDEX]] +export int TestDecrementCounter() { + return RWSB2.DecrementCounter(); +} + +// CHECK: define noundef i32 @_Z20TestDecrementCounterv() +// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1) +// CHECK-DXIL: ret i32 %[[INDEX]] + +export void TestAppend(float value) { + ASB.Append(value); +} + +// CHECK: define void @_Z10TestAppendf(float noundef %value) +// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %value.addr, align 4 +// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1) +// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i32 %[[INDEX]]) +// CHECK-DXIL: store float %[[VALUE]], ptr %[[RESPTR]], align 4 -export void TestDecrementCounter() { - RWSB2.DecrementCounter(); +export float TestConsume() { + return CSB.Consume(); } -// CHECK: define void @_Z20TestDecrementCounterv() -// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1) +// CHECK: define noundef float @_Z11TestConsumev() +// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %1, i8 -1) +// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %0, i32 %[[INDEX]]) +// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %[[RESPTR]], align 4 +// CHECK-DXIL: ret float %[[VALUE]] // CHECK: declare i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i8) +// CHECK: declare ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i32) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits