ChuanqiXu updated this revision to Diff 549795.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D157833/new/

https://reviews.llvm.org/D157833

Files:
  clang/docs/ReleaseNotes.rst
  clang/lib/CodeGen/CGCall.cpp
  clang/lib/CodeGen/CGCoroutine.cpp
  clang/lib/CodeGen/CodeGenFunction.h
  clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp
  clang/test/CodeGenCoroutines/pr56301.cpp

Index: clang/test/CodeGenCoroutines/pr56301.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCoroutines/pr56301.cpp
@@ -0,0 +1,85 @@
+// An end-to-end test to make sure things get processed correctly.
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -emit-llvm -o - %s -O3 | \
+// RUN:     FileCheck %s
+
+#include "Inputs/coroutine.h"
+
+struct SomeAwaitable {
+  // Resume the supplied handle once the awaitable becomes ready,
+  // returning a handle that should be resumed now for the sake of symmetric transfer.
+  // If the awaitable is already ready, return an empty handle without doing anything.
+  //
+  // Defined in another translation unit. Note that this may contain
+  // code that synchronizees with another thread.
+  std::coroutine_handle<> Register(std::coroutine_handle<>);
+};
+
+// Defined in another translation unit.
+void DidntSuspend();
+
+struct Awaiter {
+  SomeAwaitable&& awaitable;
+  bool suspended;
+
+  bool await_ready() { return false; }
+
+  std::coroutine_handle<> await_suspend(const std::coroutine_handle<> h) {
+    // Assume we will suspend unless proven otherwise below. We must do
+    // this *before* calling Register, since we may be destroyed by another
+    // thread asynchronously as soon as we have registered.
+    suspended = true;
+
+    // Attempt to hand off responsibility for resuming/destroying the coroutine.
+    const auto to_resume = awaitable.Register(h);
+
+    if (!to_resume) {
+      // The awaitable is already ready. In this case we know that Register didn't
+      // hand off responsibility for the coroutine. So record the fact that we didn't
+      // actually suspend, and tell the compiler to resume us inline.
+      suspended = false;
+      return h;
+    }
+
+    // Resume whatever Register wants us to resume.
+    return to_resume;
+  }
+
+  void await_resume() {
+    // If we didn't suspend, make note of that fact.
+    if (!suspended) {
+      DidntSuspend();
+    }
+  }
+};
+
+struct MyTask{
+  struct promise_type {
+    MyTask get_return_object() { return {}; }
+    std::suspend_never initial_suspend() { return {}; }
+    std::suspend_always final_suspend() noexcept { return {}; }
+    void unhandled_exception();
+
+    Awaiter await_transform(SomeAwaitable&& awaitable) {
+      return Awaiter{static_cast<SomeAwaitable&&>(awaitable)};
+    }
+  };
+};
+
+MyTask FooBar() {
+  co_await SomeAwaitable();
+}
+
+// CHECK-LABEL: @_Z6FooBarv
+// CHECK: %[[to_resume:.*]] = {{.*}}call ptr @_ZN13SomeAwaitable8RegisterESt16coroutine_handleIvE
+// CHECK-NEXT: %[[to_bool:.*]] = icmp eq ptr %[[to_resume]], null
+// CHECK-NEXT: br i1 %[[to_bool]], label %[[then:.*]], label %[[else:.*]]
+
+// CHECK: [[then]]:
+// We only access the coroutine frame conditionally as the sources did.
+// CHECK:   store i8 0,
+// CHECK-NEXT: br label %[[else]]
+
+// CHECK: [[else]]:
+// No more access to the coroutine frame until suspended.
+// CHECK-NOT: store
+// CHECK: }
Index: clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCoroutines/coro-awaiter-noinline-suspend.cpp
@@ -0,0 +1,207 @@
+// Tests that we can mark await-suspend as noinline correctly.
+//
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -emit-llvm -o - %s \
+// RUN:     -disable-llvm-passes | FileCheck %s
+
+#include "Inputs/coroutine.h"
+
+struct Task {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+      template <typename PromiseType>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
+        return h.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    Task get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_void() noexcept {}
+
+    std::coroutine_handle<> continuation;
+  };
+
+  Task(std::coroutine_handle<promise_type> handle);
+  ~Task();
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+struct StatefulAwaiter {
+    int value;
+    bool await_ready() const noexcept { return false; }
+    template <typename PromiseType>
+    void await_suspend(std::coroutine_handle<PromiseType> h) noexcept {}
+    void await_resume() noexcept {}
+};
+
+typedef std::suspend_always NoStateAwaiter;
+using AnotherStatefulAwaiter = StatefulAwaiter;
+
+template <class T>
+struct TemplatedAwaiter {
+    T value;
+    bool await_ready() const noexcept { return false; }
+    template <typename PromiseType>
+    void await_suspend(std::coroutine_handle<PromiseType> h) noexcept {}
+    void await_resume() noexcept {}
+};
+
+
+class Awaitable {};
+StatefulAwaiter operator co_await(Awaitable) {
+  return StatefulAwaiter{};
+}
+
+StatefulAwaiter GlobalAwaiter;
+class Awaitable2 {};
+StatefulAwaiter& operator co_await(Awaitable2) {
+  return GlobalAwaiter;
+}
+
+Task testing() {
+    co_await std::suspend_always{};
+    co_await StatefulAwaiter{};
+    co_await AnotherStatefulAwaiter{};
+    
+    // Test lvalue case.
+    StatefulAwaiter awaiter;
+    co_await awaiter;
+
+    // The explicit call to await_suspend is not considered suspended.
+    awaiter.await_suspend(std::coroutine_handle<void>::from_address(nullptr));
+
+    co_await TemplatedAwaiter<int>{};
+    TemplatedAwaiter<int> TemplatedAwaiterInstace;
+    co_await TemplatedAwaiterInstace;
+
+    co_await Awaitable{};
+    co_await Awaitable2{};
+}
+
+// CHECK-LABEL: @_Z7testingv
+
+// Check `co_await __promise__.initial_suspend();` Since it returns std::suspend_always,
+// which is an empty class, we shouldn't generate optimization blocker for it.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR:[0-9]+]]
+
+// Check the `co_await std::suspend_always{};` expression. We shouldn't emit the optimization
+// blocker for it since it is an empty class.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR]]
+
+// Check `co_await StatefulAwaiter{};`. We need to emit the optimization blocker since
+// the awaiter is not empty.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR:[0-9]+]]
+
+// Check `co_await AnotherStatefulAwaiter{};` to make sure that we can handle TypedefTypes.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await awaiter;` to make sure we can handle lvalue cases.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `awaiter.await_suspend(...)` to make sure the explicit call the await_suspend won't be marked as noinline
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIvEEvSt16coroutine_handleIT_E{{.*}}#[[NORMAL_ATTR]]
+
+// Check `co_await TemplatedAwaiter<int>{};` to make sure we can handle specialized template
+// type.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN16TemplatedAwaiterIiE13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await TemplatedAwaiterInstace;` to make sure we can handle the lvalue from
+// specialized template type.
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN16TemplatedAwaiterIiE13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await Awaitable{};` to make sure we can handle awaiter returned by
+// `operator co_await`;
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await Awaitable2{};` to make sure we can handle awaiter returned by
+// `operator co_await` which returns a reference;
+// CHECK: call token @llvm.coro.save
+// CHECK: call void @_ZN15StatefulAwaiter13await_suspendIN4Task12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NOINLINE_ATTR]]
+
+// Check `co_await __promise__.final_suspend();`. We don't emit an blocker here since it is
+// empty.
+// CHECK: call token @llvm.coro.save
+// CHECK: call ptr @_ZN4Task12promise_type12FinalAwaiter13await_suspendIS0_EESt16coroutine_handleIvES3_IT_E{{.*}}#[[NORMAL_ATTR]]
+
+struct AwaitTransformTask {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+      template <typename PromiseType>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
+        return h.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    AwaitTransformTask get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_void() noexcept {}
+
+    template <typename Awaitable>
+    auto await_transform(Awaitable &&awaitable) {
+      return awaitable;
+    }
+
+    std::coroutine_handle<> continuation;
+  };
+
+  AwaitTransformTask(std::coroutine_handle<promise_type> handle);
+  ~AwaitTransformTask();
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+struct awaitableWithGetAwaiter {
+  bool await_ready() const noexcept { return false; }
+  template <typename PromiseType>
+  void await_suspend(std::coroutine_handle<PromiseType> h) noexcept {}
+  void await_resume() noexcept {}
+};
+
+AwaitTransformTask testingWithAwaitTransform() {
+  co_await awaitableWithGetAwaiter{};
+}
+
+// CHECK-LABEL: @_Z25testingWithAwaitTransformv
+
+// Init suspend
+// CHECK: call token @llvm.coro.save
+// CHECK-NOT: call void @llvm.coro.opt.blocker(
+// CHECK: call void @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE{{.*}}#[[NORMAL_ATTR]]
+
+// Check `co_await awaitableWithGetAwaiter{};`.
+// CHECK: call token @llvm.coro.save
+// CHECK-NOT: call void @llvm.coro.opt.blocker(
+// Check call void @_ZN23awaitableWithGetAwaiter13await_suspendIN18AwaitTransformTask12promise_typeEEEvSt16coroutine_handleIT_E{{.*}}#[[NORMAL_ATTR]]
+
+// Final suspend
+// CHECK: call token @llvm.coro.save
+// CHECK-NOT: call void @llvm.coro.opt.blocker(
+// CHECK: call ptr @_ZN18AwaitTransformTask12promise_type12FinalAwaiter13await_suspendIS0_EESt16coroutine_handleIvES3_IT_E{{.*}}#[[NORMAL_ATTR]]
+
+// CHECK-NOT: attributes #[[NORMAL_ATTR]] = noinline
+// CHECK: attributes #[[NOINLINE_ATTR]] = {{.*}}noinline
Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -334,6 +334,7 @@
   struct CGCoroInfo {
     std::unique_ptr<CGCoroData> Data;
     bool InSuspendBlock = false;
+    bool MaySuspendLeak = false;
     CGCoroInfo();
     ~CGCoroInfo();
   };
@@ -347,6 +348,10 @@
     return isCoroutine() && CurCoro.InSuspendBlock;
   }
 
+  bool maySuspendLeakCoroutineHandle() const {
+    return isCoroutine() && CurCoro.MaySuspendLeak;
+  }
+
   /// CurGD - The GlobalDecl for the current function being compiled.
   GlobalDecl CurGD;
 
Index: clang/lib/CodeGen/CGCoroutine.cpp
===================================================================
--- clang/lib/CodeGen/CGCoroutine.cpp
+++ clang/lib/CodeGen/CGCoroutine.cpp
@@ -12,9 +12,10 @@
 
 #include "CGCleanup.h"
 #include "CodeGenFunction.h"
-#include "llvm/ADT/ScopeExit.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtVisitor.h"
+#include "clang/AST/TypeVisitor.h"
+#include "llvm/ADT/ScopeExit.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -139,6 +140,164 @@
   return true;
 }
 
+namespace {
+// We need a TypeVisitor to find the actual awaiter declaration.
+// We can't use (CoroutineSuspendExpr).getCommonExpr()->getType() directly
+// since its type may be AutoType, ElaboratedType, ...
+class AwaiterTypeFinder : public TypeVisitor<AwaiterTypeFinder> {
+  CXXRecordDecl *Result = nullptr;
+
+public:
+  typedef TypeVisitor<AwaiterTypeFinder> Inherited;
+
+  void Visit(const CoroutineSuspendExpr &S) {
+    Visit(S.getCommonExpr()->getType());
+  }
+
+  bool IsRecordEmpty() {
+    assert(Result && "Why can't we find the record type from the common "
+                     "expression of a coroutine suspend expression? "
+                     "Maybe we missed some types or the Sema get something "
+                     "incorrect");
+
+    // In a release build without assertions enabled, return false directly
+    // to give users better user experience. It doesn't matter with the
+    // correctness but 1 byte memory overhead.
+#ifdef NDEBUG
+    if (!Result)
+      return false;
+#endif
+
+    return Result->field_empty();
+  }
+
+  // Following off should only be called by Inherited.
+public:
+  void Visit(QualType Type) { Visit(Type.getTypePtr()); }
+
+  void Visit(const Type *T) { Inherited::Visit(T); }
+
+  void VisitDeducedType(const DeducedType *T) { Visit(T->getDeducedType()); }
+
+  void VisitTypedefType(const TypedefType *T) {
+    Visit(T->getDecl()->getUnderlyingType());
+  }
+
+  void VisitElaboratedType(const ElaboratedType *T) {
+    Visit(T->getNamedType());
+  }
+
+  void VisitReferenceType(const ReferenceType *T) {
+    Visit(T->getPointeeType());
+  }
+
+  void VisitTemplateSpecializationType(const TemplateSpecializationType *T) {
+    // In the case the type is sugared, we can only see InjectedClassNameType,
+    // which doesn't contain the definition information we need.
+    if (T->desugar().getTypePtr() != T) {
+      Visit(T->desugar().getTypePtr());
+      return;
+    }
+
+    TemplateName Name = T->getTemplateName();
+    TemplateDecl *TD = Name.getAsTemplateDecl();
+
+    if (!TD)
+      return;
+
+    if (auto *TypedD = dyn_cast<TypeDecl>(TD->getTemplatedDecl()))
+      Visit(TypedD->getTypeForDecl());
+  }
+
+  void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) {
+    Visit(T->getReplacementType());
+  }
+
+  void VisitInjectedClassNameType(const InjectedClassNameType *T) {
+    VisitCXXRecordDecl(T->getDecl());
+  }
+
+  void VisitCXXRecordDecl(CXXRecordDecl *Candidate) {
+    assert(Candidate);
+
+#ifdef NDEBUG
+    Result = Candidate;
+#else
+    // Double check that the type we found is an awaiter class type.
+    // We only do this in debug mode since:
+    // The Sema should diagnose earlier in such cases. So this may
+    // be a waste of time in most cases.
+    // We just want to make sure our assumption is correct.
+
+    auto HasMember = [](CXXRecordDecl *Candidate, llvm::StringRef Name,
+                        auto HasMember) {
+      Candidate = Candidate->getDefinition();
+      if (!Candidate)
+        return false;
+
+      ASTContext &Context = Candidate->getASTContext();
+
+      auto IdenIter = Context.Idents.find(Name);
+      if (IdenIter == Context.Idents.end())
+        return false;
+
+      if (!Candidate->lookup(DeclarationName(IdenIter->second)).empty())
+        return true;
+
+      return llvm::any_of(
+          Candidate->bases(), [Name, &HasMember](CXXBaseSpecifier &Specifier) {
+            auto *RD = cast<CXXRecordDecl>(
+                Specifier.getType()->getAs<RecordType>()->getDecl());
+            return HasMember(RD, Name, HasMember);
+          });
+    };
+
+    bool FoundAwaitReady = HasMember(Candidate, "await_ready", HasMember);
+    bool FoundAwaitSuspend = HasMember(Candidate, "await_suspend", HasMember);
+    bool FoundAwaitResume = HasMember(Candidate, "await_resume", HasMember);
+
+    assert(FoundAwaitReady && FoundAwaitSuspend && FoundAwaitResume);
+    Result = Candidate;
+#endif
+  }
+
+  void VisitRecordType(const RecordType *RT) {
+    assert(isa<CXXRecordDecl>(RT->getDecl()));
+    VisitCXXRecordDecl(cast<CXXRecordDecl>(RT->getDecl()));
+  }
+
+  void VisitType(const Type *T) {}
+};
+} // namespace
+
+/// Return true when the await-suspend
+/// (`awaiter.await_suspend(std::coroutine_handle)` expression) may leak the
+/// coroutine handle. Return false only when the await-suspend won't leak the
+/// coroutine handle for sure.
+///
+/// While it is always safe to return true, return falses can bring better
+/// performances.
+///
+/// The middle end can't understand that the relationship between local
+/// variables between local variables with the coroutine handle until CoroSplit
+/// pass. However, there are a lot optimizations before CoroSplit. Luckily, it
+/// is not so bothering since the C++ languages doesn't allow the programmers to
+/// access the coroutine handle except in await_suspend.
+///
+/// See https://github.com/llvm/llvm-project/issues/56301 and
+/// https://reviews.llvm.org/D157070 for the example and the full discussion.
+static bool MaySuspendLeak(CoroutineSuspendExpr const &S) {
+  AwaiterTypeFinder Finder;
+  Finder.Visit(S);
+  // In case the awaiter type is empty, the suspend wouldn't leak the coroutine
+  // handle.
+  //
+  // TODO: We can improve this by looking into the implementation of
+  // await-suspend and see if the coroutine handle is passed to foreign
+  // functions.
+  return !Finder.IsRecordEmpty();
+}
+
 // Emit suspend expression which roughly looks like:
 //
 //   auto && x = CommonExpr();
@@ -199,8 +358,11 @@
   auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
 
   CGF.CurCoro.InSuspendBlock = true;
+  CGF.CurCoro.MaySuspendLeak = MaySuspendLeak(S);
   auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
   CGF.CurCoro.InSuspendBlock = false;
+  CGF.CurCoro.MaySuspendLeak = false;
+
   if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
     // Veto suspension if requested by bool returning await_suspend.
     BasicBlock *RealSuspendBlock =
Index: clang/lib/CodeGen/CGCall.cpp
===================================================================
--- clang/lib/CodeGen/CGCall.cpp
+++ clang/lib/CodeGen/CGCall.cpp
@@ -5484,6 +5484,26 @@
         Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::AlwaysInline);
   }
 
+  // When we're emitting suspend block for C++20 coroutines, we need to be sure
+  // that the call to the `await_suspend()` may not get inlined until the
+  // coroutine got splitted in case the `await_suspend` may leak the coroutine
+  // handle.
+  //
+  // This is necessary since the standards specifies that the coroutine is
+  // considered to be suspended after we enter the await_suspend block. So that
+  // we need to make sure we don't update the coroutine handle during the
+  // execution of the await_suspend. To achieve this, we need to prevent the
+  // await_suspend get inlined before CoroSplit pass.
+  //
+  // We can omit the `NoInline` attribute in case we are sure the await_suspend
+  // call won't leak the coroutine handle so that the middle end can get more
+  // optimization opportunities.
+  //
+  // TODO: We should try to remove the `NoInline` attribute after CoroSplit
+  // pass.
+  if (inSuspendBlock() && maySuspendLeakCoroutineHandle())
+    Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline);
+
   // Disable inlining inside SEH __try blocks.
   if (isSEHTryScope()) {
     Attrs = Attrs.addFnAttribute(getLLVMContext(), llvm::Attribute::NoInline);
Index: clang/docs/ReleaseNotes.rst
===================================================================
--- clang/docs/ReleaseNotes.rst
+++ clang/docs/ReleaseNotes.rst
@@ -138,6 +138,10 @@
   class, which can result in miscompiles in some cases.
 - Fix crash on use of a variadic overloaded operator.
   (`#42535 <https://github.com/llvm/llvm-project/issues/42535>_`)
+- Fixed an issue that the conditional access to local variables of the awaiter
+  after leaking the coroutine handle in the await_suspend may be converted to
+  unconditional access incorrectly.
+  (`#56301 <https://github.com/llvm/llvm-project/issues/56301>`_)
 
 Bug Fixes to Compiler Builtins
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to