https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/94693
>From 4e4bcdc19268543a8348736dede46d8f8cad0066 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <yuxuanchen1...@outlook.com> Date: Tue, 4 Jun 2024 23:22:00 -0700 Subject: [PATCH 1/3] [Clang] Introduce [[clang::coro_inplace_task]] --- clang/include/clang/AST/ExprCXX.h | 26 ++++-- clang/include/clang/Basic/Attr.td | 8 ++ clang/include/clang/Basic/AttrDocs.td | 19 +++++ clang/lib/CodeGen/CGBlocks.cpp | 5 +- clang/lib/CodeGen/CGCUDARuntime.cpp | 5 +- clang/lib/CodeGen/CGCUDARuntime.h | 8 +- clang/lib/CodeGen/CGCXXABI.h | 10 +-- clang/lib/CodeGen/CGClass.cpp | 16 ++-- clang/lib/CodeGen/CGCoroutine.cpp | 30 +++++-- clang/lib/CodeGen/CGExpr.cpp | 41 +++++---- clang/lib/CodeGen/CGExprCXX.cpp | 60 +++++++------ clang/lib/CodeGen/CodeGenFunction.h | 64 ++++++++------ clang/lib/CodeGen/ItaniumCXXABI.cpp | 16 ++-- clang/lib/CodeGen/MicrosoftCXXABI.cpp | 18 ++-- clang/lib/Sema/SemaCoroutine.cpp | 58 ++++++++++++- clang/lib/Serialization/ASTReaderStmt.cpp | 10 ++- clang/lib/Serialization/ASTWriterStmt.cpp | 3 +- clang/test/CodeGenCoroutines/Inputs/utility.h | 13 +++ .../coro-structured-concurrency.cpp | 84 +++++++++++++++++++ ...a-attribute-supported-attributes-list.test | 1 + llvm/include/llvm/IR/Intrinsics.td | 3 + .../lib/Transforms/Coroutines/CoroCleanup.cpp | 9 +- llvm/lib/Transforms/Coroutines/CoroElide.cpp | 58 ++++++++++++- llvm/lib/Transforms/Coroutines/Coroutines.cpp | 1 + .../coro-elide-structured-concurrency.ll | 64 ++++++++++++++ 25 files changed, 496 insertions(+), 134 deletions(-) create mode 100644 clang/test/CodeGenCoroutines/Inputs/utility.h create mode 100644 clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp create mode 100644 llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h index c2feac525c1e..0cf62aee41b6 100644 --- a/clang/include/clang/AST/ExprCXX.h +++ b/clang/include/clang/AST/ExprCXX.h @@ -5082,7 +5082,8 @@ class CoroutineSuspendExpr : public Expr { enum SubExpr { Operand, Common, Ready, Suspend, Resume, Count }; Stmt *SubExprs[SubExpr::Count]; - OpaqueValueExpr *OpaqueValue = nullptr; + OpaqueValueExpr *CommonExprOpaqueValue = nullptr; + OpaqueValueExpr *InplaceCallOpaqueValue = nullptr; public: // These types correspond to the three C++ 'await_suspend' return variants @@ -5090,10 +5091,10 @@ class CoroutineSuspendExpr : public Expr { CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand, Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume, - OpaqueValueExpr *OpaqueValue) + OpaqueValueExpr *CommonExprOpaqueValue) : Expr(SC, Resume->getType(), Resume->getValueKind(), Resume->getObjectKind()), - KeywordLoc(KeywordLoc), OpaqueValue(OpaqueValue) { + KeywordLoc(KeywordLoc), CommonExprOpaqueValue(CommonExprOpaqueValue) { SubExprs[SubExpr::Operand] = Operand; SubExprs[SubExpr::Common] = Common; SubExprs[SubExpr::Ready] = Ready; @@ -5128,7 +5129,16 @@ class CoroutineSuspendExpr : public Expr { } /// getOpaqueValue - Return the opaque value placeholder. - OpaqueValueExpr *getOpaqueValue() const { return OpaqueValue; } + OpaqueValueExpr *getCommonExprOpaqueValue() const { + return CommonExprOpaqueValue; + } + + OpaqueValueExpr *getInplaceCallOpaqueValue() const { + return InplaceCallOpaqueValue; + } + void setInplaceCallOpaqueValue(OpaqueValueExpr *E) { + InplaceCallOpaqueValue = E; + } Expr *getReadyExpr() const { return static_cast<Expr*>(SubExprs[SubExpr::Ready]); @@ -5194,9 +5204,9 @@ class CoawaitExpr : public CoroutineSuspendExpr { public: CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume, - OpaqueValueExpr *OpaqueValue, bool IsImplicit = false) + OpaqueValueExpr *CommonExprOpaqueValue, bool IsImplicit = false) : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Common, - Ready, Suspend, Resume, OpaqueValue) { + Ready, Suspend, Resume, CommonExprOpaqueValue) { CoawaitBits.IsImplicit = IsImplicit; } @@ -5275,9 +5285,9 @@ class CoyieldExpr : public CoroutineSuspendExpr { public: CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume, - OpaqueValueExpr *OpaqueValue) + OpaqueValueExpr *CommonExprOpaqueValue) : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Common, - Ready, Suspend, Resume, OpaqueValue) {} + Ready, Suspend, Resume, CommonExprOpaqueValue) {} CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand, Expr *Common) : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand, diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index d2d9dd24536c..2f02a1f9e3a0 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1217,6 +1217,14 @@ def CoroDisableLifetimeBound : InheritableAttr { let SimpleHandler = 1; } +def CoroInplaceTask : InheritableAttr { + let Spellings = [Clang<"coro_inplace_task">]; + let Subjects = SubjectList<[CXXRecord]>; + let LangOpts = [CPlusPlus]; + let Documentation = [CoroInplaceTaskDoc]; + let SimpleHandler = 1; +} + // OSObject-based attributes. def OSConsumed : InheritableParamAttr { let Spellings = [Clang<"os_consumed">]; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index ab4bd003541f..f8b3d1b0f19c 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8013,6 +8013,25 @@ but do not pass them to the underlying coroutine or pass them by value. }]; } +def CoroInplaceTaskDoc : Documentation { + let Category = DocCatDecl; + let Content = [{ +The ``[[clang::coro_inplace_task]]`` is a class attribute which can be applied +to a coroutine return type. + +When a coroutine function that returns such a type calls another coroutine function, +the compiler performs heap allocation elision when the following conditions are all met: +- callee coroutine function returns a type that is annotated with ``[[clang::coro_inplace_task]]``. +- The callee coroutine function is inlined. +- In caller coroutine, the return value of the callee is a prvalue or an xvalue, and +- The temporary expression containing the callee coroutine object is immediately co_awaited. + +The behavior is undefined if any of the following condition was met: +- the caller coroutine is destroyed earlier than the callee coroutine. + + }]; +} + def CountedByDocs : Documentation { let Category = DocCatField; let Content = [{ diff --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp index 066139b1c78c..684fda744073 100644 --- a/clang/lib/CodeGen/CGBlocks.cpp +++ b/clang/lib/CodeGen/CGBlocks.cpp @@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() { } RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>(); llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee()); llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType(); @@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, CGCallee Callee(CGCalleeInfo(), Func); // And call the block. - return EmitCall(FnInfo, Callee, ReturnValue, Args); + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke); } Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) { diff --git a/clang/lib/CodeGen/CGCUDARuntime.cpp b/clang/lib/CodeGen/CGCUDARuntime.cpp index c14a9d3f2bbb..1e1da1e2411a 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.cpp +++ b/clang/lib/CodeGen/CGCUDARuntime.cpp @@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {} RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok"); llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end"); @@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, eval.begin(CGF); CGF.EmitBlock(ConfigOKBlock); - CGF.EmitSimpleCallExpr(E, ReturnValue); + CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke); CGF.EmitBranch(ContBlock); CGF.EmitBlock(ContBlock); diff --git a/clang/lib/CodeGen/CGCUDARuntime.h b/clang/lib/CodeGen/CGCUDARuntime.h index 8030d632cc3d..86f776004ee7 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.h +++ b/clang/lib/CodeGen/CGCUDARuntime.h @@ -21,6 +21,7 @@ #include "llvm/IR/GlobalValue.h" namespace llvm { +class CallBase; class Function; class GlobalVariable; } @@ -82,9 +83,10 @@ class CGCUDARuntime { CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {} virtual ~CGCUDARuntime(); - virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF, - const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + virtual RValue + EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); /// Emits a kernel launch stub. virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0; diff --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h index 7dcc53911199..687ff7fb8444 100644 --- a/clang/lib/CodeGen/CGCXXABI.h +++ b/clang/lib/CodeGen/CGCXXABI.h @@ -485,11 +485,11 @@ class CGCXXABI { llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>; /// Emit the ABI-specific virtual destructor call. - virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, - Address This, - DeleteOrMemberCallExpr E) = 0; + virtual llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) = 0; virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index 0a595bb998d2..c56716fbd059 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2191,15 +2191,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF, return true; } -void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, - CXXCtorType Type, - bool ForVirtualBase, - bool Delegating, - Address This, - CallArgList &Args, - AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, - bool NewPointerIsChecked) { +void CodeGenFunction::EmitCXXConstructorCall( + const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase, + bool Delegating, Address This, CallArgList &Args, + AggValueSlot::Overlap_t Overlap, SourceLocation Loc, + bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) { const CXXRecordDecl *ClassDecl = D->getParent(); if (!NewPointerIsChecked) @@ -2247,7 +2243,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall( Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs); CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type)); - EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc); + EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc); // Generate vtable assumptions if we're constructing a complete object // with a vtable. We don't do this for base subobjects for two reasons: diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp index a8a70186c2c5..656c1e905317 100644 --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -12,9 +12,11 @@ #include "CGCleanup.h" #include "CodeGenFunction.h" -#include "llvm/ADT/ScopeExit.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtVisitor.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/IR/Intrinsics.h" using namespace clang; using namespace CodeGen; @@ -223,12 +225,22 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co CoroutineSuspendExpr const &S, AwaitKind Kind, AggValueSlot aggSlot, bool ignoreResult, bool forLValue) { - auto *E = S.getCommonExpr(); + auto &Builder = CGF.Builder; - auto CommonBinder = - CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E); - auto UnbindCommonOnExit = - llvm::make_scope_exit([&] { CommonBinder.unbind(CGF); }); + // If S.getInplaceCallOpaqueValue() is null, we don't have a nested opaque + // value for common expression. + std::optional<CodeGenFunction::OpaqueValueMapping> OperandMapping; + if (auto *CallOV = S.getInplaceCallOpaqueValue()) { + auto *CE = cast<CallExpr>(CallOV->getSourceExpr()); + // TODO: don't use the intrisic coro_safe_elide in the next version. + LValue CallResult = CGF.EmitCallExprLValue(CE, nullptr); + OperandMapping.emplace(CGF, CallOV, CallResult); + llvm::Value *Value = CallResult.getPointer(CGF); + auto SafeElide = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_safe_elide); + Builder.CreateCall(SafeElide, Value); + } + CodeGenFunction::OpaqueValueMapping BindCommon(CGF, + S.getCommonExprOpaqueValue()); auto Prefix = buildSuspendPrefixStr(Coro, Kind); BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready")); @@ -241,7 +253,6 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co // Otherwise, emit suspend logic. CGF.EmitBlock(SuspendBlock); - auto &Builder = CGF.Builder; llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save); auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy); auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr}); @@ -256,7 +267,8 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co SmallVector<llvm::Value *, 3> SuspendIntrinsicCallArgs; SuspendIntrinsicCallArgs.push_back( - CGF.getOrCreateOpaqueLValueMapping(S.getOpaqueValue()).getPointer(CGF)); + CGF.getOrCreateOpaqueLValueMapping(S.getCommonExprOpaqueValue()) + .getPointer(CGF)); SuspendIntrinsicCallArgs.push_back(CGF.CurCoro.Data->CoroBegin); SuspendIntrinsicCallArgs.push_back(SuspendWrapper); @@ -455,7 +467,7 @@ CodeGenFunction::generateAwaitSuspendWrapper(Twine const &CoroName, Builder.CreateLoad(GetAddrOfLocalVar(&FrameDecl)); auto AwaiterBinder = CodeGenFunction::OpaqueValueMappingData::bind( - *this, S.getOpaqueValue(), AwaiterLValue); + *this, S.getCommonExprOpaqueValue(), AwaiterLValue); auto *SuspendRet = EmitScalarExpr(S.getSuspendExpr()); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 039f60c77459..a4ff896fca6a 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -5429,16 +5429,17 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV, //===--------------------------------------------------------------------===// RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { // Builtins never have block type. if (E->getCallee()->getType()->isBlockPointerType()) - return EmitBlockCallExpr(E, ReturnValue); + return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E)) - return EmitCXXMemberCallExpr(CE, ReturnValue); + return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E)) - return EmitCUDAKernelCallExpr(CE, ReturnValue); + return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke); // A CXXOperatorCallExpr is created even for explicit object methods, but // these should be treated like static function call. @@ -5446,7 +5447,7 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, if (const auto *MD = dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl()); MD && MD->isImplicitObjectMemberFunction()) - return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue); + return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke); CGCallee callee = EmitCallee(E->getCallee()); @@ -5459,14 +5460,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr()); } - return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } /// Emit a CallExpr without considering whether it might be a subclass. RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { CGCallee Callee = EmitCallee(E->getCallee()); - return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } // Detect the unusual situation where an inline version is shadowed by a @@ -5670,8 +5674,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) { llvm_unreachable("bad evaluation kind"); } -LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) { - RValue RV = EmitCallExpr(E); +LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke) { + RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke); if (!RV.isScalar()) return MakeAddrLValue(RV.getAggregateAddress(), E->getType(), @@ -5794,9 +5799,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) { AlignmentSource::Decl); } -RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee, - const CallExpr *E, ReturnValueSlot ReturnValue, - llvm::Value *Chain) { +RValue CodeGenFunction::EmitCall(QualType CalleeType, + const CGCallee &OrigCallee, const CallExpr *E, + ReturnValueSlot ReturnValue, + llvm::Value *Chain, + llvm::CallBase **CallOrInvoke) { // Get the actual function type. The callee type will always be a pointer to // function type or a block pointer type. assert(CalleeType->isFunctionPointerType() && @@ -6007,8 +6014,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee Address(Handle, Handle->getType(), CGM.getPointerAlign())); Callee.setFunctionPointer(Stub); } - llvm::CallBase *CallOrInvoke = nullptr; - RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke, + llvm::CallBase *LocalCallOrInvoke = nullptr; + RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, E == MustTailCall, E->getExprLoc()); // Generate function declaration DISuprogram in order to be used @@ -6017,11 +6024,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { FunctionArgList Args; QualType ResTy = BuildFunctionArgList(CalleeDecl, Args); - DI->EmitFuncDeclForCallSite(CallOrInvoke, + DI->EmitFuncDeclForCallSite(LocalCallOrInvoke, DI->getFunctionType(CalleeDecl, ResTy, Args), CalleeDecl); } } + if (CallOrInvoke) + *CallOrInvoke = LocalCallOrInvoke; return Call; } diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp index 8eb6ab7381ac..1214bb054fb8 100644 --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -84,23 +84,24 @@ commonEmitCXXMemberOrOperatorCall(CodeGenFunction &CGF, GlobalDecl GD, RValue CodeGenFunction::EmitCXXMemberOrOperatorCall( const CXXMethodDecl *MD, const CGCallee &Callee, - ReturnValueSlot ReturnValue, - llvm::Value *This, llvm::Value *ImplicitParam, QualType ImplicitParamTy, - const CallExpr *CE, CallArgList *RtlArgs) { + ReturnValueSlot ReturnValue, llvm::Value *This, llvm::Value *ImplicitParam, + QualType ImplicitParamTy, const CallExpr *CE, CallArgList *RtlArgs, + llvm::CallBase **CallOrInvoke) { const FunctionProtoType *FPT = MD->getType()->castAs<FunctionProtoType>(); CallArgList Args; MemberCallInfo CallInfo = commonEmitCXXMemberOrOperatorCall( *this, MD, This, ImplicitParam, ImplicitParamTy, CE, Args, RtlArgs); auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall( Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize); - return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr, + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke, CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation()); } RValue CodeGenFunction::EmitCXXDestructorCall( GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, - llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE) { + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE, + llvm::CallBase **CallOrInvoke) { const CXXMethodDecl *DtorDecl = cast<CXXMethodDecl>(Dtor.getDecl()); assert(!ThisTy.isNull()); @@ -120,7 +121,8 @@ RValue CodeGenFunction::EmitCXXDestructorCall( commonEmitCXXMemberOrOperatorCall(*this, Dtor, This, ImplicitParam, ImplicitParamTy, CE, Args, nullptr); return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee, - ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall, + ReturnValueSlot(), Args, CallOrInvoke, + CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation{}); } @@ -186,11 +188,12 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) { // Note: This function also emit constructor calls to support a MSVC // extensions allowing explicit constructor function call. RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const Expr *callee = CE->getCallee()->IgnoreParens(); if (isa<BinaryOperator>(callee)) - return EmitCXXMemberPointerCallExpr(CE, ReturnValue); + return EmitCXXMemberPointerCallExpr(CE, ReturnValue, CallOrInvoke); const MemberExpr *ME = cast<MemberExpr>(callee); const CXXMethodDecl *MD = cast<CXXMethodDecl>(ME->getMemberDecl()); @@ -200,7 +203,7 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, CGCallee callee = CGCallee::forDirect(CGM.GetAddrOfFunction(MD), GlobalDecl(MD)); return EmitCall(getContext().getPointerType(MD->getType()), callee, CE, - ReturnValue); + ReturnValue, /*Chain=*/nullptr, CallOrInvoke); } bool HasQualifier = ME->hasQualifier(); @@ -208,14 +211,15 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, bool IsArrow = ME->isArrow(); const Expr *Base = ME->getBase(); - return EmitCXXMemberOrOperatorMemberCallExpr( - CE, MD, ReturnValue, HasQualifier, Qualifier, IsArrow, Base); + return EmitCXXMemberOrOperatorMemberCallExpr(CE, MD, ReturnValue, + HasQualifier, Qualifier, IsArrow, + Base, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, - const Expr *Base) { + const Expr *Base, llvm::CallBase **CallOrInvoke) { assert(isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)); // Compute the object pointer. @@ -300,7 +304,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( EmitCXXConstructorCall(Ctor, Ctor_Complete, /*ForVirtualBase=*/false, /*Delegating=*/false, This.getAddress(), Args, AggValueSlot::DoesNotOverlap, CE->getExprLoc(), - /*NewPointerIsChecked=*/false); + /*NewPointerIsChecked=*/false, CallOrInvoke); return RValue::get(nullptr); } @@ -374,9 +378,9 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( "Destructor shouldn't have explicit parameters"); assert(ReturnValue.isNull() && "Destructor shouldn't have return value"); if (UseVirtualCall) { - CGM.getCXXABI().EmitVirtualDestructorCall(*this, Dtor, Dtor_Complete, - This.getAddress(), - cast<CXXMemberCallExpr>(CE)); + CGM.getCXXABI().EmitVirtualDestructorCall( + *this, Dtor, Dtor_Complete, This.getAddress(), + cast<CXXMemberCallExpr>(CE), CallOrInvoke); } else { GlobalDecl GD(Dtor, Dtor_Complete); CGCallee Callee; @@ -393,7 +397,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( IsArrow ? Base->getType()->getPointeeType() : Base->getType(); EmitCXXDestructorCall(GD, Callee, This.getPointer(*this), ThisTy, /*ImplicitParam=*/nullptr, - /*ImplicitParamTy=*/QualType(), CE); + /*ImplicitParamTy=*/QualType(), CE, CallOrInvoke); } return RValue::get(nullptr); } @@ -435,12 +439,13 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( return EmitCXXMemberOrOperatorCall( CalleeDecl, Callee, ReturnValue, This.getPointer(*this), - /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs); + /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const BinaryOperator *BO = cast<BinaryOperator>(E->getCallee()->IgnoreParens()); const Expr *BaseExpr = BO->getLHS(); @@ -484,24 +489,25 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, EmitCallArgs(Args, FPT, E->arguments()); return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required, /*PrefixSize=*/0), - Callee, ReturnValue, Args, nullptr, E == MustTailCall, + Callee, ReturnValue, Args, CallOrInvoke, E == MustTailCall, E->getExprLoc()); } -RValue -CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue) { +RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr( + const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, + ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) { assert(MD->isImplicitObjectMemberFunction() && "Trying to emit a member call expr on a static method!"); return EmitCXXMemberOrOperatorMemberCallExpr( E, MD, ReturnValue, /*HasQualifier=*/false, /*Qualifier=*/nullptr, - /*IsArrow=*/false, E->getArg(0)); + /*IsArrow=*/false, E->getArg(0), CallOrInvoke); } RValue CodeGenFunction::EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { - return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { + return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue, + CallOrInvoke); } static void EmitNullBaseClassInitialization(CodeGenFunction &CGF, diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 13f12b5d878a..883504c10927 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3142,7 +3142,8 @@ class CodeGenFunction : public CodeGenTypeCache { bool ForVirtualBase, bool Delegating, Address This, CallArgList &Args, AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, bool NewPointerIsChecked); + SourceLocation Loc, bool NewPointerIsChecked, + llvm::CallBase **CallOrInvoke = nullptr); /// Emit assumption load for all bases. Requires to be called only on /// most-derived class and not under construction of the object. @@ -4256,7 +4257,8 @@ class CodeGenFunction : public CodeGenTypeCache { LValue EmitBinaryOperatorLValue(const BinaryOperator *E); LValue EmitCompoundAssignmentLValue(const CompoundAssignOperator *E); // Note: only available for agg return types - LValue EmitCallExprLValue(const CallExpr *E); + LValue EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); // Note: only available for agg return types LValue EmitVAArgExprLValue(const VAArgExpr *E); LValue EmitDeclRefLValue(const DeclRefExpr *E); @@ -4366,20 +4368,26 @@ class CodeGenFunction : public CodeGenTypeCache { /// LLVM arguments and the types they were derived from. RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke, bool IsMustTail, + llvm::CallBase **CallOrInvoke, bool IsMustTail, SourceLocation Loc); RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke = nullptr, + llvm::CallBase **CallOrInvoke = nullptr, bool IsMustTail = false) { - return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke, + return EmitCall(CallInfo, Callee, ReturnValue, Args, CallOrInvoke, IsMustTail, SourceLocation()); } RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E, - ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr); + ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr, + llvm::CallBase **CallOrInvoke = nullptr); + + // If a Call or Invoke instruction was emitted for this CallExpr, this method + // writes the pointer to `CallOrInvoke` if it's not null. RValue EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue = ReturnValueSlot()); - RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue = ReturnValueSlot(), + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); CGCallee EmitCallee(const Expr *E); void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl); @@ -4467,25 +4475,23 @@ class CodeGenFunction : public CodeGenTypeCache { void callCStructCopyAssignmentOperator(LValue Dst, LValue Src); void callCStructMoveAssignmentOperator(LValue Dst, LValue Src); - RValue - EmitCXXMemberOrOperatorCall(const CXXMethodDecl *Method, - const CGCallee &Callee, - ReturnValueSlot ReturnValue, llvm::Value *This, - llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E, - CallArgList *RtlArgs); + RValue EmitCXXMemberOrOperatorCall( + const CXXMethodDecl *Method, const CGCallee &Callee, + ReturnValueSlot ReturnValue, llvm::Value *This, + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *E, + CallArgList *RtlArgs, llvm::CallBase **CallOrInvoke); RValue EmitCXXDestructorCall(GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E); + QualType ImplicitParamTy, const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); - RValue EmitCXXMemberOrOperatorMemberCallExpr(const CallExpr *CE, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue, - bool HasQualifier, - NestedNameSpecifier *Qualifier, - bool IsArrow, const Expr *Base); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitCXXMemberOrOperatorMemberCallExpr( + const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, + bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, + const Expr *Base, llvm::CallBase **CallOrInvoke); // Compute the object pointer. Address EmitCXXMemberDataPointerAddress(const Expr *E, Address base, llvm::Value *memberPtr, @@ -4493,15 +4499,18 @@ class CodeGenFunction : public CodeGenTypeCache { LValueBaseInfo *BaseInfo = nullptr, TBAAAccessInfo *TBAAInfo = nullptr); RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *E); RValue EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E); RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E); @@ -4524,7 +4533,8 @@ class CodeGenFunction : public CodeGenTypeCache { const analyze_os_log::OSLogBufferLayout &Layout, CharUnits BufferAlignment); - RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); /// EmitTargetBuiltinExpr - Emit the given builtin call. Returns 0 if the call /// is unhandled by the current target. diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index e1d056765a86..8c2ad7e03cea 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -315,10 +315,11 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void emitVirtualInheritanceTables(const CXXRecordDecl *RD) override; @@ -1257,7 +1258,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // FIXME: Provide a source location here even though there's no // CXXMemberCallExpr for dtor call. CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.PopCleanupBlock(); @@ -2086,7 +2088,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast<const CXXMemberCallExpr *>(); auto *D = E.dyn_cast<const CXXDeleteExpr *>(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2107,7 +2109,7 @@ llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( } CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - nullptr, QualType(), nullptr); + nullptr, QualType(), nullptr, CallOrInvoke); return nullptr; } diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp index cc6740edabcd..24ae0ece4d70 100644 --- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp +++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp @@ -334,10 +334,11 @@ class MicrosoftCXXABI : public CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, CallArgList &CallArgs) override { @@ -901,7 +902,8 @@ void MicrosoftCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // CXXMemberCallExpr for dtor call. bool UseGlobalDelete = DE->isGlobalDelete(); CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.EmitDeleteCall(DE->getOperatorDelete(), MDThis, ElementType); } @@ -1685,7 +1687,7 @@ void MicrosoftCXXABI::EmitDestructorCall(CodeGenFunction &CGF, CGF.EmitCXXDestructorCall(GD, Callee, CGF.getAsNaturalPointerTo(This, ThisTy), ThisTy, /*ImplicitParam=*/Implicit, - /*ImplicitParamTy=*/QualType(), nullptr); + /*ImplicitParamTy=*/QualType(), /*E=*/nullptr); if (BaseDtorEndBB) { // Complete object handler should continue to be the remaining CGF.Builder.CreateBr(BaseDtorEndBB); @@ -2001,7 +2003,7 @@ CGCallee MicrosoftCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast<const CXXMemberCallExpr *>(); auto *D = E.dyn_cast<const CXXDeleteExpr *>(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2031,7 +2033,7 @@ llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( This = adjustThisArgumentForVirtualFunctionCall(CGF, GD, This, true); RValue RV = CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - ImplicitParam, Context.IntTy, CE); + ImplicitParam, Context.IntTy, CE, CallOrInvoke); return RV.getScalarVal(); } diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 81334c817b2a..726138577620 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -15,6 +15,7 @@ #include "CoroutineStmtBuilder.h" #include "clang/AST/ASTLambda.h" +#include "clang/AST/ComputeDependence.h" #include "clang/AST/Decl.h" #include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" @@ -825,6 +826,32 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) { return CoawaitOp; } +static bool isAttributedCoroInplaceTask(const QualType &QT) { + auto *Record = QT->getAsCXXRecordDecl(); + return Record && Record->hasAttr<CoroInplaceTaskAttr>(); +} + +static bool isCoroInplaceCall(Expr *Operand) { + if (!Operand->isPRValue()) { + return false; + } + + return isAttributedCoroInplaceTask(Operand->getType()); +} + +template <typename DesiredExpr> +DesiredExpr *getExprWrappedByTemporary(Expr *E) { + if (auto *BTE = dyn_cast<CXXBindTemporaryExpr>(E)) { + E = BTE->getSubExpr(); + } + + if (auto *S = dyn_cast<DesiredExpr>(E)) { + return S; + } + + return nullptr; +} + // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to // DependentCoawaitExpr if needed. ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, @@ -848,10 +875,31 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, } auto *RD = Promise->getType()->getAsCXXRecordDecl(); + bool InplaceCall = + isCoroInplaceCall(Operand) && + isAttributedCoroInplaceTask( + getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType()); + + OpaqueValueExpr *OpaqueCallExpr = nullptr; + auto *Transformed = Operand; + + if (InplaceCall) { + if (auto *Temporary = dyn_cast<CXXBindTemporaryExpr>(Operand)) { + auto *SubExpr = Temporary->getSubExpr(); + if (CallExpr *Call = dyn_cast<CallExpr>(SubExpr)) { + OpaqueCallExpr = new (Context) + OpaqueValueExpr(Call->getRParenLoc(), Call->getType(), + Call->getValueKind(), Call->getObjectKind(), Call); + Transformed = CXXBindTemporaryExpr::Create( + Context, Temporary->getTemporary(), OpaqueCallExpr); + } + } + } + if (lookupMember(*this, "await_transform", RD, Loc)) { ExprResult R = - buildPromiseCall(*this, Promise, Loc, "await_transform", Operand); + buildPromiseCall(*this, Promise, Loc, "await_transform", Transformed); if (R.isInvalid()) { Diag(Loc, diag::note_coroutine_promise_implicit_await_transform_required_here) @@ -864,7 +912,13 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, if (Awaiter.isInvalid()) return ExprError(); - return BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get()); + auto Res = BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get()); + if (!Res.isInvalid() && InplaceCall) { + // BuildResolvedCoawaitExpr must return a CoawaitExpr, if valid. + CoawaitExpr *CE = Res.getAs<CoawaitExpr>(); + CE->setInplaceCallOpaqueValue(OpaqueCallExpr); + } + return Res; } ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index e23ceffb10bf..4986360651a1 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -483,7 +483,10 @@ void ASTStmtReader::VisitCoawaitExpr(CoawaitExpr *E) { E->KeywordLoc = readSourceLocation(); for (auto &SubExpr: E->SubExprs) SubExpr = Record.readSubStmt(); - E->OpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->CommonExprOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->InplaceCallOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); E->setIsImplicit(Record.readInt() != 0); } @@ -492,7 +495,10 @@ void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *E) { E->KeywordLoc = readSourceLocation(); for (auto &SubExpr: E->SubExprs) SubExpr = Record.readSubStmt(); - E->OpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->CommonExprOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->InplaceCallOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); } void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *E) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index ea499019c9d1..57084125a851 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -444,7 +444,8 @@ void ASTStmtWriter::VisitCoroutineSuspendExpr(CoroutineSuspendExpr *E) { Record.AddSourceLocation(E->getKeywordLoc()); for (Stmt *S : E->children()) Record.AddStmt(S); - Record.AddStmt(E->getOpaqueValue()); + Record.AddStmt(E->getCommonExprOpaqueValue()); + Record.AddStmt(E->getInplaceCallOpaqueValue()); } void ASTStmtWriter::VisitCoawaitExpr(CoawaitExpr *E) { diff --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h new file mode 100644 index 000000000000..43c6d27823bd --- /dev/null +++ b/clang/test/CodeGenCoroutines/Inputs/utility.h @@ -0,0 +1,13 @@ +// This is a mock file for <utility> + +namespace std { + +template <typename T> struct remove_reference { using type = T; }; +template <typename T> struct remove_reference<T &> { using type = T; }; +template <typename T> struct remove_reference<T &&> { using type = T; }; + +template <typename T> +constexpr typename std::remove_reference<T>::type&& move(T &&t) noexcept { + return static_cast<typename std::remove_reference<T>::type &&>(t); +} +} diff --git a/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp b/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp new file mode 100644 index 000000000000..2569643221da --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp @@ -0,0 +1,84 @@ +// This file tests the coro_structured_concurrency attribute semantics. +// RUN: %clang_cc1 -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s + +#include "Inputs/coroutine.h" +#include "Inputs/utility.h" + +template <typename T> +struct [[clang::coro_inplace_task]] Task { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + + template <typename P> + std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept { + if (!coro) + return std::noop_coroutine(); + return coro.promise().continuation; + } + void await_resume() noexcept {} + }; + + Task get_return_object() noexcept { + return std::coroutine_handle<promise_type>::from_promise(*this); + } + + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_value(T x) noexcept { + value = x; + } + + std::coroutine_handle<> continuation; + T value; + }; + + Task(std::coroutine_handle<promise_type> handle) : handle(handle) {} + ~Task() { + if (handle) + handle.destroy(); + } + + struct Awaiter { + Awaiter(Task *t) : task(t) {} + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle<void> continuation) noexcept {} + T await_resume() noexcept { + return task->handle.promise().value; + } + + Task *task; + }; + + auto operator co_await() { + return Awaiter{this}; + } + +private: + std::coroutine_handle<promise_type> handle; +}; + +// CHECK-LABEL: define{{.*}} @_Z6calleev +Task<int> callee() { + co_return 1; +} + +// CHECK-LABEL: define{{.*}} @_Z8elidablev +Task<int> elidable() { + // CHECK: %[[TARK_OBJ:.+]] = alloca %struct.Task + // CHECK: call void @llvm.coro.safe.elide(ptr %[[TARK_OBJ:.+]]) + co_return co_await callee(); +} + +// CHECK-LABEL: define{{.*}} @_Z11nonelidablev +Task<int> nonelidable() { + // CHECK: %[[TARK_OBJ:.+]] = alloca %struct.Task + auto t = callee(); + // Because we aren't co_awaiting a prvalue, we cannot elide here. + // CHECK-NOT: call void @llvm.coro.safe.elide(ptr %[[TARK_OBJ:.+]]) + co_await t; + co_await std::move(t); + + co_return 1; +} diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test index 28df04c5e33e..c9c14a29843f 100644 --- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test +++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test @@ -59,6 +59,7 @@ // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record) // CHECK-NEXT: Convergent (SubjectMatchRule_function) // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function) +// CHECK-NEXT: CoroInplaceTask (SubjectMatchRule_record) // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record) // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record) // CHECK-NEXT: CoroReturnType (SubjectMatchRule_record) diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index 01e379dfcebc..09cdf5e4fd7e 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1758,6 +1758,9 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic< [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>]>; +def int_coro_safe_elide : DefaultAttrsIntrinsic< + [], [llvm_ptr_ty], []>; + ///===-------------------------- Other Intrinsics --------------------------===// // // TODO: We should introduce a new memory kind fo traps (and other side effects diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index dd92b3593af9..35455ea87ae1 100644 --- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -11,6 +11,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/Transforms/Scalar/SimplifyCFG.h" @@ -81,7 +82,7 @@ bool Lowerer::lower(Function &F) { } else continue; break; - case Intrinsic::coro_async_size_replace: + case Intrinsic::coro_async_size_replace: { auto *Target = cast<ConstantStruct>( cast<GlobalVariable>(II->getArgOperand(0)->stripPointerCasts()) ->getInitializer()); @@ -99,6 +100,9 @@ bool Lowerer::lower(Function &F) { Target->replaceAllUsesWith(NewFuncPtrStruct); break; } + case Intrinsic::coro_safe_elide: + break; + } II->eraseFromParent(); Changed = true; } @@ -112,7 +116,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) { M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr", "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.async", "llvm.coro.id.retcon.once", - "llvm.coro.async.size.replace", "llvm.coro.async.resume"}); + "llvm.coro.async.size.replace", "llvm.coro.async.resume", + "llvm.coro.safe.elide"}); } PreservedAnalyses CoroCleanupPass::run(Module &M, diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 598ef7779d77..20e4fb176a19 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -7,12 +7,14 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Coroutines/CoroElide.h" +#include "CoroInstr.h" #include "CoroInternal.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/Support/ErrorHandling.h" @@ -56,7 +58,8 @@ class FunctionElideInfo { class CoroIdElider { public: CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA, - DominatorTree &DT, OptimizationRemarkEmitter &ORE); + DominatorTree &DT, PostDominatorTree &PDT, + OptimizationRemarkEmitter &ORE); void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign); bool lifetimeEligibleForElide() const; bool attemptElide(); @@ -68,6 +71,7 @@ class CoroIdElider { FunctionElideInfo &FEI; AAResults &AA; DominatorTree &DT; + PostDominatorTree &PDT; OptimizationRemarkEmitter &ORE; SmallVector<CoroBeginInst *, 1> CoroBegins; @@ -183,8 +187,9 @@ void FunctionElideInfo::collectPostSplitCoroIds() { CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA, DominatorTree &DT, + PostDominatorTree &PDT, OptimizationRemarkEmitter &ORE) - : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), ORE(ORE) { + : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), PDT(PDT), ORE(ORE) { // Collect all coro.begin and coro.allocs associated with this coro.id. for (User *U : CoroId->users()) { if (auto *CB = dyn_cast<CoroBeginInst>(U)) @@ -336,6 +341,41 @@ bool CoroIdElider::canCoroBeginEscape( return false; } +// FIXME: This is not accounting for the stores to tasks whose handle is not +// zero offset. +static const StoreInst *getPostDominatingStoreToTask(const CoroBeginInst *CB, + PostDominatorTree &PDT) { + const StoreInst *OnlyStore = nullptr; + + for (auto *U : CB->users()) { + auto *Store = dyn_cast<StoreInst>(U); + if (Store && Store->getValueOperand() == CB) { + if (OnlyStore) { + // Store must be unique. one coro begin getting stored to multiple + // stores is not accepted. + return nullptr; + } + OnlyStore = Store; + } + } + + if (!OnlyStore || !PDT.dominates(OnlyStore, CB)) { + return nullptr; + } + + return OnlyStore; +} + +static bool isMarkedSafeElide(const llvm::Value *V) { + for (auto *U : V->users()) { + auto *II = dyn_cast<IntrinsicInst>(U); + if (II && (II->getIntrinsicID() == Intrinsic::coro_safe_elide)) { + return true; + } + } + return false; +} + bool CoroIdElider::lifetimeEligibleForElide() const { // If no CoroAllocs, we cannot suppress allocation, so elision is not // possible. @@ -364,6 +404,17 @@ bool CoroIdElider::lifetimeEligibleForElide() const { // Filter out the coro.destroy that lie along exceptional paths. for (const auto *CB : CoroBegins) { + // This might be too strong of a condition but should be very safe. + // If the CB is unconditionally stored into a "Task Like Object", + // and such object is "safe elide". + if (FEI.ContainingFunction->isPresplitCoroutine()) { + if (auto *MaybeStoreToTask = getPostDominatingStoreToTask(CB, PDT)) { + auto Dest = MaybeStoreToTask->getPointerOperand(); + if (isMarkedSafeElide(Dest)) + continue; + } + } + auto It = DestroyAddr.find(CB); // FIXME: If we have not found any destroys for this coro.begin, we @@ -476,11 +527,12 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { AAResults &AA = AM.getResult<AAManager>(F); DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); bool Changed = false; for (auto *CII : FEI.getCoroIds()) { - CoroIdElider CIE(CII, FEI, AA, DT, ORE); + CoroIdElider CIE(CII, FEI, AA, DT, PDT, ORE); Changed |= CIE.attemptElide(); } diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 1a92bc163625..48c02e5406b7 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -86,6 +86,7 @@ static const char *const CoroIntrinsics[] = { "llvm.coro.prepare.retcon", "llvm.coro.promise", "llvm.coro.resume", + "llvm.coro.safe.elide", "llvm.coro.save", "llvm.coro.size", "llvm.coro.subfn.addr", diff --git a/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll b/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll new file mode 100644 index 000000000000..b19886d549d9 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll @@ -0,0 +1,64 @@ +; Testing elide performed its job for calls to coroutines marked safe. +; RUN: opt < %s -S -passes='inline,coro-elide' | FileCheck %s + +%struct.Task = type { ptr } + +declare void @print(i32) nounwind + +; resume part of the coroutine +define fastcc void @callee.resume(ptr dereferenceable(1)) { + tail call void @print(i32 0) + ret void +} + +; destroy part of the coroutine +define fastcc void @callee.destroy(ptr) { + tail call void @print(i32 1) + ret void +} + +; cleanup part of the coroutine +define fastcc void @callee.cleanup(ptr) { + tail call void @print(i32 2) + ret void +} + +@callee.resumers = internal constant [3 x ptr] [ + ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup] + +declare void @alloc(i1) nounwind + +; CHECK: define ptr @callee() +define ptr @callee() { +entry: + %task = alloca %struct.Task, align 8 + %id = call token @llvm.coro.id(i32 0, ptr null, + ptr @callee, + ptr @callee.resumers) + %alloc = call i1 @llvm.coro.alloc(token %id) + %hdl = call ptr @llvm.coro.begin(token %id, ptr null) + store ptr %hdl, ptr %task + ret ptr %task +} + +; CHECK: define ptr @caller() +; Function Attrs: presplitcoroutine +define ptr @caller() #0 { +entry: + %task = call ptr @callee() + + ; CHECK: %[[id:.+]] = call token @llvm.coro.id(i32 0, ptr null, ptr @callee, ptr @callee.resumers) + ; CHECK-NOT: call i1 @llvm.coro.alloc(token %[[id]]) + call void @llvm.coro.safe.elide(ptr %task) + + ret ptr %task +} + +attributes #0 = { presplitcoroutine } + +declare token @llvm.coro.id(i32, ptr, ptr, ptr) +declare ptr @llvm.coro.begin(token, ptr) +declare ptr @llvm.coro.frame() +declare ptr @llvm.coro.subfn.addr(ptr, i8) +declare i1 @llvm.coro.alloc(token) +declare void @llvm.coro.safe.elide(ptr) >From 1543f3ed1b3c6b1ad5a00f3e6b6eb251ef7b4d83 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <y...@meta.com> Date: Tue, 9 Jul 2024 17:12:09 -0700 Subject: [PATCH 2/3] Implement noalloc copy --- clang/lib/CodeGen/CGCoroutine.cpp | 7 +- llvm/lib/Transforms/Coroutines/CoroElide.cpp | 12 +- llvm/lib/Transforms/Coroutines/CoroInternal.h | 4 + llvm/lib/Transforms/Coroutines/CoroSplit.cpp | 105 ++++++++++++++---- llvm/lib/Transforms/Coroutines/Coroutines.cpp | 27 +++++ 5 files changed, 119 insertions(+), 36 deletions(-) diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp index 656c1e905317..7ad1470b52f4 100644 --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -232,8 +232,11 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co std::optional<CodeGenFunction::OpaqueValueMapping> OperandMapping; if (auto *CallOV = S.getInplaceCallOpaqueValue()) { auto *CE = cast<CallExpr>(CallOV->getSourceExpr()); - // TODO: don't use the intrisic coro_safe_elide in the next version. - LValue CallResult = CGF.EmitCallExprLValue(CE, nullptr); + llvm::CallBase *CallOrInvoke = nullptr; + LValue CallResult = CGF.EmitCallExprLValue(CE, &CallOrInvoke); + if (CallOrInvoke) + CallOrInvoke->addAnnotationMetadata("coro_must_elide"); + OperandMapping.emplace(CGF, CallOV, CallResult); llvm::Value *Value = CallResult.getPointer(CGF); auto SafeElide = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_safe_elide); diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 20e4fb176a19..467cdf491e5d 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -225,17 +225,7 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) { BasicBlock::iterator InsertPt = getFirstNonAllocaInTheEntryBlock(FEI.ContainingFunction)->getIterator(); - // Replacing llvm.coro.alloc with false will suppress dynamic - // allocation as it is expected for the frontend to generate the code that - // looks like: - // id = coro.id(...) - // mem = coro.alloc(id) ? malloc(coro.size()) : 0; - // coro.begin(id, mem) - auto *False = ConstantInt::getFalse(C); - for (auto *CA : CoroAllocs) { - CA->replaceAllUsesWith(False); - CA->eraseFromParent(); - } + coro::suppressCoroAllocs(C, CoroAllocs); // FIXME: Design how to transmit alignment information for every alloca that // is spilled into the coroutine frame and recreate the alignment information diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h index 5716fd0ea4ab..d91cccd99a70 100644 --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -26,6 +26,10 @@ bool declaresIntrinsics(const Module &M, const std::initializer_list<StringRef>); void replaceCoroFree(CoroIdInst *CoroId, bool Elide); +void suppressCoroAllocs(CoroIdInst *CoroId); +void suppressCoroAllocs(LLVMContext &Context, + ArrayRef<CoroAllocInst *> CoroAllocs); + /// Attempts to rewrite the location operand of debug intrinsics in terms of /// the coroutine frame pointer, folding pointer offsets into the DIExpression /// of the intrinsic. diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 0b52d1e4490c..5b994573336f 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/CFG.h" @@ -1179,6 +1180,14 @@ static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) { Shape.AsyncLowering.AsyncFuncPointer->setInitializer(NewFuncPtrStruct); } +static TypeSize getFrameSizeForShape(coro::Shape &Shape) { + // In the same function all coro.sizes should have the same result type. + auto *SizeIntrin = Shape.CoroSizes.back(); + Module *M = SizeIntrin->getModule(); + const DataLayout &DL = M->getDataLayout(); + return DL.getTypeAllocSize(Shape.FrameTy); +} + static void replaceFrameSizeAndAlignment(coro::Shape &Shape) { if (Shape.ABI == coro::ABI::Async) updateAsyncFuncPointerContextSize(Shape); @@ -1194,10 +1203,8 @@ static void replaceFrameSizeAndAlignment(coro::Shape &Shape) { // In the same function all coro.sizes should have the same result type. auto *SizeIntrin = Shape.CoroSizes.back(); - Module *M = SizeIntrin->getModule(); - const DataLayout &DL = M->getDataLayout(); - auto Size = DL.getTypeAllocSize(Shape.FrameTy); - auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); + auto *SizeConstant = + ConstantInt::get(SizeIntrin->getType(), getFrameSizeForShape(Shape)); for (CoroSizeInst *CS : Shape.CoroSizes) { CS->replaceAllUsesWith(SizeConstant); @@ -1248,6 +1255,7 @@ static void handleNoSuspendCoroutine(coro::Shape &Shape) { } CoroBegin->eraseFromParent(); + Shape.CoroBegin = nullptr; } // SimplifySuspendPoint needs to check that there is no calls between @@ -1970,9 +1978,17 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, } /// Remove calls to llvm.coro.end in the original function. -static void removeCoroEnds(const coro::Shape &Shape) { - for (auto *End : Shape.CoroEnds) { - replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); +static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) { + if (Shape.ABI != coro::ABI::Switch) { + for (auto *End : Shape.CoroEnds) { + replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); + } + } else { + for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { + auto &Context = End->getContext(); + End->replaceAllUsesWith(ConstantInt::getFalse(Context)); + End->eraseFromParent(); + } } } @@ -1981,18 +1997,6 @@ static void updateCallGraphAfterCoroutineSplit( const SmallVectorImpl<Function *> &Clones, LazyCallGraph::SCC &C, LazyCallGraph &CG, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR, FunctionAnalysisManager &FAM) { - if (!Shape.CoroBegin) - return; - - if (Shape.ABI != coro::ABI::Switch) - removeCoroEnds(Shape); - else { - for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { - auto &Context = End->getContext(); - End->replaceAllUsesWith(ConstantInt::getFalse(Context)); - End->eraseFromParent(); - } - } if (!Clones.empty()) { switch (Shape.ABI) { @@ -2080,6 +2084,45 @@ static void addPrepareFunction(const Module &M, Fns.push_back(PrepareFn); } +static Function *createNoAllocVariant(Function &F, ValueToValueMapTy &VMap, + coro::Shape &Shape) { + auto *OrigFnTy = F.getFunctionType(); + auto OldParams = OrigFnTy->params(); + + SmallVector<Type *> NewParams; + NewParams.reserve(OldParams.size() + 1); + NewParams.push_back(PointerType::getUnqual(Shape.FrameTy)); + for (Type *T : OldParams) { + NewParams.push_back(T); + } + auto *NewFnTy = FunctionType::get(OrigFnTy->getReturnType(), NewParams, + OrigFnTy->isVarArg()); + Function *NoAllocF = + Function::Create(NewFnTy, F.getLinkage(), F.getName() + ".noalloc"); + unsigned int Idx = 1; + for (const auto &I : F.args()) { + VMap[&I] = NoAllocF->getArg(Idx++); + } + SmallVector<ReturnInst *, 4> Returns; + CloneFunctionInto(NoAllocF, &F, VMap, + CloneFunctionChangeType::LocalChangesOnly, Returns); + + if (Shape.CoroBegin) { + auto *NewCoroBegin = cast_if_present<CoroBeginInst>(VMap[Shape.CoroBegin]); + auto *NewCoroId = cast<CoroIdInst>(NewCoroBegin->getId()); + coro::replaceCoroFree(NewCoroId, /*Elide=*/true); + coro::suppressCoroAllocs(NewCoroId); + NewCoroBegin->replaceAllUsesWith(NoAllocF->getArg(0)); + NewCoroBegin->eraseFromParent(); + } + + Module *M = F.getParent(); + M->getFunctionList().insert(M->end(), NoAllocF); + + removeUnreachableBlocks(*NoAllocF); + return NoAllocF; +} + CoroSplitPass::CoroSplitPass(bool OptimizeFrame) : MaterializableCallback(coro::defaultMaterializable), OptimizeFrame(OptimizeFrame) {} @@ -2111,15 +2154,19 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, // Split all the coroutines. for (LazyCallGraph::Node *N : Coroutines) { Function &F = N->getFunction(); + + LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName() << "\n"); F.setSplittedCoroutine(); SmallVector<Function *, 4> Clones; auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - const coro::Shape Shape = + coro::Shape Shape = splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame, MaterializableCallback); + + removeCoroEndsFromRampFunction(Shape); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); ORE.emit([&]() { @@ -2135,11 +2182,23 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, for (Function *Clone : Clones) UR.CWorklist.insert(CG.lookupSCC(CG.get(*Clone))); } - } - for (auto *PrepareFn : PrepareFns) { - replaceAllPrepares(PrepareFn, CG, C); + if (Shape.ABI == coro::ABI::Switch) { + ValueToValueMapTy VMap; + auto *NoAllocF = createNoAllocVariant(F, VMap, Shape); + NoAllocF->addFnAttr("elided-coro"); + auto NewAttrs = NoAllocF->getAttributes(); + + addFramePointerAttrs(NewAttrs, NoAllocF->getContext(), 0, Shape.FrameSize, + Shape.FrameAlign, /*NoAlias=*/false); + + NoAllocF->setAttributes(NewAttrs); } + } + + for (auto *PrepareFn : PrepareFns) { + replaceAllPrepares(PrepareFn, CG, C); + } return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 48c02e5406b7..36869316a35a 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -146,6 +146,33 @@ void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) { } } +void coro::suppressCoroAllocs(CoroIdInst *CoroId) { + SmallVector<CoroAllocInst *, 4> CoroAllocs; + for (User *U : CoroId->users()) + if (auto *CA = dyn_cast<CoroAllocInst>(U)) + CoroAllocs.push_back(CA); + + if (CoroAllocs.empty()) + return; + + coro::suppressCoroAllocs(CoroId->getContext(), CoroAllocs); +} + +// Replacing llvm.coro.alloc with false will suppress dynamic +// allocation as it is expected for the frontend to generate the code that +// looks like: +// id = coro.id(...) +// mem = coro.alloc(id) ? malloc(coro.size()) : 0; +// coro.begin(id, mem) +void coro::suppressCoroAllocs(LLVMContext &Context, + ArrayRef<CoroAllocInst *> CoroAllocs) { + auto *False = ConstantInt::getFalse(Context); + for (auto *CA : CoroAllocs) { + CA->replaceAllUsesWith(False); + CA->eraseFromParent(); + } +} + static void clear(coro::Shape &Shape) { Shape.CoroBegin = nullptr; Shape.CoroEnds.clear(); >From d2af7a470774f20df0ec54380e4af7f3584956cf Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <y...@meta.com> Date: Thu, 11 Jul 2024 19:02:25 -0700 Subject: [PATCH 3/3] add CoroAnnotationElidePass --- llvm/include/llvm/IR/Instruction.h | 4 + .../Coroutines/CoroAnnotationElide.h | 34 ++++++ llvm/lib/IR/Metadata.cpp | 16 +++ llvm/lib/Passes/PassBuilder.cpp | 1 + llvm/lib/Passes/PassBuilderPipelines.cpp | 3 +- llvm/lib/Passes/PassRegistry.def | 1 + llvm/lib/Transforms/Coroutines/CMakeLists.txt | 1 + .../Coroutines/CoroAnnotationElide.cpp | 108 ++++++++++++++++++ 8 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h create mode 100644 llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h index 7a9b95f23465..8d22d4cbf8ad 100644 --- a/llvm/include/llvm/IR/Instruction.h +++ b/llvm/include/llvm/IR/Instruction.h @@ -455,6 +455,10 @@ class Instruction : public User, /// !annotation metadata, append the tuple to /// the existing node. void addAnnotationMetadata(SmallVector<StringRef> Annotations); + + /// Returns true if an !annotation metadata is set to this instruction. + bool hasAnnotationMetadata(StringRef Name) const; + /// Returns the AA metadata for this instruction. AAMDNodes getAAMetadata() const; diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h new file mode 100644 index 000000000000..dae1cc0c689a --- /dev/null +++ b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h @@ -0,0 +1,34 @@ +//===- CoroAnnotationElide.h - Optimizing a coro_must_elide call ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This pass transforms all Call or Invoke instructions that are annotated +// "coro_must_elide" to call the `.noalloc` variant of coroutine instead. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H +#define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> { + CoroAnnotationElidePass() {} + + PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, + LazyCallGraph &CG, CGSCCUpdateResult &UR); + + static bool isRequired() { return false; } +}; +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index 5f42ce22f72f..ba2b7850c364 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1703,6 +1703,22 @@ void Instruction::addAnnotationMetadata(StringRef Name) { setMetadata(LLVMContext::MD_annotation, MD); } +bool Instruction::hasAnnotationMetadata(StringRef Name) const { + auto *Metadata = getMetadata(LLVMContext::MD_annotation); + if (!Metadata) + return false; + + auto *Tuple = cast<MDTuple>(Metadata); + for (auto &N : Tuple->operands()) { + if (auto *S = dyn_cast<MDString>(N.get())) { + if (S->getString() == Name) { + return true; + } + } + } + return false; +} + AAMDNodes Instruction::getAAMetadata() const { AAMDNodes Result; // Not using Instruction::hasMetadata() because we're not interested in diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 219e8b75450e..2b2ace1330fc 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -130,6 +130,7 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/CFGuard.h" +#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" #include "llvm/Transforms/Coroutines/CoroCleanup.h" #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h" #include "llvm/Transforms/Coroutines/CoroEarly.h" diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 4fd5ee1946bb..2d5a6e6861f2 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/VirtualFileSystem.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" +#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" #include "llvm/Transforms/Coroutines/CoroCleanup.h" #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h" #include "llvm/Transforms/Coroutines/CoroEarly.h" @@ -968,8 +969,8 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level, // it's been modified since. MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor( RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis, Function>())); - MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0)); + MainCGPipeline.addPass(CoroAnnotationElidePass()); // Make sure we don't affect potential future NoRerun CGSCC adaptors. MIWP.addLateModulePass(createModuleToFunctionPassAdaptor( diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 3b92823cd283..f3d77f8def4a 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -234,6 +234,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass()) CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass()) CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass()) CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass()) +CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass()) #undef CGSCC_PASS #ifndef CGSCC_PASS_WITH_PARAMS diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt index 2139446e5ff9..b4b5812d97d8 100644 --- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt +++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMCoroutines Coroutines.cpp + CoroAnnotationElide.cpp CoroCleanup.cpp CoroConditionalWrapper.cpp CoroEarly.cpp diff --git a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp new file mode 100644 index 000000000000..e380af2cc3a3 --- /dev/null +++ b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp @@ -0,0 +1,108 @@ +//===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" + +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/IR/Analysis.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" + +#include <cassert> + +using namespace llvm; + +#define DEBUG_TYPE "coro-annotation-elide" + +#define CORO_MUST_ELIDE_ANNOTATION "coro_must_elide" + +static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { + for (Instruction &I : F->getEntryBlock()) + if (!isa<AllocaInst>(&I)) + return &I; + llvm_unreachable("no terminator in the entry block"); +} + +static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize, + Align FrameAlign) { + LLVMContext &C = Caller->getContext(); + BasicBlock::iterator InsertPt = + getFirstNonAllocaInTheEntryBlock(Caller)->getIterator(); + const DataLayout &DL = Caller->getDataLayout(); + auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); + auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); + Frame->setAlignment(FrameAlign); + return new BitCastInst(Frame, PointerType::getUnqual(C), "vFrame", InsertPt); +} + +static void processCall(CallBase *CB, Function *Caller, Function *NewCallee, + uint64_t FrameSize, Align FrameAlign) { + auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign); + CB->setCalledFunction(NewCallee->getFunctionType(), NewCallee); + CB->setArgOperand(0, FramePtr); + auto NewCBInsertPt = CB->getIterator(); + llvm::CallBase *NewCB = nullptr; + SmallVector<Value *, 4> NewArgs; + NewArgs.push_back(FramePtr); + NewArgs.append(CB->arg_begin(), CB->arg_end()); + + // TODO: bundles? + if (auto *CI = dyn_cast<CallInst>(CB)) { + NewCB = CallInst::Create(NewCallee->getFunctionType(), NewCallee, NewArgs, + "", NewCBInsertPt); + } else if (auto *II = dyn_cast<InvokeInst>(CB)) { + NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee, + II->getNormalDest(), II->getUnwindDest(), + NewArgs, std::nullopt, "", NewCBInsertPt); + } else { + llvm_unreachable("CallBase should either be Call or Invoke!"); + } + + CB->replaceAllUsesWith(NewCB); + CB->eraseFromParent(); +} + +PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + // NB: One invariant of a valid LazyCallGraph::SCC is that it must contain a + // non-zero number of nodes, so we assume that here and grab the first + // node's function's module. + Module &M = *C.begin()->getFunction().getParent(); + bool Changed = false; + // Find coroutines for processing. + SmallVector<LazyCallGraph::Node *> Coroutines; + for (LazyCallGraph::Node &N : C) { + Function *Callee = &N.getFunction(); + Function *NewCallee = Callee->getParent()->getFunction( + (Callee->getName() + ".noalloc").getSingleStringRef()); + if (!NewCallee) { + continue; + } + + auto FrameSize = NewCallee->getParamDereferenceableBytes(0); + auto FrameAlign = NewCallee->getParamAlign(0).valueOrOne(); + + for (auto *U : Callee->users()) { + if (auto *CB = dyn_cast<CallBase>(U)) { + auto *Caller = CB->getFunction(); + if (Caller && Caller->isPresplitCoroutine() && + CB->hasAnnotationMetadata(CORO_MUST_ELIDE_ANNOTATION)) { + processCall(CB, Caller, NewCallee, FrameSize, FrameAlign); + Changed = true; + } + } + } + } + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits