EricWF updated this revision to Diff 76953. EricWF added a comment. - Address review comments about `DependentCoawaitExpr` and using `UnresolvedLookupExpr`. - Fix building of the initial/final coroutine suspends points. - Fix transformation of `CoroutineBodyStmt` so that it transforms the final/initial suspend points instead of rebuilding them fully.
@rsmith: This change is a little big, but it's not trivial for me to split it up. Please let me know if you would prefer this submitted as multiple patches. https://reviews.llvm.org/D26057 Files: include/clang/AST/ExprCXX.h include/clang/AST/RecursiveASTVisitor.h include/clang/AST/StmtCXX.h include/clang/Basic/DiagnosticSemaKinds.td include/clang/Basic/StmtNodes.td include/clang/Sema/ScopeInfo.h include/clang/Sema/Sema.h lib/AST/Expr.cpp lib/AST/ExprClassification.cpp lib/AST/ExprConstant.cpp lib/AST/ItaniumMangle.cpp lib/AST/StmtPrinter.cpp lib/AST/StmtProfile.cpp lib/Parse/ParseStmt.cpp lib/Sema/ScopeInfo.cpp lib/Sema/SemaCoroutine.cpp lib/Sema/SemaDecl.cpp lib/Sema/SemaExceptionSpec.cpp lib/Sema/SemaTemplateInstantiateDecl.cpp lib/Sema/TreeTransform.h lib/Serialization/ASTReaderStmt.cpp lib/Serialization/ASTWriterStmt.cpp lib/StaticAnalyzer/Core/ExprEngine.cpp test/SemaCXX/coroutines.cpp tools/libclang/CXCursor.cpp
Index: tools/libclang/CXCursor.cpp =================================================================== --- tools/libclang/CXCursor.cpp +++ tools/libclang/CXCursor.cpp @@ -231,6 +231,7 @@ case Stmt::TypeTraitExprClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoawaitExprClass: + case Stmt::DependentCoawaitExprClass: case Stmt::CoreturnStmtClass: case Stmt::CoyieldExprClass: case Stmt::CXXBindTemporaryExprClass: Index: test/SemaCXX/coroutines.cpp =================================================================== --- test/SemaCXX/coroutines.cpp +++ test/SemaCXX/coroutines.cpp @@ -59,25 +59,25 @@ template <typename... T> struct std::experimental::coroutine_traits<int, T...> {}; -int no_promise_type() { - co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}} +int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}} + co_await a; } template <> struct std::experimental::coroutine_traits<double, double> { typedef int promise_type; }; -double bad_promise_type(double) { - co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}} +double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}} + co_await a; } template <> struct std::experimental::coroutine_traits<double, int> { struct promise_type {}; }; -double bad_promise_type_2(int) { - co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits<double, int>::promise_type'}} +double bad_promise_type_2(int) { // expected-error {{no member named 'initial_suspend'}} + co_yield 0; // expected-error {{no member named 'yield_value'}} } -struct promise; // expected-note 2{{forward declaration}} +struct promise; // expected-note {{forward declaration}} struct promise_void; struct void_tag {}; template <typename... T> @@ -94,9 +94,7 @@ } // FIXME: This diagnostic is terrible. -void undefined_promise() { // expected-error {{variable has incomplete type 'promise_type'}} - // FIXME: This diagnostic doesn't make any sense. - // expected-error@-2 {{incomplete definition of type 'promise'}} +void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<void>::promise_type' (aka 'promise') is an incomplete type}} co_await a; } @@ -217,6 +215,13 @@ } struct outer {}; +struct await_arg_1 {}; +struct await_arg_2 {}; + +namespace adl_ns { +struct coawait_arg_type {}; +awaitable operator co_await(coawait_arg_type); +} namespace dependent_operator_co_await_lookup { template<typename T> void await_template(T t) { @@ -239,6 +244,94 @@ }; template void await_template(outer); // expected-note {{instantiation}} template void await_template_2(outer); + + struct transform_awaitable {}; + struct transformed {}; + + struct transform_promise { + typedef transform_awaitable await_arg; + coro<transform_promise> get_return_object(); + transformed initial_suspend(); + ::adl_ns::coawait_arg_type final_suspend(); + transformed await_transform(transform_awaitable); + }; + template <class AwaitArg> + struct basic_promise { + typedef AwaitArg await_arg; + coro<basic_promise> get_return_object(); + awaitable initial_suspend(); + awaitable final_suspend(); + }; + + awaitable operator co_await(await_arg_1); + + template <typename T, typename U> + coro<T> await_template_3(U t) { + co_await t; + } + + template coro<basic_promise<await_arg_1>> await_template_3<basic_promise<await_arg_1>>(await_arg_1); + + template <class T, int I = 0> + struct dependent_member { + coro<T> mem_fn() const { + co_await typename T::await_arg{}; // expected-error {{call to function 'operator co_await'}}} + } + template <class U> + coro<T> dep_mem_fn(U t) { + co_await t; + } + }; + + template <> + struct dependent_member<long> { + // FIXME this diagnostic is terrible + coro<transform_promise> mem_fn() const { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}} + // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + // expected-note@+1 {{function is a coroutine due to use of 'co_await' here}} + co_await transform_awaitable{}; + // expected-error@-1 {{no member named 'await_ready'}} + } + template <class R, class U> + coro<R> dep_mem_fn(U u) { co_await u; } + }; + + awaitable operator co_await(await_arg_2); // expected-note {{'operator co_await' should be declared prior to the call site}} + + template struct dependent_member<basic_promise<await_arg_1>, 0>; + template struct dependent_member<basic_promise<await_arg_2>, 0>; // expected-note {{in instantiation}} + + template <> + coro<transform_promise> + // FIXME this diagnostic is terrible + dependent_member<long>::dep_mem_fn<transform_promise>(int) { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}} + //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + //expected-note@+1 {{function is a coroutine due to use of 'co_await' here}} + co_await transform_awaitable{}; + // expected-error@-1 {{no member named 'await_ready'}} + } + + void operator co_await(transform_awaitable) = delete; + awaitable operator co_await(transformed); + + template coro<transform_promise> + dependent_member<long>::dep_mem_fn<transform_promise>(transform_awaitable); + + template <> + coro<transform_promise> dependent_member<long>::dep_mem_fn<transform_promise>(long) { + co_await transform_awaitable{}; + } + + template <> + struct dependent_member<int> { + coro<transform_promise> mem_fn() const { + co_await transform_awaitable{}; + } + }; + + template coro<transform_promise> await_template_3<transform_promise>(transform_awaitable); + template struct dependent_member<transform_promise>; + template coro<transform_promise> dependent_member<transform_promise>::dep_mem_fn(transform_awaitable); } struct yield_fn_tag {}; @@ -314,7 +407,8 @@ }; // FIXME: This diagnostic is terrible. coro<bad_promise_4> bad_initial_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} - co_await a; + // expected-note@-1 {{'initial_suspend' implicitly required}} + co_await a; // expected-note {{use of 'co_await' here}} } struct bad_promise_5 { @@ -324,7 +418,8 @@ }; // FIXME: This diagnostic is terrible. coro<bad_promise_5> bad_final_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} - co_await a; + // expected-note@-1 {{'final_suspend' implicitly required}} + co_await a; // expected-note {{use of 'co_await' here}} } struct bad_promise_6 { @@ -355,20 +450,69 @@ int *current_exception(); } -struct bad_promise_8 { +struct bad_promise_base { +private: + void return_void(); // expected-note {{declared private here}} +}; +struct bad_promise_8 : bad_promise_base { coro<bad_promise_8> get_return_object(); suspend_always initial_suspend(); suspend_always final_suspend(); - void return_void(); void set_exception(); // expected-note {{function not viable}} void set_exception(int *) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}} void set_exception(void *); // expected-note {{candidate function}} }; coro<bad_promise_8> calls_set_exception() { // expected-error@-1 {{call to unavailable member function 'set_exception'}} + // expected-error@-2 {{'return_void' is a private member of 'bad_promise_base'}} co_await a; } +struct bad_promise_9 { + coro<bad_promise_9> get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void await_transform(void *); // expected-note {{candidate}} + awaitable await_transform(int) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}} + void return_void(); +}; +coro<bad_promise_9> calls_await_transform() { + co_await 42; // expected-error {{call to unavailable member function 'await_transform'}} + // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}} +} + +struct bad_promise_10 { + coro<bad_promise_10> get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + int await_transform; + void return_void(); +}; +coro<bad_promise_10> bad_coawait() { + // FIXME this diagnostic is terrible + co_await 42; // expected-error {{called object type 'int' is not a function or function pointer}} + // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}} +} + +struct call_operator { + template <class... Args> + awaitable operator()(Args...) const { return a; } +}; +void ret_void(); +struct good_promise_1 { + coro<good_promise_1> get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + static const call_operator await_transform; + using Fn = void (*)(); + Fn return_void = ret_void; +}; +const call_operator good_promise_1::await_transform; +coro<good_promise_1> ok_static_coawait() { + // FIXME this diagnostic is terrible + co_await 42; +} + template<> struct std::experimental::coroutine_traits<int, int, const char**> { using promise_type = promise; }; Index: lib/StaticAnalyzer/Core/ExprEngine.cpp =================================================================== --- lib/StaticAnalyzer/Core/ExprEngine.cpp +++ lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -774,6 +774,7 @@ case Stmt::FunctionParmPackExprClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoawaitExprClass: + case Stmt::DependentCoawaitExprClass: case Stmt::CoreturnStmtClass: case Stmt::CoyieldExprClass: case Stmt::SEHTryStmtClass: Index: lib/Serialization/ASTWriterStmt.cpp =================================================================== --- lib/Serialization/ASTWriterStmt.cpp +++ lib/Serialization/ASTWriterStmt.cpp @@ -315,6 +315,11 @@ llvm_unreachable("unimplemented"); } +void ASTStmtWriter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + // FIXME: Implement coroutine serialization. + llvm_unreachable("unimplemented"); +} + void ASTStmtWriter::VisitCoyieldExpr(CoyieldExpr *S) { // FIXME: Implement coroutine serialization. llvm_unreachable("unimplemented"); Index: lib/Serialization/ASTReaderStmt.cpp =================================================================== --- lib/Serialization/ASTReaderStmt.cpp +++ lib/Serialization/ASTReaderStmt.cpp @@ -400,6 +400,11 @@ llvm_unreachable("unimplemented"); } +void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + // FIXME: Implement coroutine serialization. + llvm_unreachable("unimplemented"); +} + void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *S) { // FIXME: Implement coroutine serialization. llvm_unreachable("unimplemented"); Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -1306,16 +1306,29 @@ /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result) { - return getSema().BuildCoreturnStmt(CoreturnLoc, Result); + StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result, + bool IsImplicitlyCreated) { + return getSema().BuildCoreturnStmt(CoreturnLoc, Result, + IsImplicitlyCreated); } /// \brief Build a new co_await expression. /// /// By default, performs semantic analysis to build the new expression. /// Subclasses may override this routine to provide different behavior. - ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result) { - return getSema().BuildCoawaitExpr(CoawaitLoc, Result); + ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result, + bool IsImplicitlyCreated) { + return getSema().BuildCoawaitExpr(CoawaitLoc, Result, IsImplicitlyCreated); + } + + /// \brief Build a new co_await expression. + /// + /// By default, performs semantic analysis to build the new expression. + /// Subclasses may override this routine to provide different behavior. + ExprResult RebuildDependentCoawaitExpr(SourceLocation CoawaitLoc, + Expr *Result, + UnresolvedLookupExpr *Lookup) { + return getSema().BuildDependentCoawaitExpr(CoawaitLoc, Result, Lookup); } /// \brief Build a new co_yield expression. @@ -1326,6 +1339,15 @@ return getSema().BuildCoyieldExpr(CoyieldLoc, Result); } + StmtResult RebuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend, + Stmt *FinalSuspend, Stmt *OnException, + Stmt *OnFallthrough, + Expr *Allocation, + Stmt *Deallocation, Expr *ReturnObject) { + return getSema().BuildCoroutineBodyStmt( + Body, Promise, InitSuspend, FinalSuspend, OnException, OnFallthrough, + Allocation, Deallocation, ReturnObject); + } /// \brief Build a new Objective-C \@try statement. /// /// By default, performs semantic analysis to build the new statement. @@ -6655,7 +6677,87 @@ TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) { // The coroutine body should be re-formed by the caller if necessary. // FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody - return getDerived().TransformStmt(S->getBody()); + + auto *ScopeInfo = SemaRef.getCurFunction(); + auto *FD = cast<FunctionDecl>(SemaRef.CurContext); + assert(ScopeInfo && !ScopeInfo->CoroutinePromise && + !ScopeInfo->HasCoroutineSuspends && + ScopeInfo->CoroutineStmts.empty() && "expected clean scope info"); + + // Set that we have (possibly-invalid) suspend points before we do anything + // that may fail. + ScopeInfo->setCoroutineSuspendsInvalid(); + + // The new CoroutinePromise object needs to be built and put into the current + // FunctionScopeInfo before any transformations or rebuilding occurs. + auto *Promise = S->getPromiseDecl(); + auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation()); + if (!NewPromise) + return StmtError(); + getDerived().transformedLocalDecl(Promise, NewPromise); + ScopeInfo->CoroutinePromise = NewPromise; + + // Transform the implicit coroutine statements we built during the initial + // parse. + StmtResult InitSuspend = getDerived().TransformStmt(S->getInitSuspendStmt()); + if (InitSuspend.isInvalid()) + return StmtError(); + StmtResult FinalSuspend = + getDerived().TransformStmt(S->getFinalSuspendStmt()); + if (FinalSuspend.isInvalid()) + return StmtError(); + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + + StmtResult BodyRes = getDerived().TransformStmt(S->getBody()); + if (BodyRes.isInvalid()) + return StmtError(); + + Stmt *SetException = S->getExceptionHandler(); + Stmt *Fallthrough = S->getFallthroughHandler(); + if (Fallthrough) { + StmtResult Res = getDerived().TransformStmt(Fallthrough); + if (Res.isInvalid()) + return StmtError(); + Fallthrough = Res.get(); + } + + if (SetException) { + StmtResult Res = getDerived().TransformStmt(SetException); + if (Res.isInvalid()) + return StmtError(); + SetException = Res.get(); + } + + // Transform any additional statements we may have already built. + Expr *Allocation = nullptr; + Stmt *Deallocation = nullptr; + if (S->getAllocate() && S->getDeallocate()) { + ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate()); + if (AllocRes.isInvalid()) + return StmtError(); + Allocation = AllocRes.get(); + + StmtResult DeallocRes = getDerived().TransformStmt(S->getDeallocate()); + if (DeallocRes.isInvalid()) + return StmtError(); + Deallocation = DeallocRes.get(); + } + + Expr *ReturnObject = S->getReturnValueInit(); + if (ReturnObject) { + ExprResult Res = getDerived().TransformInitializer(ReturnObject, + /*NoCopyInit*/false); + if (Res.isInvalid()) + return StmtError(); + ReturnObject = Res.get(); + } + + // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo + return getDerived().RebuildCoroutineBodyStmt( + BodyRes.get(), NewPromise, InitSuspend.get(), FinalSuspend.get(), + SetException, Fallthrough, Allocation, + Deallocation, ReturnObject); + } template<typename Derived> @@ -6668,7 +6770,8 @@ // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get(), + S->isImplicitlyCreated()); } template<typename Derived> @@ -6681,12 +6784,26 @@ // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(), + E->isImplicitlyCreated()); } -template<typename Derived> +template <typename Derived> ExprResult -TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) { +TreeTransform<Derived>::TransformDependentCoawaitExpr(DependentCoawaitExpr *E) { + ExprResult Result = getDerived().TransformInitializer(E->getOperand(), + /*NotCopyInit*/ false); + if (Result.isInvalid()) + return ExprError(); + + // Always rebuild; we don't know if this needs to be injected into a new + // context or if the promise type has changed. + return getDerived().RebuildDependentCoawaitExpr( + E->getKeywordLoc(), Result.get(), E->getOperatorCoawaitLookup()); +} + +template <typename Derived> +ExprResult TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) { ExprResult Result = getDerived().TransformInitializer(E->getOperand(), /*NotCopyInit*/false); if (Result.isInvalid()) Index: lib/Sema/SemaTemplateInstantiateDecl.cpp =================================================================== --- lib/Sema/SemaTemplateInstantiateDecl.cpp +++ lib/Sema/SemaTemplateInstantiateDecl.cpp @@ -3714,6 +3714,8 @@ if (Body.isInvalid()) Function->setInvalidDecl(); + else + assert(Body.get()); ActOnFinishFunctionBody(Function, Body.get(), /*IsInstantiation=*/true); Index: lib/Sema/SemaExceptionSpec.cpp =================================================================== --- lib/Sema/SemaExceptionSpec.cpp +++ lib/Sema/SemaExceptionSpec.cpp @@ -1146,6 +1146,7 @@ case Expr::ArraySubscriptExprClass: case Expr::OMPArraySectionExprClass: case Expr::BinaryOperatorClass: + case Expr::DependentCoawaitExprClass: case Expr::CompoundAssignOperatorClass: case Expr::CStyleCastExprClass: case Expr::CXXStaticCastExprClass: Index: lib/Sema/SemaDecl.cpp =================================================================== --- lib/Sema/SemaDecl.cpp +++ lib/Sema/SemaDecl.cpp @@ -11383,7 +11383,7 @@ if (canRedefineFunction(Definition, getLangOpts())) return; - // If we don't have a visible definition of the function, and it's inline or + // If we don't have a viNsible definition of the function, and it's inline or // a template, skip the new definition. if (SkipBody && !hasVisibleDefinition(Definition) && (Definition->getFormalLinkage() == InternalLinkage || @@ -11675,7 +11675,7 @@ sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy(); sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr; - if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) + if (getLangOpts().CoroutinesTS && getCurFunction()->CoroutinePromise) CheckCompletedCoroutineBody(FD, Body); if (FD) { Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -21,21 +21,32 @@ using namespace clang; using namespace sema; +static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, + SourceLocation Loc) { + DeclarationName DN = S.PP.getIdentifierInfo(Name); + LookupResult LR(S, DN, Loc, Sema::LookupMemberName); + // Suppress diagnostics when a private member is selected. The same warnings + // will be produced again when building the call. + LR.suppressDiagnostics(); + return S.LookupQualifiedName(LR, RD); +} + /// Look up the std::coroutine_traits<...>::promise_type for the given /// function type. static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, - SourceLocation Loc) { + SourceLocation KwLoc, + SourceLocation FuncLoc) { // FIXME: Cache std::coroutine_traits once we've found it. NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); if (!StdExp) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found); return QualType(); } LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), - Loc, Sema::LookupOrdinaryName); + FuncLoc, Sema::LookupOrdinaryName); if (!S.LookupQualifiedName(Result, StdExp)) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found); return QualType(); } @@ -49,52 +60,58 @@ } // Form template argument list for coroutine_traits<R, P1, P2, ...>. - TemplateArgumentListInfo Args(Loc, Loc); + TemplateArgumentListInfo Args(KwLoc, KwLoc); Args.addArgument(TemplateArgumentLoc( TemplateArgument(FnType->getReturnType()), - S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); + S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc))); // FIXME: If the function is a non-static member function, add the type // of the implicit object parameter before the formal parameters. for (QualType T : FnType->getParamTypes()) Args.addArgument(TemplateArgumentLoc( - TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); + TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); // Build the template-id. QualType CoroTrait = - S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); + S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); if (CoroTrait.isNull()) return QualType(); - if (S.RequireCompleteType(Loc, CoroTrait, + if (S.RequireCompleteType(KwLoc, CoroTrait, diag::err_coroutine_traits_missing_specialization)) return QualType(); - CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl(); + auto *RD = CoroTrait->getAsCXXRecordDecl(); assert(RD && "specialization of class template is not a class?"); // Look up the ::promise_type member. - LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, + LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, Sema::LookupOrdinaryName); S.LookupQualifiedName(R, RD); auto *Promise = R.getAsSingle<TypeDecl>(); if (!Promise) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_found) << RD; return QualType(); } - // The promise type is required to be a class type. QualType PromiseType = S.Context.getTypeDeclType(Promise); - if (!PromiseType->getAsCXXRecordDecl()) { - // Use the fully-qualified name of the type. + + auto buildNNS = [&]() { auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); NNS = NestedNameSpecifier::Create(S.Context, NNS, false, CoroTrait.getTypePtr()); - PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType); + return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); + }; - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) - << PromiseType; + if (!PromiseType->getAsCXXRecordDecl()) { + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_class) + << buildNNS(); return QualType(); } + if (S.RequireCompleteType(FuncLoc, buildNNS(), + diag::err_coroutine_promise_type_incomplete)) + return QualType(); return PromiseType; } @@ -160,41 +177,49 @@ return !Diagnosed; } -/// Check that this is a context in which a coroutine suspension can appear. -static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, - StringRef Keyword) { - if (!isValidCoroutineContext(S, Loc, Keyword)) - return nullptr; +static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, + SourceLocation Loc) { + DeclarationName OpName = + SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); + LookupResult Operators(SemaRef, OpName, SourceLocation(), + Sema::LookupOperatorName); + SemaRef.LookupName(Operators, S); + + assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); + const auto &Functions = Operators.asUnresolvedSet(); + bool IsOverloaded = + Functions.size() > 1 || + (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); + Expr *CoawaitOp = UnresolvedLookupExpr::Create( + SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), + DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, + Functions.begin(), Functions.end()); + assert(CoawaitOp); + return CoawaitOp; +} - assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); - auto *FD = cast<FunctionDecl>(S.CurContext); - auto *ScopeInfo = S.getCurFunction(); - assert(ScopeInfo && "missing function scope for function"); +/// Build a call to 'operator co_await' if there is a suitable operator for +/// the given expression. +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, + Expr *E, + UnresolvedLookupExpr *Lookup) { - // If we don't have a promise variable, build one now. - if (!ScopeInfo->CoroutinePromise) { - QualType T = FD->getType()->isDependentType() - ? S.Context.DependentTy - : lookupPromiseType( - S, FD->getType()->castAs<FunctionProtoType>(), Loc); - if (T.isNull()) - return nullptr; - - // Create and default-initialize the promise. - ScopeInfo->CoroutinePromise = - VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), - &S.PP.getIdentifierTable().get("__promise"), T, - S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); - S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); - if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) - S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false); - } + UnresolvedSet<16> Functions; + Functions.append(Lookup->decls_begin(), Lookup->decls_end()); + return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); +} - return ScopeInfo; +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, + SourceLocation Loc, Expr *E) { + ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); + if (R.isInvalid()) + return ExprError(); + return buildOperatorCoawaitCall(SemaRef, Loc, E, + cast<UnresolvedLookupExpr>(R.get())); } static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, - MutableArrayRef<Expr *> CallArgs) { + MultiExprArg CallArgs) { StringRef Name = S.Context.BuiltinInfo.getName(Id); LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); @@ -213,24 +238,14 @@ return Call.get(); } -/// Build a call to 'operator co_await' if there is a suitable operator for -/// the given expression. -static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, - SourceLocation Loc, Expr *E) { - UnresolvedSet<16> Functions; - SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), - Functions); - return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); -} struct ReadySuspendResumeResult { bool IsInvalid; Expr *Results[3]; }; static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, - StringRef Name, - MutableArrayRef<Expr *> Args) { + StringRef Name, MultiExprArg Args) { DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. @@ -268,25 +283,174 @@ return Calls; } +static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, + SourceLocation Loc, StringRef Name, + MultiExprArg Args) { + + // Form a reference to the promise. + ExprResult PromiseRef = S.BuildDeclRefExpr( + Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); + if (PromiseRef.isInvalid()) + return ExprError(); + + // Call 'yield_value', passing in E. + return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); +} + +VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { + assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); + auto *FD = cast<FunctionDecl>(CurContext); + + QualType T = + FD->getType()->isDependentType() + ? Context.DependentTy + : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(), + Loc, FD->getLocation()); + if (T.isNull()) + return nullptr; + + auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), + &PP.getIdentifierTable().get("__promise"), T, + Context.getTrivialTypeSourceInfo(T, Loc), SC_None); + CheckVariableDeclarationType(VD); + if (VD->isInvalidDecl()) + return nullptr; + ActOnUninitializedDecl(VD, false); + assert(!VD->isInvalidDecl()); + return VD; +} + +/// Check that this is a context in which a coroutine suspension can appear. +static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, + StringRef Keyword) { + if (!isValidCoroutineContext(S, Loc, Keyword)) + return nullptr; + + assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); + auto *FD = cast<FunctionDecl>(S.CurContext); + + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo && "missing function scope for function"); + + if (ScopeInfo->CoroutinePromise) + return ScopeInfo; + + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); + if (!ScopeInfo->CoroutinePromise) + return nullptr; + + return ScopeInfo; +} + +static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, + StringRef Keyword) { + if (!checkCoroutineContext(S, KWLoc, Keyword)) + return false; + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo->CoroutinePromise); + + // If we have existing coroutine statements then we have already built + // the initial and final suspend points. + if (ScopeInfo->HasCoroutineSuspends) + return true; + + ScopeInfo->setCoroutineSuspendsInvalid(); + + auto *Fn = cast<FunctionDecl>(S.CurContext); + SourceLocation Loc = Fn->getLocation(); + // Build the initial suspend point + auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { + ExprResult Suspend = + buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = S.BuildCoawaitExpr(Loc, Suspend.get(), + /*IsImplicitlyCreated*/ true); + Suspend = S.ActOnFinishFullExpr(Suspend.get()); + if (Suspend.isInvalid()) { + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << ((Name == "initial_suspend") ? 0 : 1); + S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; + return StmtError(); + } + return cast<Stmt>(Suspend.get()); + }; + + StmtResult InitSuspend = buildSuspends("initial_suspend"); + if (InitSuspend.isInvalid()) + return true; + + StmtResult FinalSuspend = buildSuspends("final_suspend"); + if (FinalSuspend.isInvalid()) + return true; + + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + + return true; +} + ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { CorrectDelayedTyposInExpr(E); return ExprError(); } + if (E->getType()->isPlaceholderType()) { ExprResult R = CheckPlaceholderExpr(E); if (R.isInvalid()) return ExprError(); E = R.get(); } + ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); + if (Lookup.isInvalid()) + return ExprError(); + return BuildDependentCoawaitExpr(Loc, E, + cast<UnresolvedLookupExpr>(Lookup.get())); +} + +ExprResult Sema::BuildDependentCoawaitExpr(SourceLocation Loc, Expr *E, + UnresolvedLookupExpr *Lookup) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); + if (!FSI) + return ExprError(); + + if (E->getType()->isPlaceholderType()) { + ExprResult R = CheckPlaceholderExpr(E); + if (R.isInvalid()) + return ExprError(); + E = R.get(); + } - ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); + auto *Promise = FSI->CoroutinePromise; + if (Promise->getType()->isDependentType()) { + Expr *Res = + new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); + FSI->CoroutineStmts.push_back(Res); + return Res; + } + + auto *RD = Promise->getType()->getAsCXXRecordDecl(); + if (lookupMember(*this, "await_transform", RD, Loc)) { + ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); + if (R.isInvalid()) { + Diag(Loc, + diag::note_coroutine_promise_implicit_await_transform_required_here) + << E->getSourceRange(); + return ExprError(); + } + E = R.get(); + } + ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); if (Awaitable.isInvalid()) return ExprError(); return BuildCoawaitExpr(Loc, Awaitable.get()); } -ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { + +ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E, + bool IsImplicitlyCreated) { auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); if (!Coroutine) return ExprError(); @@ -298,8 +462,10 @@ } if (E->getType()->isDependentType()) { - Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); + Expr *Res = new (Context) + CoawaitExpr(Loc, Context.DependentTy, E, IsImplicitlyCreated); + if (!IsImplicitlyCreated) + Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -314,37 +480,21 @@ return ExprError(); Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], - RSS.Results[2]); - Coroutine->CoroutineStmts.push_back(Res); + RSS.Results[2], IsImplicitlyCreated); + if (!IsImplicitlyCreated) + Coroutine->CoroutineStmts.push_back(Res); return Res; } -static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, - SourceLocation Loc, StringRef Name, - MutableArrayRef<Expr *> Args) { - assert(Coroutine->CoroutinePromise && "no promise for coroutine"); - - // Form a reference to the promise. - auto *Promise = Coroutine->CoroutinePromise; - ExprResult PromiseRef = S.BuildDeclRefExpr( - Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); - if (PromiseRef.isInvalid()) - return ExprError(); - - // Call 'yield_value', passing in E. - return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); -} - ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { CorrectDelayedTyposInExpr(E); return ExprError(); } // Build yield_value call. - ExprResult Awaitable = - buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); + ExprResult Awaitable = buildPromiseCall( + *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); if (Awaitable.isInvalid()) return ExprError(); @@ -388,18 +538,18 @@ return Res; } -StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) { +StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { CorrectDelayedTyposInExpr(E); return StmtError(); } return BuildCoreturnStmt(Loc, E); } -StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) +StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, + bool IsImplicitlyCreated) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_return"); + if (!FSI) return StmtError(); if (E && E->getType()->isPlaceholderType() && @@ -412,20 +562,22 @@ // FIXME: If the operand is a reference to a variable that's about to go out // of scope, we should treat the operand as an xvalue for this overload // resolution. + VarDecl *Promise = FSI->CoroutinePromise; ExprResult PC; if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { - PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); + PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); } else { E = MakeFullDiscardedValueExpr(E).get(); - PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); + PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); } if (PC.isInvalid()) return StmtError(); Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); - Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); - Coroutine->CoroutineStmts.push_back(Res); + Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicitlyCreated); + if (!IsImplicitlyCreated) + FSI->CoroutineStmts.push_back(Res); return Res; } @@ -482,14 +634,82 @@ return OperatorDelete; } +static bool buildFallthrough(Sema &S, SourceLocation Loc, + FunctionDecl *FD, + FunctionScopeInfo *FTI, + Stmt *&OnFallthrough) +{ + assert(!OnFallthrough && "rebuilding existing OnFallthrough"); + auto *Promise = FTI->CoroutinePromise; + if (Promise->getType()->isDependentType()) + return true; + + CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl(); + + // [dcl.fct.def.coroutine]/4 + // The unqualified-ids 'return_void' and 'return_value' are looked up in + // the scope of class P. If both are found, the program is ill-formed. + const bool HasRVoid = lookupMember(S, "return_void", RD, Loc); + const bool HasRValue = lookupMember(S, "return_value", RD, Loc); + if (HasRVoid && HasRValue) { + // FIXME Improve this diagnostic + S.Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed) + << RD; + return false; + } else if (HasRVoid) { + // If the unqualified-id return_void is found, flowing off the end of a + // coroutine is equivalent to a co_return with no operand. Otherwise, + // flowing off the end of a coroutine results in undefined behavior. + StmtResult Fallthrough = S.BuildCoreturnStmt(FD->getLocation(), nullptr, + /*IsImplicitlyCreated*/ true); + if (!Fallthrough.isInvalid()) + Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); + if (Fallthrough.isInvalid()) + return false; + OnFallthrough = Fallthrough.get(); + } + return true; +} + +static bool buildSetException(Sema &S, SourceLocation Loc, + FunctionDecl *FD, + FunctionScopeInfo *FTI, + Stmt *&OnException) +{ + assert(!OnException && "rebuilding existing set_exception"); + auto *Promise = FTI->CoroutinePromise; + if (Promise->getType()->isDependentType()) + return true; + + CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl(); + + // [dcl.fct.def.coroutine]/3 + // The unqualified-id set_exception is found in the scope of P by class + // member access lookup (3.4.5). + if (lookupMember(S, "set_exception", RD, Loc)) { + // Form the call 'p.set_exception(std::current_exception())' + ExprResult SetException = buildStdCurrentExceptionCall(S, Loc); + if (SetException.isInvalid()) + return false; + Expr *E = SetException.get(); + SetException = buildPromiseCall(S, Promise, Loc, "set_exception", E); + SetException = S.ActOnFinishFullExpr(SetException.get(), Loc); + if (SetException.isInvalid()) + return false; + OnException = SetException.get(); + } + return true; +} + + // Builds allocation and deallocation for the coroutine. Returns false on // failure. static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, FunctionScopeInfo *Fn, Expr *&Allocation, Stmt *&Deallocation) { - TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); - QualType PromiseType = TInfo->getType(); + assert(!Allocation && !Deallocation && "alloc/dealloc statements have already been built"); + QualType PromiseType = Fn->CoroutinePromise->getType(); if (PromiseType->isDependentType()) return true; @@ -532,8 +752,6 @@ if (NewExpr.isInvalid()) return false; - Allocation = NewExpr.get(); - // Make delete call. QualType OpDeleteQualType = OperatorDelete->getType(); @@ -559,149 +777,137 @@ if (DeleteExpr.isInvalid()) return false; + Allocation = NewExpr.get(); Deallocation = DeleteExpr.get(); return true; } -void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { - FunctionScopeInfo *Fn = getCurFunction(); - assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); - - // Coroutines [stmt.return]p1: - // A return statement shall not appear in a coroutine. - if (Fn->FirstReturnLoc.isValid()) { - Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); - auto *First = Fn->CoroutineStmts[0]; - Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa<CoawaitExpr>(First) ? 0 : - isa<CoyieldExpr>(First) ? 1 : 2); - } - - bool AnyCoawaits = false; - bool AnyCoyields = false; - for (auto *CoroutineStmt : Fn->CoroutineStmts) { - AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt); - AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt); - } - - if (!AnyCoawaits && !AnyCoyields) - Diag(Fn->CoroutineStmts.front()->getLocStart(), - diag::ext_coroutine_without_co_await_co_yield); - - SourceLocation Loc = FD->getLocation(); - +StmtResult Sema::BuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend, + Stmt *FinalSuspend, Stmt *SetException, + Stmt *OnFallthrough, Expr *Allocation, + Stmt *Deallocation, Expr *ReturnObjectInit) { + assert(Promise && InitSuspend && FinalSuspend && "these nodes must already be built"); // Form a declaration statement for the promise declaration, so that AST // visitors can more easily find it. + // FIXME Get real location + auto *FSI = getCurFunction(); + assert(FSI->CoroutinePromise); + auto *FD = cast<FunctionDecl>(CurContext); + auto Loc = FD->getLocation(); + + auto checkPlaceholders = [&](Stmt *&S) mutable { + Expr *E = cast_or_null<Expr>(S); + if (E && E->getType()->isPlaceholderType() && + !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { + ExprResult R = CheckPlaceholderExpr(E); + if (R.isInvalid()) + return false; + S = cast<Stmt>(R.get()); + } + return true; + }; + if (!checkPlaceholders(InitSuspend) || !checkPlaceholders(FinalSuspend)) + return StmtError(); + StmtResult PromiseStmt = - ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); + ActOnDeclStmt(ConvertDeclToDeclGroup(Promise), Promise->getLocStart(), + Promise->getLocEnd()); if (PromiseStmt.isInvalid()) - return FD->setInvalidDecl(); - - // Form and check implicit 'co_await p.initial_suspend();' statement. - ExprResult InitialSuspend = - buildPromiseCall(*this, Fn, Loc, "initial_suspend", None); - // FIXME: Support operator co_await here. - if (!InitialSuspend.isInvalid()) - InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); - InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); - if (InitialSuspend.isInvalid()) - return FD->setInvalidDecl(); - - // Form and check implicit 'co_await p.final_suspend();' statement. - ExprResult FinalSuspend = - buildPromiseCall(*this, Fn, Loc, "final_suspend", None); - // FIXME: Support operator co_await here. - if (!FinalSuspend.isInvalid()) - FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); - FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); - if (FinalSuspend.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); - // Form and check allocation and deallocation calls. - Expr *Allocation = nullptr; - Stmt *Deallocation = nullptr; - if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation)) - return FD->setInvalidDecl(); + if (!OnFallthrough && !buildFallthrough(*this, Loc, FD, FSI, OnFallthrough)) + return StmtError(); - // control flowing off the end of the coroutine. - // Also try to form 'p.set_exception(std::current_exception());' to handle - // uncaught exceptions. - ExprResult SetException; - StmtResult Fallthrough; - if (Fn->CoroutinePromise && - !Fn->CoroutinePromise->getType()->isDependentType()) { - CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl(); - assert(RD && "Type should have already been checked"); - // [dcl.fct.def.coroutine]/4 - // The unqualified-ids 'return_void' and 'return_value' are looked up in - // the scope of class P. If both are found, the program is ill-formed. - DeclarationName RVoidDN = PP.getIdentifierInfo("return_void"); - LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName); - const bool HasRVoid = LookupQualifiedName(RVoidResult, RD); - - DeclarationName RValueDN = PP.getIdentifierInfo("return_value"); - LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName); - const bool HasRValue = LookupQualifiedName(RValueResult, RD); - - if (HasRVoid && HasRValue) { - // FIXME Improve this diagnostic - Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed) - << RD; - return FD->setInvalidDecl(); - } else if (HasRVoid) { - // If the unqualified-id return_void is found, flowing off the end of a - // coroutine is equivalent to a co_return with no operand. Otherwise, - // flowing off the end of a coroutine results in undefined behavior. - Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr); - Fallthrough = ActOnFinishFullStmt(Fallthrough.get()); - if (Fallthrough.isInvalid()) - return FD->setInvalidDecl(); - } + if (!SetException && !buildSetException(*this, Loc, FD, FSI, SetException)) + return StmtError(); - // [dcl.fct.def.coroutine]/3 - // The unqualified-id set_exception is found in the scope of P by class - // member access lookup (3.4.5). - DeclarationName SetExDN = PP.getIdentifierInfo("set_exception"); - LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName); - if (LookupQualifiedName(SetExResult, RD)) { - // Form the call 'p.set_exception(std::current_exception())' - SetException = buildStdCurrentExceptionCall(*this, Loc); - if (SetException.isInvalid()) - return FD->setInvalidDecl(); - Expr *E = SetException.get(); - SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E); - SetException = ActOnFinishFullExpr(SetException.get(), Loc); - if (SetException.isInvalid()) - return FD->setInvalidDecl(); - } + if (!Allocation || !Deallocation) { + assert(!Allocation && !Deallocation && "These should be a package deal"); + if (!buildAllocationAndDeallocation(*this, Loc, FSI, Allocation, + Deallocation)) + return StmtError(); } // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. - ExprResult ReturnObject = - buildPromiseCall(*this, Fn, Loc, "get_return_object", None); - if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); - QualType RetType = FD->getReturnType(); - if (!RetType->isDependentType()) { - InitializedEntity Entity = - InitializedEntity::InitializeResult(Loc, RetType, false); - ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, - ReturnObject.get()); + if (!ReturnObjectInit) { + ExprResult ReturnObject = + buildPromiseCall(*this, Promise, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return StmtError(); + QualType RetType = FD->getReturnType(); + if (!RetType->isDependentType()) { + InitializedEntity Entity = + InitializedEntity::InitializeResult(Loc, RetType, false); + ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, + ReturnObject.get()); + if (ReturnObject.isInvalid()) + return StmtError(); + } + ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); + if (ReturnObject.isInvalid()) + return StmtError(); + ReturnObjectInit = ReturnObject.get(); } - ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); - if (ReturnObject.isInvalid()) + + return new (Context) CoroutineBodyStmt( + Body, PromiseStmt.get(), InitSuspend, FinalSuspend, + SetException, OnFallthrough, Allocation, Deallocation, ReturnObjectInit, None); +} + +void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { + FunctionScopeInfo *FSI = getCurFunction(); + assert(FSI && FSI->CoroutinePromise && FSI->HasCoroutineSuspends && + "not a coroutine"); + VarDecl *Promise = FSI->CoroutinePromise; + + // Check if we failed to build the initial/final suspend points during the + // initial parse. + if (FSI->hasInvalidCoroutineSuspends()) return FD->setInvalidDecl(); // FIXME: Perform move-initialization of parameters into frame-local copies. SmallVector<Expr*, 16> ParamMoves; + if (Body && !isa<CoroutineBodyStmt>(Body)) { + StmtResult BodyRes = BuildCoroutineBodyStmt( + Body, FSI->CoroutinePromise, FSI->CoroutineSuspends.first, + FSI->CoroutineSuspends.second, nullptr, nullptr, nullptr, nullptr, + nullptr); + if (BodyRes.isInvalid()) + return FD->setInvalidDecl(); + Body = BodyRes.get(); + } + + + // Coroutines [stmt.return]p1: + // A return statement shall not appear in a coroutine. + if (FSI->FirstReturnLoc.isValid()) { + Diag(FSI->FirstReturnLoc, diag::err_return_in_coroutine); + auto *First = FSI->CoroutineStmts[0]; + Diag(First->getLocStart(), diag::note_declared_coroutine_here) + << ((isa<CoawaitExpr>(First) || isa<DependentCoawaitExpr>(First)) + ? "co_await" + : isa<CoyieldExpr>(First) ? "co_yield" : "co_return"); + } + + bool AnyCoawaits = false; + bool AnyDependentCoawaits = false; + bool AnyCoyields = false; + for (auto *CoroutineStmt : FSI->CoroutineStmts) { + // Don't count the implicitly generated initial/final suspend points + if (auto *CA = dyn_cast<CoawaitExpr>(CoroutineStmt)) + AnyCoawaits |= !CA->isImplicitlyCreated(); + AnyDependentCoawaits |= isa<DependentCoawaitExpr>(CoroutineStmt); + AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt); + } + + if (!FD->isInvalidDecl() && !FSI->CoroutineStmts.empty() && !AnyCoawaits && + !AnyCoyields && !AnyDependentCoawaits) + Diag(FSI->CoroutineStmts.front()->getLocStart(), + diag::ext_coroutine_without_co_await_co_yield); + + assert((!AnyDependentCoawaits || Promise->getType()->isDependentType()) && + "All dependent coawait expressions should already be resolved"); - // Build body for the coroutine wrapper statement. - Body = new (Context) CoroutineBodyStmt( - Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), - SetException.get(), Fallthrough.get(), Allocation, Deallocation, - ReturnObject.get(), ParamMoves); } Index: lib/Sema/ScopeInfo.cpp =================================================================== --- lib/Sema/ScopeInfo.cpp +++ lib/Sema/ScopeInfo.cpp @@ -42,6 +42,9 @@ SwitchStack.clear(); Returns.clear(); CoroutinePromise = nullptr; + HasCoroutineSuspends = false; + CoroutineSuspends.first = nullptr; + CoroutineSuspends.second = nullptr; CoroutineStmts.clear(); ErrorTrap.reset(); PossiblyUnreachableDiags.clear(); Index: lib/Parse/ParseStmt.cpp =================================================================== --- lib/Parse/ParseStmt.cpp +++ lib/Parse/ParseStmt.cpp @@ -1898,7 +1898,7 @@ } } if (IsCoreturn) - return Actions.ActOnCoreturnStmt(ReturnLoc, R.get()); + return Actions.ActOnCoreturnStmt(getCurScope(), ReturnLoc, R.get()); return Actions.ActOnReturnStmt(ReturnLoc, R.get(), getCurScope()); } Index: lib/AST/StmtProfile.cpp =================================================================== --- lib/AST/StmtProfile.cpp +++ lib/AST/StmtProfile.cpp @@ -1552,6 +1552,10 @@ VisitExpr(S); } +void StmtProfiler::VisitDependentCoawaitExpr(const DependentCoawaitExpr *S) { + VisitExpr(S); +} + void StmtProfiler::VisitCoyieldExpr(const CoyieldExpr *S) { VisitExpr(S); } Index: lib/AST/StmtPrinter.cpp =================================================================== --- lib/AST/StmtPrinter.cpp +++ lib/AST/StmtPrinter.cpp @@ -2422,6 +2422,11 @@ PrintExpr(S->getOperand()); } +void StmtPrinter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + OS << "co_await "; + PrintExpr(S->getOperand()); +} + void StmtPrinter::VisitCoyieldExpr(CoyieldExpr *S) { OS << "co_yield "; PrintExpr(S->getOperand()); Index: lib/AST/ItaniumMangle.cpp =================================================================== --- lib/AST/ItaniumMangle.cpp +++ lib/AST/ItaniumMangle.cpp @@ -3299,6 +3299,8 @@ // These all can only appear in local or variable-initialization // contexts and so should never appear in a mangling. case Expr::AddrLabelExprClass: + // This should no longer exist in the AST by now + case Expr::DependentCoawaitExprClass: case Expr::DesignatedInitUpdateExprClass: case Expr::ImplicitValueInitExprClass: case Expr::NoInitExprClass: Index: lib/AST/ExprConstant.cpp =================================================================== --- lib/AST/ExprConstant.cpp +++ lib/AST/ExprConstant.cpp @@ -9442,6 +9442,7 @@ case Expr::LambdaExprClass: case Expr::CXXFoldExprClass: case Expr::CoawaitExprClass: + case Expr::DependentCoawaitExprClass: case Expr::CoyieldExprClass: return ICEDiag(IK_NotICE, E->getLocStart()); Index: lib/AST/ExprClassification.cpp =================================================================== --- lib/AST/ExprClassification.cpp +++ lib/AST/ExprClassification.cpp @@ -129,6 +129,7 @@ case Expr::UnresolvedLookupExprClass: case Expr::UnresolvedMemberExprClass: case Expr::TypoExprClass: + case Expr::DependentCoawaitExprClass: case Expr::CXXDependentScopeMemberExprClass: case Expr::DependentScopeDeclRefExprClass: // ObjC instance variables are lvalues Index: lib/AST/Expr.cpp =================================================================== --- lib/AST/Expr.cpp +++ lib/AST/Expr.cpp @@ -2923,6 +2923,7 @@ case CXXNewExprClass: case CXXDeleteExprClass: case CoawaitExprClass: + case DependentCoawaitExprClass: case CoyieldExprClass: // These always have a side-effect. return true; Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -8032,12 +8032,21 @@ // ExprResult ActOnCoawaitExpr(Scope *S, SourceLocation KwLoc, Expr *E); ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E); - StmtResult ActOnCoreturnStmt(SourceLocation KwLoc, Expr *E); + StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E); - ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E); + ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E, + bool IsImplicitlyCreated = false); + ExprResult BuildDependentCoawaitExpr(SourceLocation KwLoc, Expr *E, + UnresolvedLookupExpr *Lookup); ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E); - StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E); - + StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E, + bool IsImplicitlyCreated = false); + StmtResult BuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend, + Stmt *FinalSuspend, Stmt *OnException, + Stmt *OnFallthrough, Expr *Allocation, + Stmt *Deallocation, Expr *ReturnValue); + + VarDecl *buildCoroutinePromise(SourceLocation Loc); void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); //===--------------------------------------------------------------------===// Index: include/clang/Sema/ScopeInfo.h =================================================================== --- include/clang/Sema/ScopeInfo.h +++ include/clang/Sema/ScopeInfo.h @@ -16,12 +16,14 @@ #define LLVM_CLANG_SEMA_SCOPEINFO_H #include "clang/AST/Expr.h" +#include "clang/AST/StmtCXX.h" #include "clang/AST/Type.h" #include "clang/Basic/CapturedStmt.h" #include "clang/Basic/PartialDiagnostic.h" #include "clang/Sema/CleanupInfo.h" #include "clang/Sema/Ownership.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include <algorithm> @@ -135,6 +137,10 @@ /// false if there is an invocation of an initializer on 'self'. bool ObjCWarnForNoInitDelegation : 1; + /// \brief Wether this function has already built, or tried to build, the + /// the initial and final coroutine suspend points. + bool HasCoroutineSuspends : 1; + /// First 'return' statement in the current function. SourceLocation FirstReturnLoc; @@ -159,6 +165,9 @@ /// \brief The promise object for this coroutine, if any. VarDecl *CoroutinePromise; + /// \brief The initial and final coroutine suspend points. + std::pair<Stmt *, Stmt *> CoroutineSuspends; + /// \brief The list of coroutine control flow constructs (co_await, co_yield, /// co_return) that occur within the function or block. Empty if and only if /// this function or block is not (yet known to be) a coroutine. @@ -376,22 +385,34 @@ (HasIndirectGoto || (HasBranchProtectedScope && HasBranchIntoScope)); } - + + void setCoroutineSuspendsInvalid() { + assert(!HasCoroutineSuspends && CoroutineSuspends.first == nullptr && + "we already have valid suspend points"); + HasCoroutineSuspends = true; + } + + bool hasInvalidCoroutineSuspends() const { + return HasCoroutineSuspends && CoroutineSuspends.first == nullptr; + } + + void setCoroutineSuspends(Stmt *Initial, Stmt *Final) { + assert(Initial && Final && "suspend points cannot be null"); + assert(CoroutineStmts.first == nullptr && "suspend points already set"); + HasCoroutineSuspends = true; + CoroutineSuspends.first = Initial; + CoroutineSuspends.second = Final; + } + FunctionScopeInfo(DiagnosticsEngine &Diag) - : Kind(SK_Function), - HasBranchProtectedScope(false), - HasBranchIntoScope(false), - HasIndirectGoto(false), - HasDroppedStmt(false), - HasOMPDeclareReductionCombiner(false), - HasFallthroughStmt(false), - HasPotentialAvailabilityViolations(false), - ObjCShouldCallSuper(false), - ObjCIsDesignatedInit(false), - ObjCWarnForNoDesignatedInitChain(false), - ObjCIsSecondaryInit(false), - ObjCWarnForNoInitDelegation(false), - ErrorTrap(Diag) { } + : Kind(SK_Function), HasBranchProtectedScope(false), + HasBranchIntoScope(false), HasIndirectGoto(false), + HasDroppedStmt(false), HasOMPDeclareReductionCombiner(false), + HasFallthroughStmt(false), HasPotentialAvailabilityViolations(false), + ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false), + ObjCWarnForNoDesignatedInitChain(false), ObjCIsSecondaryInit(false), + ObjCWarnForNoInitDelegation(false), HasCoroutineSuspends(false), + CoroutinePromise(nullptr), ErrorTrap(Diag) {} virtual ~FunctionScopeInfo(); Index: include/clang/Basic/StmtNodes.td =================================================================== --- include/clang/Basic/StmtNodes.td +++ include/clang/Basic/StmtNodes.td @@ -148,6 +148,7 @@ // C++ Coroutines TS expressions def CoroutineSuspendExpr : DStmt<Expr, 1>; def CoawaitExpr : DStmt<CoroutineSuspendExpr>; +def DependentCoawaitExpr : DStmt<Expr>; def CoyieldExpr : DStmt<CoroutineSuspendExpr>; // Obj-C Expressions. Index: include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- include/clang/Basic/DiagnosticSemaKinds.td +++ include/clang/Basic/DiagnosticSemaKinds.td @@ -8634,8 +8634,7 @@ def err_return_in_coroutine : Error< "return statement not allowed in coroutine; did you mean 'co_return'?">; def note_declared_coroutine_here : Note< - "function is a coroutine due to use of " - "'%select{co_await|co_yield|co_return}0' here">; + "function is a coroutine due to use of '%0' here">; def err_coroutine_objc_method : Error< "Objective-C methods as coroutines are not yet supported">; def err_coroutine_unevaluated_context : Error< @@ -8659,6 +8658,8 @@ "this function cannot be a coroutine: %q0 has no member named 'promise_type'">; def err_implied_std_coroutine_traits_promise_type_not_class : Error< "this function cannot be a coroutine: %0 is not a class">; +def err_coroutine_promise_type_incomplete : Error< + "this function cannot be a coroutine: %0 is an incomplete type">; def err_coroutine_traits_missing_specialization : Error< "this function cannot be a coroutine: missing definition of " "specialization %q0">; @@ -8669,6 +8670,11 @@ "'std::current_exception' must be a function">; def err_coroutine_promise_return_ill_formed : Error< "%0 declares both 'return_value' and 'return_void'">; +def note_coroutine_promise_implicit_await_transform_required_here : Note< + "call to 'await_transform' implicitly required by 'co_await' here">; +def note_coroutine_promise_call_implicitly_required : Note< + "call to '%select{initial_suspend|final_suspend}0' implicitly " + "required by the %select{initial suspend point|final suspend point}0">; } let CategoryName = "Documentation Issue" in { Index: include/clang/AST/StmtCXX.h =================================================================== --- include/clang/AST/StmtCXX.h +++ include/clang/AST/StmtCXX.h @@ -327,6 +327,8 @@ SubStmts[CoroutineBodyStmt::Allocate] = Allocate; SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate; SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue; + assert(Promise && InitSuspend && FinalSuspend && + "these members must never be null"); // FIXME: Tail-allocate space for parameter move expressions and store them. assert(ParamMoves.empty() && "not implemented yet"); } @@ -336,7 +338,10 @@ Stmt *getBody() const { return SubStmts[SubStmt::Body]; } - + void setBody(Stmt *B) { + assert(!B || !isa<CoroutineBodyStmt>(B)); + SubStmts[SubStmt::Body] = B; + } Stmt *getPromiseDeclStmt() const { return SubStmts[SubStmt::Promise]; } VarDecl *getPromiseDecl() const { return cast<VarDecl>(cast<DeclStmt>(getPromiseDeclStmt())->getSingleDecl()); @@ -349,19 +354,18 @@ Stmt *getFallthroughHandler() const { return SubStmts[SubStmt::OnFallthrough]; } - - Expr *getAllocate() const { return cast<Expr>(SubStmts[SubStmt::Allocate]); } + Expr *getAllocate() const { return cast_or_null<Expr>(SubStmts[SubStmt::Allocate]); } Stmt *getDeallocate() const { return SubStmts[SubStmt::Deallocate]; } Expr *getReturnValueInit() const { - return cast<Expr>(SubStmts[SubStmt::ReturnValue]); + return cast_or_null<Expr>(SubStmts[SubStmt::ReturnValue]); } SourceLocation getLocStart() const LLVM_READONLY { - return getBody()->getLocStart(); + return getBody() ? getBody()->getLocStart() : getPromiseDecl()->getLocStart(); } SourceLocation getLocEnd() const LLVM_READONLY { - return getBody()->getLocEnd(); + return getBody() ? getBody()->getLocEnd() : getPromiseDecl()->getLocStart(); } child_range children() { @@ -390,10 +394,14 @@ enum SubStmt { Operand, PromiseCall, Count }; Stmt *SubStmts[SubStmt::Count]; + bool IsImplicitlyCreated : 1; + friend class ASTStmtReader; public: - CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall) - : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc) { + CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall, + bool IsImplicit = false) + : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc), + IsImplicitlyCreated(IsImplicit) { SubStmts[SubStmt::Operand] = Operand; SubStmts[SubStmt::PromiseCall] = PromiseCall; } @@ -410,6 +418,8 @@ Expr *getPromiseCall() const { return static_cast<Expr*>(SubStmts[PromiseCall]); } + bool isImplicitlyCreated() const { return IsImplicitlyCreated; } + void setImplicitlyCreated(bool value = true) { IsImplicitlyCreated = value; } SourceLocation getLocStart() const LLVM_READONLY { return CoreturnLoc; } SourceLocation getLocEnd() const LLVM_READONLY { Index: include/clang/AST/RecursiveASTVisitor.h =================================================================== --- include/clang/AST/RecursiveASTVisitor.h +++ include/clang/AST/RecursiveASTVisitor.h @@ -2471,6 +2471,12 @@ ShouldVisitChildren = false; } }) +DEF_TRAVERSE_STMT(DependentCoawaitExpr, { + if (!getDerived().shouldVisitImplicitCode()) { + TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand()); + ShouldVisitChildren = false; + } +}) DEF_TRAVERSE_STMT(CoyieldExpr, { if (!getDerived().shouldVisitImplicitCode()) { TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand()); Index: include/clang/AST/ExprCXX.h =================================================================== --- include/clang/AST/ExprCXX.h +++ include/clang/AST/ExprCXX.h @@ -4231,26 +4231,82 @@ /// \brief Represents a 'co_await' expression. class CoawaitExpr : public CoroutineSuspendExpr { friend class ASTStmtReader; + + /// \brief True if this co_await expression was implicitly generated by the + /// compiler. + bool IsImplicitlyCreated : 1; + public: CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready, - Expr *Suspend, Expr *Resume) + Expr *Suspend, Expr *Resume, bool IsImplicit = false) : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready, - Suspend, Resume) {} - CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand) - : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {} + Suspend, Resume), + IsImplicitlyCreated(IsImplicit) {} + CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand, + bool IsImplicit = false) + : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand), + IsImplicitlyCreated(IsImplicit) {} CoawaitExpr(EmptyShell Empty) : CoroutineSuspendExpr(CoawaitExprClass, Empty) {} Expr *getOperand() const { // FIXME: Dig out the actual operand or store it. return getCommonExpr(); } + bool isImplicitlyCreated() const { return IsImplicitlyCreated; } + void setIsImplicitlyCreated(bool value = true) { + IsImplicitlyCreated = value; + } + static bool classof(const Stmt *T) { return T->getStmtClass() == CoawaitExprClass; } }; +/// \brief Represents a 'co_await' expression while the type of the promise +/// is dependent. +class DependentCoawaitExpr : public Expr { + SourceLocation KeywordLoc; + Stmt *SubExprs[2]; + + friend class ASTStmtReader; + +public: + DependentCoawaitExpr(SourceLocation KeywordLoc, QualType Ty, Expr *Op, + UnresolvedLookupExpr *OpCoawait) + : Expr(DependentCoawaitExprClass, Ty, VK_RValue, OK_Ordinary, + /*TypeDependent*/ true, /*ValueDependent*/ true, + /*InstantiationDependent*/ true, + Op->containsUnexpandedParameterPack()), + KeywordLoc(KeywordLoc) { + assert(Op->isTypeDependent() && Ty->isDependentType() && + "wrong constructor for non-dependent co_await/co_yield expression"); + SubExprs[0] = Op; + SubExprs[1] = OpCoawait; + } + + DependentCoawaitExpr(EmptyShell Empty) + : Expr(DependentCoawaitExprClass, Empty) {} + + Expr *getOperand() const { return cast<Expr>(SubExprs[0]); } + UnresolvedLookupExpr *getOperatorCoawaitLookup() const { + return cast<UnresolvedLookupExpr>(SubExprs[1]); + } + SourceLocation getKeywordLoc() const { return KeywordLoc; } + + SourceLocation getLocStart() const LLVM_READONLY { return KeywordLoc; } + SourceLocation getLocEnd() const LLVM_READONLY { + return getOperand()->getLocEnd(); + } + + child_range children() { return child_range(SubExprs, SubExprs + 2); } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == DependentCoawaitExprClass; + } +}; + /// \brief Represents a 'co_yield' expression. class CoyieldExpr : public CoroutineSuspendExpr { friend class ASTStmtReader;
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits