https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/99282
>From 9c8163db0df6f3d89f32239fbbd6dd47f5eec1a6 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <yuxuanchen1...@outlook.com> Date: Tue, 4 Jun 2024 23:22:00 -0700 Subject: [PATCH] [Clang] Introduce [[clang::coro_await_elidable]] --- clang/docs/ReleaseNotes.rst | 3 + clang/include/clang/AST/Expr.h | 3 + clang/include/clang/AST/Stmt.h | 5 +- clang/include/clang/Basic/Attr.td | 8 ++ clang/include/clang/Basic/AttrDocs.td | 33 ++++++- clang/lib/AST/Expr.cpp | 2 + clang/lib/CodeGen/CGBlocks.cpp | 5 +- clang/lib/CodeGen/CGCUDARuntime.cpp | 5 +- clang/lib/CodeGen/CGCUDARuntime.h | 8 +- clang/lib/CodeGen/CGCXXABI.h | 10 +-- clang/lib/CodeGen/CGClass.cpp | 16 ++-- clang/lib/CodeGen/CGExpr.cpp | 55 ++++++++---- clang/lib/CodeGen/CGExprCXX.cpp | 60 +++++++------ clang/lib/CodeGen/CodeGenFunction.cpp | 6 ++ clang/lib/CodeGen/CodeGenFunction.h | 64 ++++++++------ clang/lib/CodeGen/ItaniumCXXABI.cpp | 16 ++-- clang/lib/CodeGen/MicrosoftCXXABI.cpp | 18 ++-- clang/lib/Sema/SemaCoroutine.cpp | 24 ++++- clang/test/CodeGenCoroutines/Inputs/utility.h | 13 +++ .../CodeGenCoroutines/coro-await-elidable.cpp | 88 +++++++++++++++++++ ...a-attribute-supported-attributes-list.test | 1 + llvm/include/llvm/Bitcode/LLVMBitCodes.h | 2 + llvm/include/llvm/IR/Attributes.td | 9 ++ llvm/lib/Bitcode/Reader/BitcodeReader.cpp | 2 + llvm/lib/Bitcode/Writer/BitcodeWriter.cpp | 4 + llvm/lib/Transforms/Utils/CodeExtractor.cpp | 2 + 26 files changed, 352 insertions(+), 110 deletions(-) create mode 100644 clang/test/CodeGenCoroutines/Inputs/utility.h create mode 100644 clang/test/CodeGenCoroutines/coro-await-elidable.cpp diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index 1df3f0e7e75ca3..d68efa647d413f 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -209,6 +209,9 @@ Attribute Changes in Clang - ``[[clang::lifetimebound]]`` is now explicitly disallowed on explicit object member functions where they were previously silently ignored. +- Introduced a new attribute ``[[clang::coro_await_elidable]]`` on coroutine return types + to express elideability at call sites where the coroutine is co_awaited as a prvalue. + Improvements to Clang's diagnostics ----------------------------------- diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h index 7bacf028192c65..77ce378df26dcf 100644 --- a/clang/include/clang/AST/Expr.h +++ b/clang/include/clang/AST/Expr.h @@ -2991,6 +2991,9 @@ class CallExpr : public Expr { bool hasStoredFPFeatures() const { return CallExprBits.HasFPFeatures; } + bool isCoroElideSafe() const { return CallExprBits.IsCoroElideSafe; } + void setCoroElideSafe(bool V = true) { CallExprBits.IsCoroElideSafe = V; } + Decl *getCalleeDecl() { return getCallee()->getReferencedDeclOfCallee(); } const Decl *getCalleeDecl() const { return getCallee()->getReferencedDeclOfCallee(); diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h index f1a2aac0a8b2f8..7aed83e9c68bb7 100644 --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -561,8 +561,11 @@ class alignas(void *) Stmt { LLVM_PREFERRED_TYPE(bool) unsigned HasFPFeatures : 1; + /// True if the call expression is a must-elide call to a coroutine. + unsigned IsCoroElideSafe : 1; + /// Padding used to align OffsetToTrailingObjects to a byte multiple. - unsigned : 24 - 3 - NumExprBits; + unsigned : 24 - 4 - NumExprBits; /// The offset in bytes from the this pointer to the start of the /// trailing objects belonging to CallExpr. Intentionally byte sized diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 10a9d9e899e007..8a36cb7cde51ef 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1220,6 +1220,14 @@ def CoroDisableLifetimeBound : InheritableAttr { let SimpleHandler = 1; } +def CoroAwaitElidable : InheritableAttr { + let Spellings = [Clang<"coro_await_elidable">]; + let Subjects = SubjectList<[CXXRecord]>; + let LangOpts = [CPlusPlus]; + let Documentation = [CoroAwaitElidableDoc]; + let SimpleHandler = 1; +} + // OSObject-based attributes. def OSConsumed : InheritableParamAttr { let Spellings = [Clang<"os_consumed">]; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index 19cbb9a0111a28..b6c004e4344c5c 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8173,6 +8173,38 @@ but do not pass them to the underlying coroutine or pass them by value. }]; } +def CoroAwaitElidableDoc : Documentation { + let Category = DocCatDecl; + let Content = [{ +The ``[[clang::coro_await_elidable]]`` is a class attribute which can be applied +to a coroutine return type. + +When a coroutine function that returns such a type calls another coroutine function, +the compiler performs heap allocation elision when the call to the coroutine function +is immediately co_awaited as a prvalue. In this case, the coroutine frame for the +callee will be a local variable within the enclosing braces in the caller's stack +frame. And the local variable, like other variables in coroutines, may be collected +into the coroutine frame, which may be allocated on the heap. + +Example: + +.. code-block:: c++ + + class [[clang::coro_await_elidable]] Task { ... }; + + Task foo(); + Task bar() { + co_await foo(); // foo()'s coroutine frame on this line is elidable + auto t = foo(); // foo()'s coroutine frame on this line is NOT elidable + co_await t; + } + +The behavior is undefined if the caller coroutine is destroyed earlier than the +callee coroutine. + +}]; +} + def CountedByDocs : Documentation { let Category = DocCatField; let Content = [{ @@ -8332,4 +8364,3 @@ Declares that a function potentially allocates heap memory, and prevents any pot of ``nonallocating`` by the compiler. }]; } - diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index 25ab6f3b2addfb..5ff0fb0b1d88d2 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -1474,6 +1474,7 @@ CallExpr::CallExpr(StmtClass SC, Expr *Fn, ArrayRef<Expr *> PreArgs, this->computeDependence(); CallExprBits.HasFPFeatures = FPFeatures.requiresTrailingStorage(); + CallExprBits.IsCoroElideSafe = false; if (hasStoredFPFeatures()) setStoredFPFeatures(FPFeatures); } @@ -1489,6 +1490,7 @@ CallExpr::CallExpr(StmtClass SC, unsigned NumPreArgs, unsigned NumArgs, assert((CallExprBits.OffsetToTrailingObjects == OffsetToTrailingObjects) && "OffsetToTrailingObjects overflow!"); CallExprBits.HasFPFeatures = HasFPFeatures; + CallExprBits.IsCoroElideSafe = false; } CallExpr *CallExpr::Create(const ASTContext &Ctx, Expr *Fn, diff --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp index 066139b1c78c7f..684fda74407313 100644 --- a/clang/lib/CodeGen/CGBlocks.cpp +++ b/clang/lib/CodeGen/CGBlocks.cpp @@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() { } RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>(); llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee()); llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType(); @@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, CGCallee Callee(CGCalleeInfo(), Func); // And call the block. - return EmitCall(FnInfo, Callee, ReturnValue, Args); + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke); } Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) { diff --git a/clang/lib/CodeGen/CGCUDARuntime.cpp b/clang/lib/CodeGen/CGCUDARuntime.cpp index c14a9d3f2bbbcf..1e1da1e2411a76 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.cpp +++ b/clang/lib/CodeGen/CGCUDARuntime.cpp @@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {} RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok"); llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end"); @@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, eval.begin(CGF); CGF.EmitBlock(ConfigOKBlock); - CGF.EmitSimpleCallExpr(E, ReturnValue); + CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke); CGF.EmitBranch(ContBlock); CGF.EmitBlock(ContBlock); diff --git a/clang/lib/CodeGen/CGCUDARuntime.h b/clang/lib/CodeGen/CGCUDARuntime.h index 8030d632cc3d28..86f776004ee7c1 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.h +++ b/clang/lib/CodeGen/CGCUDARuntime.h @@ -21,6 +21,7 @@ #include "llvm/IR/GlobalValue.h" namespace llvm { +class CallBase; class Function; class GlobalVariable; } @@ -82,9 +83,10 @@ class CGCUDARuntime { CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {} virtual ~CGCUDARuntime(); - virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF, - const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + virtual RValue + EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); /// Emits a kernel launch stub. virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0; diff --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h index 7dcc539111996b..687ff7fb844445 100644 --- a/clang/lib/CodeGen/CGCXXABI.h +++ b/clang/lib/CodeGen/CGCXXABI.h @@ -485,11 +485,11 @@ class CGCXXABI { llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>; /// Emit the ABI-specific virtual destructor call. - virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, - Address This, - DeleteOrMemberCallExpr E) = 0; + virtual llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) = 0; virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index 667e260f2228dc..67ff06e1f8dd59 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2192,15 +2192,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF, return true; } -void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, - CXXCtorType Type, - bool ForVirtualBase, - bool Delegating, - Address This, - CallArgList &Args, - AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, - bool NewPointerIsChecked) { +void CodeGenFunction::EmitCXXConstructorCall( + const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase, + bool Delegating, Address This, CallArgList &Args, + AggValueSlot::Overlap_t Overlap, SourceLocation Loc, + bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) { const CXXRecordDecl *ClassDecl = D->getParent(); if (!NewPointerIsChecked) @@ -2248,7 +2244,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall( Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs); CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type)); - EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc); + EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc); // Generate vtable assumptions if we're constructing a complete object // with a vtable. We don't do this for base subobjects for two reasons: diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 48d9a3b8a5acb3..63ba478c8088a6 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -33,6 +33,7 @@ #include "clang/Basic/SourceManager.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringExtras.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Intrinsics.h" @@ -5478,16 +5479,30 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV, //===--------------------------------------------------------------------===// RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { + llvm::CallBase *CallOrInvokeStorage; + if (!CallOrInvoke) { + CallOrInvoke = &CallOrInvokeStorage; + } + + auto AddCoroElideSafeOnExit = llvm::make_scope_exit([&] { + if (E->isCoroElideSafe()) { + auto *I = *CallOrInvoke; + if (I) + I->addFnAttr(llvm::Attribute::CoroElideSafe); + } + }); + // Builtins never have block type. if (E->getCallee()->getType()->isBlockPointerType()) - return EmitBlockCallExpr(E, ReturnValue); + return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E)) - return EmitCXXMemberCallExpr(CE, ReturnValue); + return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E)) - return EmitCUDAKernelCallExpr(CE, ReturnValue); + return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke); // A CXXOperatorCallExpr is created even for explicit object methods, but // these should be treated like static function call. @@ -5495,7 +5510,7 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, if (const auto *MD = dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl()); MD && MD->isImplicitObjectMemberFunction()) - return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue); + return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke); CGCallee callee = EmitCallee(E->getCallee()); @@ -5508,14 +5523,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr()); } - return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } /// Emit a CallExpr without considering whether it might be a subclass. RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { CGCallee Callee = EmitCallee(E->getCallee()); - return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } // Detect the unusual situation where an inline version is shadowed by a @@ -5719,8 +5737,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) { llvm_unreachable("bad evaluation kind"); } -LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) { - RValue RV = EmitCallExpr(E); +LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke) { + RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke); if (!RV.isScalar()) return MakeAddrLValue(RV.getAggregateAddress(), E->getType(), @@ -5843,9 +5862,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) { AlignmentSource::Decl); } -RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee, - const CallExpr *E, ReturnValueSlot ReturnValue, - llvm::Value *Chain) { +RValue CodeGenFunction::EmitCall(QualType CalleeType, + const CGCallee &OrigCallee, const CallExpr *E, + ReturnValueSlot ReturnValue, + llvm::Value *Chain, + llvm::CallBase **CallOrInvoke) { // Get the actual function type. The callee type will always be a pointer to // function type or a block pointer type. assert(CalleeType->isFunctionPointerType() && @@ -6065,8 +6086,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee Address(Handle, Handle->getType(), CGM.getPointerAlign())); Callee.setFunctionPointer(Stub); } - llvm::CallBase *CallOrInvoke = nullptr; - RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke, + llvm::CallBase *LocalCallOrInvoke = nullptr; + RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, E == MustTailCall, E->getExprLoc()); // Generate function declaration DISuprogram in order to be used @@ -6075,11 +6096,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { FunctionArgList Args; QualType ResTy = BuildFunctionArgList(CalleeDecl, Args); - DI->EmitFuncDeclForCallSite(CallOrInvoke, + DI->EmitFuncDeclForCallSite(LocalCallOrInvoke, DI->getFunctionType(CalleeDecl, ResTy, Args), CalleeDecl); } } + if (CallOrInvoke) + *CallOrInvoke = LocalCallOrInvoke; return Call; } diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp index 8eb6ab7381acbc..1214bb054fb8df 100644 --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -84,23 +84,24 @@ commonEmitCXXMemberOrOperatorCall(CodeGenFunction &CGF, GlobalDecl GD, RValue CodeGenFunction::EmitCXXMemberOrOperatorCall( const CXXMethodDecl *MD, const CGCallee &Callee, - ReturnValueSlot ReturnValue, - llvm::Value *This, llvm::Value *ImplicitParam, QualType ImplicitParamTy, - const CallExpr *CE, CallArgList *RtlArgs) { + ReturnValueSlot ReturnValue, llvm::Value *This, llvm::Value *ImplicitParam, + QualType ImplicitParamTy, const CallExpr *CE, CallArgList *RtlArgs, + llvm::CallBase **CallOrInvoke) { const FunctionProtoType *FPT = MD->getType()->castAs<FunctionProtoType>(); CallArgList Args; MemberCallInfo CallInfo = commonEmitCXXMemberOrOperatorCall( *this, MD, This, ImplicitParam, ImplicitParamTy, CE, Args, RtlArgs); auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall( Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize); - return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr, + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke, CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation()); } RValue CodeGenFunction::EmitCXXDestructorCall( GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, - llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE) { + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE, + llvm::CallBase **CallOrInvoke) { const CXXMethodDecl *DtorDecl = cast<CXXMethodDecl>(Dtor.getDecl()); assert(!ThisTy.isNull()); @@ -120,7 +121,8 @@ RValue CodeGenFunction::EmitCXXDestructorCall( commonEmitCXXMemberOrOperatorCall(*this, Dtor, This, ImplicitParam, ImplicitParamTy, CE, Args, nullptr); return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee, - ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall, + ReturnValueSlot(), Args, CallOrInvoke, + CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation{}); } @@ -186,11 +188,12 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) { // Note: This function also emit constructor calls to support a MSVC // extensions allowing explicit constructor function call. RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const Expr *callee = CE->getCallee()->IgnoreParens(); if (isa<BinaryOperator>(callee)) - return EmitCXXMemberPointerCallExpr(CE, ReturnValue); + return EmitCXXMemberPointerCallExpr(CE, ReturnValue, CallOrInvoke); const MemberExpr *ME = cast<MemberExpr>(callee); const CXXMethodDecl *MD = cast<CXXMethodDecl>(ME->getMemberDecl()); @@ -200,7 +203,7 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, CGCallee callee = CGCallee::forDirect(CGM.GetAddrOfFunction(MD), GlobalDecl(MD)); return EmitCall(getContext().getPointerType(MD->getType()), callee, CE, - ReturnValue); + ReturnValue, /*Chain=*/nullptr, CallOrInvoke); } bool HasQualifier = ME->hasQualifier(); @@ -208,14 +211,15 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, bool IsArrow = ME->isArrow(); const Expr *Base = ME->getBase(); - return EmitCXXMemberOrOperatorMemberCallExpr( - CE, MD, ReturnValue, HasQualifier, Qualifier, IsArrow, Base); + return EmitCXXMemberOrOperatorMemberCallExpr(CE, MD, ReturnValue, + HasQualifier, Qualifier, IsArrow, + Base, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, - const Expr *Base) { + const Expr *Base, llvm::CallBase **CallOrInvoke) { assert(isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)); // Compute the object pointer. @@ -300,7 +304,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( EmitCXXConstructorCall(Ctor, Ctor_Complete, /*ForVirtualBase=*/false, /*Delegating=*/false, This.getAddress(), Args, AggValueSlot::DoesNotOverlap, CE->getExprLoc(), - /*NewPointerIsChecked=*/false); + /*NewPointerIsChecked=*/false, CallOrInvoke); return RValue::get(nullptr); } @@ -374,9 +378,9 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( "Destructor shouldn't have explicit parameters"); assert(ReturnValue.isNull() && "Destructor shouldn't have return value"); if (UseVirtualCall) { - CGM.getCXXABI().EmitVirtualDestructorCall(*this, Dtor, Dtor_Complete, - This.getAddress(), - cast<CXXMemberCallExpr>(CE)); + CGM.getCXXABI().EmitVirtualDestructorCall( + *this, Dtor, Dtor_Complete, This.getAddress(), + cast<CXXMemberCallExpr>(CE), CallOrInvoke); } else { GlobalDecl GD(Dtor, Dtor_Complete); CGCallee Callee; @@ -393,7 +397,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( IsArrow ? Base->getType()->getPointeeType() : Base->getType(); EmitCXXDestructorCall(GD, Callee, This.getPointer(*this), ThisTy, /*ImplicitParam=*/nullptr, - /*ImplicitParamTy=*/QualType(), CE); + /*ImplicitParamTy=*/QualType(), CE, CallOrInvoke); } return RValue::get(nullptr); } @@ -435,12 +439,13 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( return EmitCXXMemberOrOperatorCall( CalleeDecl, Callee, ReturnValue, This.getPointer(*this), - /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs); + /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const BinaryOperator *BO = cast<BinaryOperator>(E->getCallee()->IgnoreParens()); const Expr *BaseExpr = BO->getLHS(); @@ -484,24 +489,25 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, EmitCallArgs(Args, FPT, E->arguments()); return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required, /*PrefixSize=*/0), - Callee, ReturnValue, Args, nullptr, E == MustTailCall, + Callee, ReturnValue, Args, CallOrInvoke, E == MustTailCall, E->getExprLoc()); } -RValue -CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue) { +RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr( + const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, + ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) { assert(MD->isImplicitObjectMemberFunction() && "Trying to emit a member call expr on a static method!"); return EmitCXXMemberOrOperatorMemberCallExpr( E, MD, ReturnValue, /*HasQualifier=*/false, /*Qualifier=*/nullptr, - /*IsArrow=*/false, E->getArg(0)); + /*IsArrow=*/false, E->getArg(0), CallOrInvoke); } RValue CodeGenFunction::EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { - return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { + return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue, + CallOrInvoke); } static void EmitNullBaseClassInitialization(CodeGenFunction &CGF, diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index eff8c9f5694084..b58ec2d7e1a554 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -523,6 +523,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) { NormalCleanupDest = Address::invalid(); } + if (getLangOpts().Coroutines && isCoroutine()) { + auto *Record = FnRetTy->getAsCXXRecordDecl(); + if (Record && Record->hasAttr<CoroAwaitElidableAttr>()) + CurFn->addFnAttr(llvm::Attribute::CoroGenNoallocRamp); + } + // Scan function arguments for vector width. for (llvm::Argument &A : CurFn->args()) if (auto *VT = dyn_cast<llvm::VectorType>(A.getType())) diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 57e0b7f91e9bf8..1941a35f273037 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3150,7 +3150,8 @@ class CodeGenFunction : public CodeGenTypeCache { bool ForVirtualBase, bool Delegating, Address This, CallArgList &Args, AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, bool NewPointerIsChecked); + SourceLocation Loc, bool NewPointerIsChecked, + llvm::CallBase **CallOrInvoke = nullptr); /// Emit assumption load for all bases. Requires to be called only on /// most-derived class and not under construction of the object. @@ -4270,7 +4271,8 @@ class CodeGenFunction : public CodeGenTypeCache { LValue EmitBinaryOperatorLValue(const BinaryOperator *E); LValue EmitCompoundAssignmentLValue(const CompoundAssignOperator *E); // Note: only available for agg return types - LValue EmitCallExprLValue(const CallExpr *E); + LValue EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); // Note: only available for agg return types LValue EmitVAArgExprLValue(const VAArgExpr *E); LValue EmitDeclRefLValue(const DeclRefExpr *E); @@ -4380,21 +4382,27 @@ class CodeGenFunction : public CodeGenTypeCache { /// LLVM arguments and the types they were derived from. RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke, bool IsMustTail, + llvm::CallBase **CallOrInvoke, bool IsMustTail, SourceLocation Loc, bool IsVirtualFunctionPointerThunk = false); RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke = nullptr, + llvm::CallBase **CallOrInvoke = nullptr, bool IsMustTail = false) { - return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke, + return EmitCall(CallInfo, Callee, ReturnValue, Args, CallOrInvoke, IsMustTail, SourceLocation()); } RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E, - ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr); + ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr, + llvm::CallBase **CallOrInvoke = nullptr); + + // If a Call or Invoke instruction was emitted for this CallExpr, this method + // writes the pointer to `CallOrInvoke` if it's not null. RValue EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue = ReturnValueSlot()); - RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue = ReturnValueSlot(), + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); CGCallee EmitCallee(const Expr *E); void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl); @@ -4498,25 +4506,23 @@ class CodeGenFunction : public CodeGenTypeCache { void callCStructCopyAssignmentOperator(LValue Dst, LValue Src); void callCStructMoveAssignmentOperator(LValue Dst, LValue Src); - RValue - EmitCXXMemberOrOperatorCall(const CXXMethodDecl *Method, - const CGCallee &Callee, - ReturnValueSlot ReturnValue, llvm::Value *This, - llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E, - CallArgList *RtlArgs); + RValue EmitCXXMemberOrOperatorCall( + const CXXMethodDecl *Method, const CGCallee &Callee, + ReturnValueSlot ReturnValue, llvm::Value *This, + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *E, + CallArgList *RtlArgs, llvm::CallBase **CallOrInvoke); RValue EmitCXXDestructorCall(GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E); + QualType ImplicitParamTy, const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); - RValue EmitCXXMemberOrOperatorMemberCallExpr(const CallExpr *CE, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue, - bool HasQualifier, - NestedNameSpecifier *Qualifier, - bool IsArrow, const Expr *Base); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitCXXMemberOrOperatorMemberCallExpr( + const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, + bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, + const Expr *Base, llvm::CallBase **CallOrInvoke); // Compute the object pointer. Address EmitCXXMemberDataPointerAddress(const Expr *E, Address base, llvm::Value *memberPtr, @@ -4524,15 +4530,18 @@ class CodeGenFunction : public CodeGenTypeCache { LValueBaseInfo *BaseInfo = nullptr, TBAAAccessInfo *TBAAInfo = nullptr); RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *E); RValue EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E); RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E); @@ -4554,7 +4563,8 @@ class CodeGenFunction : public CodeGenTypeCache { const analyze_os_log::OSLogBufferLayout &Layout, CharUnits BufferAlignment); - RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); /// EmitTargetBuiltinExpr - Emit the given builtin call. Returns 0 if the call /// is unhandled by the current target. diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index 0cde8a192eda08..28b0880d013e5b 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -315,10 +315,11 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void emitVirtualInheritanceTables(const CXXRecordDecl *RD) override; @@ -1396,7 +1397,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // FIXME: Provide a source location here even though there's no // CXXMemberCallExpr for dtor call. CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.PopCleanupBlock(); @@ -2233,7 +2235,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast<const CXXMemberCallExpr *>(); auto *D = E.dyn_cast<const CXXDeleteExpr *>(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2254,7 +2256,7 @@ llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( } CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - nullptr, QualType(), nullptr); + nullptr, QualType(), nullptr, CallOrInvoke); return nullptr; } diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp index cc6740edabcd3c..24ae0ece4d7006 100644 --- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp +++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp @@ -334,10 +334,11 @@ class MicrosoftCXXABI : public CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, CallArgList &CallArgs) override { @@ -901,7 +902,8 @@ void MicrosoftCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // CXXMemberCallExpr for dtor call. bool UseGlobalDelete = DE->isGlobalDelete(); CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.EmitDeleteCall(DE->getOperatorDelete(), MDThis, ElementType); } @@ -1685,7 +1687,7 @@ void MicrosoftCXXABI::EmitDestructorCall(CodeGenFunction &CGF, CGF.EmitCXXDestructorCall(GD, Callee, CGF.getAsNaturalPointerTo(This, ThisTy), ThisTy, /*ImplicitParam=*/Implicit, - /*ImplicitParamTy=*/QualType(), nullptr); + /*ImplicitParamTy=*/QualType(), /*E=*/nullptr); if (BaseDtorEndBB) { // Complete object handler should continue to be the remaining CGF.Builder.CreateBr(BaseDtorEndBB); @@ -2001,7 +2003,7 @@ CGCallee MicrosoftCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast<const CXXMemberCallExpr *>(); auto *D = E.dyn_cast<const CXXDeleteExpr *>(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2031,7 +2033,7 @@ llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( This = adjustThisArgumentForVirtualFunctionCall(CGF, GD, This, true); RValue RV = CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - ImplicitParam, Context.IntTy, CE); + ImplicitParam, Context.IntTy, CE, CallOrInvoke); return RV.getScalarVal(); } diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 1bb8955f6f8792..ea8dd97cef1173 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -844,6 +844,19 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) { return CoawaitOp; } +static bool isAttributedCoroInplaceTask(const QualType &QT) { + auto *Record = QT->getAsCXXRecordDecl(); + return Record && Record->hasAttr<CoroAwaitElidableAttr>(); +} + +static bool isCoroInplaceCall(Expr *Operand) { + if (!Operand->isPRValue()) { + return false; + } + + return isAttributedCoroInplaceTask(Operand->getType()); +} + // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to // DependentCoawaitExpr if needed. ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, @@ -867,7 +880,16 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, } auto *RD = Promise->getType()->getAsCXXRecordDecl(); - auto *Transformed = Operand; + bool InplaceCall = + isCoroInplaceCall(Operand) && + isAttributedCoroInplaceTask( + getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType()); + + if (InplaceCall) + if (auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit())) + Call->setCoroElideSafe(); + + Expr *Transformed = Operand; if (lookupMember(*this, "await_transform", RD, Loc)) { ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", Operand); diff --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h new file mode 100644 index 00000000000000..43c6d27823bd47 --- /dev/null +++ b/clang/test/CodeGenCoroutines/Inputs/utility.h @@ -0,0 +1,13 @@ +// This is a mock file for <utility> + +namespace std { + +template <typename T> struct remove_reference { using type = T; }; +template <typename T> struct remove_reference<T &> { using type = T; }; +template <typename T> struct remove_reference<T &&> { using type = T; }; + +template <typename T> +constexpr typename std::remove_reference<T>::type&& move(T &&t) noexcept { + return static_cast<typename std::remove_reference<T>::type &&>(t); +} +} diff --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp new file mode 100644 index 00000000000000..e81d8f8ab361c7 --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp @@ -0,0 +1,88 @@ +// This file tests the coro_await_elidable attribute semantics. +// RUN: %clang_cc1 -triple=x86_64-unknown-linux-gnu -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s + +#include "Inputs/coroutine.h" +#include "Inputs/utility.h" + +template <typename T> +struct [[clang::coro_await_elidable]] Task { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + + template <typename P> + std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept { + if (!coro) + return std::noop_coroutine(); + return coro.promise().continuation; + } + void await_resume() noexcept {} + }; + + Task get_return_object() noexcept { + return std::coroutine_handle<promise_type>::from_promise(*this); + } + + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_value(T x) noexcept { + value = x; + } + + std::coroutine_handle<> continuation; + T value; + }; + + Task(std::coroutine_handle<promise_type> handle) : handle(handle) {} + ~Task() { + if (handle) + handle.destroy(); + } + + struct Awaiter { + Awaiter(Task *t) : task(t) {} + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle<void> continuation) noexcept {} + T await_resume() noexcept { + return task->handle.promise().value; + } + + Task *task; + }; + + auto operator co_await() { + return Awaiter{this}; + } + +private: + std::coroutine_handle<promise_type> handle; +}; + +// CHECK-LABEL: define{{.*}} @_Z6calleev{{.*}} #0 { +Task<int> callee() { + co_return 1; +} + +// CHECK-LABEL: define{{.*}} @_Z8elidablev{{.*}} #0 { +Task<int> elidable() { + // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task + // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]]) #[[ELIDE_SAFE:.+]] + co_return co_await callee(); +} + +// CHECK-LABEL: define{{.*}} @_Z11nonelidablev{{.*}} #0 { +Task<int> nonelidable() { + // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task + auto t = callee(); + // Because we aren't co_awaiting a prvalue, we cannot elide here. + // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]]) + // CHECK-NOT: #[[ELIDE_SAFE]] + co_await t; + co_await std::move(t); + + co_return 1; +} + +// CHECK: attributes #0 = { coro_gen_noalloc_ramp {{.*}} } +// CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe } diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test index 1a71556213bb16..5430924664fac0 100644 --- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test +++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test @@ -58,6 +58,7 @@ // CHECK-NEXT: ConsumableAutoCast (SubjectMatchRule_record) // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record) // CHECK-NEXT: Convergent (SubjectMatchRule_function) +// CHECK-NEXT: CoroAwaitElidable (SubjectMatchRule_record) // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function) // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record) // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record) diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h index 4beac37a583445..327c6606bc2b79 100644 --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -759,6 +759,8 @@ enum AttributeKindCodes { ATTR_KIND_INITIALIZES = 94, ATTR_KIND_HYBRID_PATCHABLE = 95, ATTR_KIND_SANITIZE_REALTIME = 96, + ATTR_KIND_CORO_ELIDE_SAFE = 97, + ATTR_KIND_CORO_GEN_NOALLOC_RAMP = 98, }; enum ComdatSelectionKindCodes { diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td index 891e34fec0c798..d04c26231173c7 100644 --- a/llvm/include/llvm/IR/Attributes.td +++ b/llvm/include/llvm/IR/Attributes.td @@ -345,6 +345,15 @@ def PresplitCoroutine : EnumAttr<"presplitcoroutine", [FnAttr]>; /// The coroutine would only be destroyed when it is complete. def CoroDestroyOnlyWhenComplete : EnumAttr<"coro_only_destroy_when_complete", [FnAttr]>; +/// The coroutine call meets the elide requirement. Hint the optimization +/// pipeline to perform elide on the call or invoke instruction. +def CoroElideSafe : EnumAttr<"coro_elide_safe", [FnAttr]>; + +/// Generate a .noalloc ramp function for coroutine. The frontend emits this +/// attribute only on coroutines that return a "coro_await_elidable" type. +/// CoroSplit reads this attribute and conditionally build the noalloc ramp. +def CoroGenNoallocRamp : EnumAttr<"coro_gen_noalloc_ramp", [FnAttr]>; + /// Target-independent string attributes. def LessPreciseFPMAD : StrBoolAttr<"less-precise-fpmad">; def NoInfsFPMath : StrBoolAttr<"no-infs-fp-math">; diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index d4dbab04e8ecdb..e30f25027b1f39 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -2187,6 +2187,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) { return Attribute::Range; case bitc::ATTR_KIND_INITIALIZES: return Attribute::Initializes; + case bitc::ATTR_KIND_CORO_ELIDE_SAFE: + return Attribute::CoroElideSafe; } } diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 03d0537291dada..b2cdf480b4824d 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -881,6 +881,10 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) { return bitc::ATTR_KIND_WRITABLE; case Attribute::CoroDestroyOnlyWhenComplete: return bitc::ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE; + case Attribute::CoroElideSafe: + return bitc::ATTR_KIND_CORO_ELIDE_SAFE; + case Attribute::CoroGenNoallocRamp: + return bitc::ATTR_KIND_CORO_GEN_NOALLOC_RAMP; case Attribute::DeadOnUnwind: return bitc::ATTR_KIND_DEAD_ON_UNWIND; case Attribute::Range: diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 81d3243c887fce..ff328135a41315 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -913,6 +913,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::Memory: case Attribute::NoFPClass: case Attribute::CoroDestroyOnlyWhenComplete: + case Attribute::CoroElideSafe: + case Attribute::CoroGenNoallocRamp: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits