https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/85684
>From b843c2f71a1a43cb897b557f783d60c6bf26f687 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <y...@meta.com> Date: Mon, 18 Mar 2024 10:45:20 -0700 Subject: [PATCH] Check if Coroutine await_suspend type returns the right type --- clang/docs/ReleaseNotes.rst | 3 + .../clang/Basic/DiagnosticSemaKinds.td | 2 +- clang/lib/Sema/SemaCoroutine.cpp | 119 +++++++++++++----- clang/test/SemaCXX/coroutines.cpp | 28 ++++- 4 files changed, 111 insertions(+), 41 deletions(-) diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index d6e179ca9d6904..f7b44e5e65641c 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -48,6 +48,9 @@ C++ Specific Potentially Breaking Changes - Clang now diagnoses function/variable templates that shadow their own template parameters, e.g. ``template<class T> void T();``. This error can be disabled via `-Wno-strict-primary-template-shadow` for compatibility with previous versions of clang. +- Clang now emits errors for coroutine `await_suspend` functions whose return type is not + one of `void`, `bool`, or `std::coroutine_handle`. + ABI Changes in This Version --------------------------- - Fixed Microsoft name mangling of implicitly defined variables used for thread diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index fc727cef9cd835..796b3d9d5e1190 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11707,7 +11707,7 @@ def err_coroutine_promise_new_requires_nothrow : Error< def note_coroutine_promise_call_implicitly_required : Note< "call to %0 implicitly required by coroutine function here">; def err_await_suspend_invalid_return_type : Error< - "return type of 'await_suspend' is required to be 'void' or 'bool' (have %0)" + "return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have %0)" >; def note_await_ready_no_bool_conversion : Note< "return type of 'await_ready' is required to be contextually convertible to 'bool'" diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 736632857efc36..2e81a83b62df51 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -137,12 +137,8 @@ static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD, return PromiseType; } -/// Look up the std::coroutine_handle<PromiseType>. -static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, - SourceLocation Loc) { - if (PromiseType.isNull()) - return QualType(); - +static ClassTemplateDecl *lookupCoroutineHandleTemplate(Sema &S, + SourceLocation Loc) { NamespaceDecl *CoroNamespace = S.getStdNamespace(); assert(CoroNamespace && "Should already be diagnosed"); @@ -151,18 +147,32 @@ static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, if (!S.LookupQualifiedName(Result, CoroNamespace)) { S.Diag(Loc, diag::err_implied_coroutine_type_not_found) << "std::coroutine_handle"; - return QualType(); + return nullptr; } - ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); + auto *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); + if (!CoroHandle) { Result.suppressDiagnostics(); // We found something weird. Complain about the first thing we found. NamedDecl *Found = *Result.begin(); S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); - return QualType(); + return nullptr; } + return CoroHandle; +} + +/// Look up the std::coroutine_handle<PromiseType>. +static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, + SourceLocation Loc) { + if (PromiseType.isNull()) + return QualType(); + + ClassTemplateDecl *CoroHandle = lookupCoroutineHandleTemplate(S, Loc); + if (!CoroHandle) + return QualType(); + // Form template argument list for coroutine_handle<Promise>. TemplateArgumentListInfo Args(Loc, Loc); Args.addArgument(TemplateArgumentLoc( @@ -331,16 +341,12 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, // coroutine. static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, SourceLocation Loc) { - if (RetType->isReferenceType()) - return nullptr; + assert(!RetType->isReferenceType() && + "Should have diagnosed reference types."); Type const *T = RetType.getTypePtr(); if (!T->isClassType() && !T->isStructureType()) return nullptr; - // FIXME: Add convertability check to coroutine_handle<>. Possibly via - // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment - // a private function in SemaExprCXX.cpp - ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", std::nullopt); if (AddressExpr.isInvalid()) return nullptr; @@ -358,6 +364,30 @@ static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, return S.MaybeCreateExprWithCleanups(JustAddress); } +static bool isSpecializationOfCoroutineHandle(Sema &S, QualType Ty, + SourceLocation Loc) { + auto *CoroutineHandleClassTemplateDecl = + lookupCoroutineHandleTemplate(S, Loc); + + if (!CoroutineHandleClassTemplateDecl) + return false; + + auto *RecordTy = Ty->getAs<RecordType>(); + if (!RecordTy) + return false; + + auto *D = RecordTy->getDecl(); + if (!D) + return false; + + auto *SpecializationDecl = dyn_cast<ClassTemplateSpecializationDecl>(D); + if (!SpecializationDecl) + return false; + + return CoroutineHandleClassTemplateDecl->getCanonicalDecl() == + SpecializationDecl->getSpecializedTemplate()->getCanonicalDecl(); +} + /// Build calls to await_ready, await_suspend, and await_resume for a co_await /// expression. /// The generated AST tries to clean up temporary objects as early as @@ -418,39 +448,60 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, return Calls; } Expr *CoroHandle = CoroHandleRes.get(); - CallExpr *AwaitSuspend = cast_or_null<CallExpr>( - BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle)); + auto *AwaitSuspend = [&]() -> CallExpr * { + auto *SubExpr = BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle); + if (!SubExpr) + return nullptr; + if (auto *E = dyn_cast<CXXBindTemporaryExpr>(SubExpr)) { + // This happens when await_suspend return type is not trivially + // destructible. This doesn't happen for the permitted return types of + // such function. Diagnose it later. + return cast_or_null<CallExpr>(E->getSubExpr()); + } else { + return cast_or_null<CallExpr>(SubExpr); + } + }(); + if (!AwaitSuspend) return Calls; + if (!AwaitSuspend->getType()->isDependentType()) { + auto InvalidAwaitSuspendReturnType = [&](QualType RetType) { + // non-class prvalues always have cv-unqualified types + S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), + diag::err_await_suspend_invalid_return_type) + << RetType; + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << AwaitSuspend->getDirectCallee(); + Calls.IsInvalid = true; + }; + // [expr.await]p3 [...] // - await-suspend is the expression e.await_suspend(h), which shall be // a prvalue of type void, bool, or std::coroutine_handle<Z> for some // type Z. QualType RetType = AwaitSuspend->getCallReturnType(S.Context); - // Support for coroutine_handle returning await_suspend. - if (Expr *TailCallSuspend = - maybeTailCall(S, RetType, AwaitSuspend, Loc)) + if (RetType->isReferenceType()) { + InvalidAwaitSuspendReturnType(RetType); + } else if (RetType->isBooleanType() || RetType->isVoidType()) { + Calls.Results[ACT::ACT_Suspend] = + S.MaybeCreateExprWithCleanups(AwaitSuspend); + } else if (isSpecializationOfCoroutineHandle(S, RetType, Loc)) { + // Support for coroutine_handle returning await_suspend. + // // Note that we don't wrap the expression with ExprWithCleanups here // because that might interfere with tailcall contract (e.g. inserting // clean up instructions in-between tailcall and return). Instead // ExprWithCleanups is wrapped within maybeTailCall() prior to the resume // call. - Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; - else { - // non-class prvalues always have cv-unqualified types - if (RetType->isReferenceType() || - (!RetType->isBooleanType() && !RetType->isVoidType())) { - S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), - diag::err_await_suspend_invalid_return_type) - << RetType; - S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) - << AwaitSuspend->getDirectCallee(); - Calls.IsInvalid = true; - } else - Calls.Results[ACT::ACT_Suspend] = - S.MaybeCreateExprWithCleanups(AwaitSuspend); + Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc); + if (TailCallSuspend) + Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; + else + InvalidAwaitSuspendReturnType(RetType); + } else { + InvalidAwaitSuspendReturnType(RetType); } } diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp index 2292932583fff6..14c4a2a8d9b45e 100644 --- a/clang/test/SemaCXX/coroutines.cpp +++ b/clang/test/SemaCXX/coroutines.cpp @@ -1005,12 +1005,24 @@ coro<promise_no_return_func> no_return_value_or_return_void_3() { co_return 43; // expected-error {{no member named 'return_value'}} } -struct bad_await_suspend_return { +struct non_trivial_destruction_type { + ~non_trivial_destruction_type(); +}; + +struct bad_await_suspend_return_1 { bool await_ready(); - // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'char')}} + // expected-error@+1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'char')}} char await_suspend(std::coroutine_handle<>); void await_resume(); }; + +struct bad_await_suspend_return_2 { + bool await_ready(); + // expected-error@+1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'non_trivial_destruction_type')}} + non_trivial_destruction_type await_suspend(std::coroutine_handle<>); + void await_resume(); +}; + struct bad_await_ready_return { // expected-note@+1 {{return type of 'await_ready' is required to be contextually convertible to 'bool'}} void await_ready(); @@ -1028,8 +1040,8 @@ struct await_ready_explicit_bool { template <class SuspendTy> struct await_suspend_type_test { bool await_ready(); - // expected-error@+2 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &')}} - // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &&')}} + // expected-error@+2 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'bool &')}} + // expected-error@+1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'bool &&')}} SuspendTy await_suspend(std::coroutine_handle<>); // cxx20_23-warning@-1 {{volatile-qualified return type 'const volatile bool' is deprecated}} void await_resume(); @@ -1042,8 +1054,12 @@ void test_bad_suspend() { co_await a; // expected-note {{call to 'await_ready' implicitly required by coroutine function here}} } { - bad_await_suspend_return b; - co_await b; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}} + bad_await_suspend_return_1 b1; + co_await b1; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}} + } + { + bad_await_suspend_return_2 b2; + co_await b2; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}} } { await_ready_explicit_bool c; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits