llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: None (Andres-Salamanca)

<details>
<summary>Changes</summary>

This PR partially upstreams support for the `co_return` keyword. It still needs 
to address the case where a `co_return` returns a value from a `co_await`. 
Additionally, this change focuses on `emitBodyAndFallthrough`, where depending 
on whether the function falls through or not it will emit the user written 
`co_await`. Another thing to note is the difference from classic CodeGen, 
previously it checked whether it could fall through by using `GetInsertBlock()` 
to verify that the block existed.  In our case, when a `co_return` is emitted, 
we mark `setCoreturn()` to indicate that the coroutine contains a `co_return`.


---
Full diff: https://github.com/llvm/llvm-project/pull/171755.diff


5 Files Affected:

- (modified) clang/include/clang/CIR/MissingFeatures.h (+1) 
- (modified) clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp (+91-2) 
- (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+14) 
- (modified) clang/lib/CIR/CodeGen/CIRGenStmt.cpp (+2-1) 
- (modified) clang/test/CIR/CodeGen/coro-task.cpp (+62) 


``````````diff
diff --git a/clang/include/clang/CIR/MissingFeatures.h 
b/clang/include/clang/CIR/MissingFeatures.h
index 9975ee0142d77..302e405587520 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -155,6 +155,7 @@ struct MissingFeatures {
   static bool coroOutsideFrameMD() { return false; }
   static bool coroCoReturn() { return false; }
   static bool coroCoYield() { return false; }
+  static bool unhandledException() { return false; };
 
   // Various handling of deferred processing in CIRGenModule.
   static bool cgmRelease() { return false; }
diff --git a/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp 
b/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
index b4f185d0b2e3e..7d95023e01c3e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
@@ -33,6 +33,15 @@ struct clang::CIRGen::CGCoroData {
   // Stores the result of __builtin_coro_begin call.
   mlir::Value coroBegin = nullptr;
 
+  // Stores the insertion point for final suspend, this happens after the
+  // promise call (return_xxx promise member) but before a cir.br to the return
+  // block.
+  mlir::Operation *finalSuspendInsPoint;
+
+  // How many co_return statements are in the coroutine. Used to decide whether
+  // we need to add co_return; equivalent at the end of the user authored body.
+  unsigned coreturnCount = 0;
+
   // The promise type's 'unhandled_exception' handler, if it defines one.
   Stmt *exceptionHandler = nullptr;
 };
@@ -118,6 +127,29 @@ static void createCoroData(CIRGenFunction &cgf,
   curCoro.data->coroId = coroId;
 }
 
+static mlir::LogicalResult
+emitBodyAndFallthrough(CIRGenFunction &cgf, const CoroutineBodyStmt &s,
+                       Stmt *body,
+                       const CIRGenFunction::LexicalScope *currLexScope) {
+  if (cgf.emitStmt(body, /*useCurrentScope=*/true).failed())
+    return mlir::failure();
+  // Note that LLVM checks CanFallthrough by looking into the availability
+  // of the insert block which is kinda brittle and unintuitive, seems to be
+  // related with how landing pads are handled.
+  //
+  // CIRGen handles this by checking pre-existing co_returns in the current
+  // scope instead.
+
+  // From LLVM IR Gen: const bool CanFallthrough = Builder.GetInsertBlock();
+  const bool canFallthrough = !currLexScope->hasCoreturn();
+  if (canFallthrough)
+    if (Stmt *onFallthrough = s.getFallthroughHandler())
+      if (cgf.emitStmt(onFallthrough, /*useCurrentScope=*/true).failed())
+        return mlir::failure();
+
+  return mlir::success();
+}
+
 cir::CallOp CIRGenFunction::emitCoroIDBuiltinCall(mlir::Location loc,
                                                   mlir::Value nullPtr) {
   cir::IntType int32Ty = builder.getUInt32Ty();
@@ -267,11 +299,39 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt 
&s) {
                        /*isInit*/ true);
 
     assert(!cir::MissingFeatures::ehCleanupScope());
-    // FIXME(cir): EHStack.pushCleanup<CallCoroEnd>(EHCleanup);
+
     curCoro.data->currentAwaitKind = cir::AwaitKind::Init;
     if (emitStmt(s.getInitSuspendStmt(), /*useCurrentScope=*/true).failed())
       return mlir::failure();
-    assert(!cir::MissingFeatures::emitBodyAndFallthrough());
+
+    curCoro.data->currentAwaitKind = cir::AwaitKind::User;
+
+    // FIXME(cir): wrap emitBodyAndFallthrough with try/catch bits.
+    if (s.getExceptionHandler())
+      assert(!cir::MissingFeatures::unhandledException());
+    if (emitBodyAndFallthrough(*this, s, s.getBody(), curLexScope).failed())
+      return mlir::failure();
+
+    // Note that LLVM checks CanFallthrough by looking into the availability
+    // of the insert block which is kinda brittle and unintuitive, seems to be
+    // related with how landing pads are handled.
+    //
+    // CIRGen handles this by checking pre-existing co_returns in the current
+    // scope instead.
+    //
+    // From LLVM IR Gen: const bool CanFallthrough = Builder.GetInsertBlock();
+    const bool canFallthrough = curLexScope->hasCoreturn();
+    const bool hasCoreturns = curCoro.data->coreturnCount > 0;
+    if (canFallthrough || hasCoreturns) {
+      curCoro.data->currentAwaitKind = cir::AwaitKind::Final;
+      {
+        mlir::OpBuilder::InsertionGuard guard(builder);
+        builder.setInsertionPoint(curCoro.data->finalSuspendInsPoint);
+        if (emitStmt(s.getFinalSuspendStmt(), /*useCurrentScope=*/true)
+                .failed())
+          return mlir::failure();
+      }
+    }
   }
   return mlir::success();
 }
@@ -425,3 +485,32 @@ RValue CIRGenFunction::emitCoawaitExpr(const CoawaitExpr 
&e,
   return emitSuspendExpr(*this, e, curCoro.data->currentAwaitKind, aggSlot,
                          ignoreResult);
 }
+
+mlir::LogicalResult CIRGenFunction::emitCoreturnStmt(CoreturnStmt const &s) {
+  ++curCoro.data->coreturnCount;
+  curLexScope->setCoreturn();
+
+  const Expr *rv = s.getOperand();
+  if (rv && rv->getType()->isVoidType() && !isa<InitListExpr>(rv)) {
+    // Make sure to evaluate the non initlist expression of a co_return
+    // with a void expression for side effects.
+    assert(!cir::MissingFeatures::ehCleanupScope());
+    emitIgnoredExpr(rv);
+  }
+
+  if (emitStmt(s.getPromiseCall(), /*useCurrentScope=*/true).failed())
+    return mlir::failure();
+  // Create a new return block (if not existent) and add a branch to
+  // it. The actual return instruction is only inserted during current
+  // scope cleanup handling.
+  mlir::Location loc = getLoc(s.getSourceRange());
+  mlir::Block *retBlock = curLexScope->getOrCreateRetBlock(*this, loc);
+  curCoro.data->finalSuspendInsPoint =
+      cir::BrOp::create(builder, loc, retBlock);
+
+  // Insert the new block to continue codegen after branch to ret block,
+  // this will likely be an empty block.
+  builder.createBlock(builder.getBlock()->getParent());
+
+  return mlir::success();
+}
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h 
b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 15322ee72a1b0..db9e7847363e3 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -1072,6 +1072,12 @@ class CIRGenFunction : public CIRGenTypeCache {
     // Holds the actual value for ScopeKind::Try
     cir::TryOp tryOp = nullptr;
 
+    // On a coroutine body, the OnFallthrough sub stmt holds the handler
+    // (CoreturnStmt) for control flow falling off the body. Keep track
+    // of emitted co_return in this scope and allow OnFallthrough to be
+    // skipeed.
+    bool hasCoreturnStmt = false;
+
     // Only Regular is used at the moment. Support for other kinds will be
     // added as the relevant statements/expressions are upstreamed.
     enum Kind {
@@ -1119,6 +1125,12 @@ class CIRGenFunction : public CIRGenTypeCache {
       restore();
     }
 
+    // ---
+    // Coroutine tracking
+    // ---
+    bool hasCoreturn() const { return hasCoreturnStmt; }
+    void setCoreturn() { hasCoreturnStmt = true; }
+
     // ---
     // Kind
     // ---
@@ -1473,6 +1485,8 @@ class CIRGenFunction : public CIRGenTypeCache {
 
   mlir::LogicalResult emitContinueStmt(const clang::ContinueStmt &s);
 
+  mlir::LogicalResult emitCoreturnStmt(const CoreturnStmt &s);
+
   void emitCXXConstructExpr(const clang::CXXConstructExpr *e,
                             AggValueSlot dest);
 
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp 
b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index f13e7cb32c71e..5d4de5dac6d4c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -163,6 +163,8 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
     return emitCoroutineBody(cast<CoroutineBodyStmt>(*s));
   case Stmt::IndirectGotoStmtClass:
     return emitIndirectGotoStmt(cast<IndirectGotoStmt>(*s));
+  case Stmt::CoreturnStmtClass:
+    return emitCoreturnStmt(cast<CoreturnStmt>(*s));
   case Stmt::OpenACCComputeConstructClass:
     return emitOpenACCComputeConstruct(cast<OpenACCComputeConstruct>(*s));
   case Stmt::OpenACCLoopConstructClass:
@@ -203,7 +205,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
   case Stmt::CaseStmtClass:
   case Stmt::SEHLeaveStmtClass:
   case Stmt::SYCLKernelCallStmtClass:
-  case Stmt::CoreturnStmtClass:
   case Stmt::OMPParallelDirectiveClass:
   case Stmt::OMPTaskwaitDirectiveClass:
   case Stmt::OMPTaskyieldDirectiveClass:
diff --git a/clang/test/CIR/CodeGen/coro-task.cpp 
b/clang/test/CIR/CodeGen/coro-task.cpp
index c6e21c993b64f..6cd494317f2d8 100644
--- a/clang/test/CIR/CodeGen/coro-task.cpp
+++ b/clang/test/CIR/CodeGen/coro-task.cpp
@@ -42,6 +42,15 @@ struct string {
   string(char const *s);
 };
 
+template<typename T>
+struct optional {
+  optional();
+  optional(const T&);
+  T &operator*() &;
+  T &&operator*() &&;
+  T &value() &;
+  T &&value() &&;
+};
 } // namespace std
 
 namespace folly {
@@ -94,6 +103,10 @@ struct Task<void> {
 
 // inline constexpr blocking_wait_fn blocking_wait{};
 // static constexpr blocking_wait_fn const& blockingWait = blocking_wait;
+template <typename T>
+T blockingWait(Task<T>&& awaitable) {
+  return T();
+}
 
 struct co_invoke_fn {
   template <typename F, typename... A>
@@ -218,6 +231,25 @@ VoidTask silly_task() {
 // - The final suspend co_await
 // - Return
 
+// The actual user written co_await
+// CIR: cir.scope {
+// CIR:   cir.await(user, ready : {
+// CIR:   }, suspend : {
+// CIR:   }, resume : {
+// CIR:   },)
+// CIR: }
+
+// The promise call
+// CHECK: cir.call 
@_ZN5folly4coro4TaskIvE12promise_type11return_voidEv(%[[VoidPromisseAddr]])
+
+// The final suspend co_await
+// CIR: cir.scope {
+// CIR:   cir.await(final, ready : {
+// CIR:   }, suspend : {
+// CIR:   }, resume : {
+// CIR:   },)
+// CIR: }
+
 folly::coro::Task<int> byRef(const std::string& s) {
   co_return s.size();
 }
@@ -260,3 +292,33 @@ folly::coro::Task<int> byRef(const std::string& s) {
 // CIR:         cir.yield
 // CIR:       },)
 // CIR:     }
+
+// can't fallthrough
+// CIR-NOT:   cir.await(user
+
+// The final suspend co_await
+// CIR: cir.scope {
+// CIR:   cir.await(final, ready : {
+// CIR:   }, suspend : {
+// CIR:   }, resume : {
+// CIR:   },)
+// CIR: }
+
+folly::coro::Task<void> silly_coro() {
+  std::optional<folly::coro::Task<int>> task;
+  {
+    std::string s = "yolo";
+    task = byRef(s);
+  }
+  folly::coro::blockingWait(std::move(task.value()));
+  co_return;
+}
+
+// Make sure we properly handle OnFallthrough coro body sub stmt and
+// check there are not multiple co_returns emitted.
+
+// CIR: cir.func coroutine {{.*}} @_Z10silly_corov() {{.*}} ![[VoidTask]]
+// CIR: cir.await(init, ready : {
+// CIR: cir.call @_ZN5folly4coro4TaskIvE12promise_type11return_voidEv
+// CIR-NOT: cir.call @_ZN5folly4coro4TaskIvE12promise_type11return_voidEv
+// CIR: cir.await(final, ready : {

``````````

</details>


https://github.com/llvm/llvm-project/pull/171755
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to