GorNishanov created this revision.
Herald added a subscriber: mehdi_amini.

Sema::CheckCompletedCoroutineBody was growing unwieldy with building all of the 
substatements. Also, constructors for CoroutineBodyStmt had way too many 
parameters.

Instead,  CoroutineBodyStmt now defines CtorArgs structure with all of the 
required construction parameters.
CheckCompleteCoroutineBody delegates construction of individual substatements 
to short functions one per each substatement.

Also, added a drive-by fix of initializing CoroutinePromise to nullptr in 
ScopeInfo.h.
And addressed the FIXME that wanted to tail allocate extra room at the end of 
the CoroutineBodyStmt to hold parameter move expressions. (The comment was 
longer that the code that implemented tail allocation).


https://reviews.llvm.org/D28835

Files:
  include/clang/AST/StmtCXX.h
  include/clang/Sema/ScopeInfo.h
  lib/AST/StmtCXX.cpp
  lib/Sema/SemaCoroutine.cpp
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -std=c++14 -fcoroutines-ts -verify %s
+// RUN: %clang_cc1 -std=c++14 -fcoroutines-ts -verify %s -fcxx-exceptions
 
 void no_coroutine_traits_bad_arg_await() {
   co_await a; // expected-error {{include <experimental/coroutine>}}
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -487,7 +487,7 @@
 static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
                                            FunctionScopeInfo *Fn,
                                            Expr *&Allocation,
-                                           Stmt *&Deallocation) {
+                                           Expr *&Deallocation) {
   TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
   QualType PromiseType = TInfo->getType();
   if (PromiseType->isDependentType())
@@ -564,6 +564,48 @@
   return true;
 }
 
+namespace {
+class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
+  Sema &S;
+  FunctionDecl &FD;
+  FunctionScopeInfo &Fn;
+  bool IsValid;
+  SourceLocation Loc;
+  QualType RetType;
+  SmallVector<Stmt *, 4> ParamMovesVector;
+  const bool IsPromiseDependentType;
+  CXXRecordDecl *PromiseRecordDecl;
+
+public:
+  SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
+      : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
+        IsPromiseDependentType(
+            !Fn.CoroutinePromise ||
+            Fn.CoroutinePromise->getType()->isDependentType()) {
+    this->Body = Body;
+    if (!IsPromiseDependentType) {
+      PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
+      assert(PromiseRecordDecl && "Type should have already been checked");
+    }
+    this->IsValid = makePromiseStmt() && makeInitialSuspend() &&
+                    makeFinalSuspend() && makeOnException() &&
+                    makeOnFallthrough() && makeNewAndDeleteExpr() &&
+                    makeReturnObject() && makeParamMoves();
+  }
+
+  bool isInvalid() const { return !this->IsValid; }  
+
+  bool makePromiseStmt();
+  bool makeInitialSuspend();
+  bool makeFinalSuspend();
+  bool makeNewAndDeleteExpr();
+  bool makeOnFallthrough();
+  bool makeOnException();
+  bool makeReturnObject();
+  bool makeParamMoves();
+};
+}
+
 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
   FunctionScopeInfo *Fn = getCurFunction();
   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
@@ -577,120 +619,159 @@
         << (isa<CoawaitExpr>(First) ? 0 :
             isa<CoyieldExpr>(First) ? 1 : 2);
   }
+  SubStmtBuilder Builder(*this, *FD, *Fn, Body);
+  if (Builder.isInvalid())
+    return FD->setInvalidDecl();
 
-  SourceLocation Loc = FD->getLocation();
+  // Build body for the coroutine wrapper statement.
+  Body = CoroutineBodyStmt::Create(Context, Builder);
+}
 
+bool SubStmtBuilder::makePromiseStmt() {
   // Form a declaration statement for the promise declaration, so that AST
   // visitors can more easily find it.
   StmtResult PromiseStmt =
-      ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
+      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
   if (PromiseStmt.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
+
+  this->Promise = PromiseStmt.get();
+  return true;
+}
 
+bool SubStmtBuilder::makeInitialSuspend() {
   // Form and check implicit 'co_await p.initial_suspend();' statement.
   ExprResult InitialSuspend =
-      buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
+      buildPromiseCall(S, &Fn, Loc, "initial_suspend", None);
   // FIXME: Support operator co_await here.
   if (!InitialSuspend.isInvalid())
-    InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
-  InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
+    InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get());
+  InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get());
   if (InitialSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
+
+  this->InitialSuspend = InitialSuspend.get();
+  return true;
+}
 
+bool SubStmtBuilder::makeFinalSuspend() {
   // Form and check implicit 'co_await p.final_suspend();' statement.
   ExprResult FinalSuspend =
-      buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
+      buildPromiseCall(S, &Fn, Loc, "final_suspend", None);
   // FIXME: Support operator co_await here.
   if (!FinalSuspend.isInvalid())
-    FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
-  FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
+    FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get());
+  FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get());
   if (FinalSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
 
+  this->FinalSuspend = FinalSuspend.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeNewAndDeleteExpr() {
   // Form and check allocation and deallocation calls.
-  Expr *Allocation = nullptr;
-  Stmt *Deallocation = nullptr;
-  if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation))
-    return FD->setInvalidDecl();
+  return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate,
+                                        this->Deallocate);
+}
+
+bool SubStmtBuilder::makeOnFallthrough() {
+  if (!PromiseRecordDecl)
+    return true;
+
+  // [dcl.fct.def.coroutine]/4
+  // The unqualified-ids 'return_void' and 'return_value' are looked up in
+  // the scope of class P. If both are found, the program is ill-formed.
+  DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void");
+  LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName);
+  const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl);
+
+  DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value");
+  LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName);
+  const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl);
 
-  // control flowing off the end of the coroutine.
-  // Also try to form 'p.set_exception(std::current_exception());' to handle
-  // uncaught exceptions.
-  ExprResult SetException;
   StmtResult Fallthrough;
-  if (Fn->CoroutinePromise &&
-      !Fn->CoroutinePromise->getType()->isDependentType()) {
-    CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl();
-    assert(RD && "Type should have already been checked");
-    // [dcl.fct.def.coroutine]/4
-    // The unqualified-ids 'return_void' and 'return_value' are looked up in
-    // the scope of class P. If both are found, the program is ill-formed.
-    DeclarationName RVoidDN = PP.getIdentifierInfo("return_void");
-    LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName);
-    const bool HasRVoid = LookupQualifiedName(RVoidResult, RD);
-
-    DeclarationName RValueDN = PP.getIdentifierInfo("return_value");
-    LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName);
-    const bool HasRValue = LookupQualifiedName(RValueResult, RD);
-
-    if (HasRVoid && HasRValue) {
-      // FIXME Improve this diagnostic
-      Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
-          << RD;
-      return FD->setInvalidDecl();
-    } else if (HasRVoid) {
-      // If the unqualified-id return_void is found, flowing off the end of a
-      // coroutine is equivalent to a co_return with no operand. Otherwise,
-      // flowing off the end of a coroutine results in undefined behavior.
-      Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr);
-      Fallthrough = ActOnFinishFullStmt(Fallthrough.get());
-      if (Fallthrough.isInvalid())
-        return FD->setInvalidDecl();
-    }
+  if (HasRVoid && HasRValue) {
+    // FIXME Improve this diagnostic
+    S.Diag(FD.getLocation(), diag::err_coroutine_promise_return_ill_formed)
+        << PromiseRecordDecl;
+    return false;
+  } else if (HasRVoid) {
+    // If the unqualified-id return_void is found, flowing off the end of a
+    // coroutine is equivalent to a co_return with no operand. Otherwise,
+    // flowing off the end of a coroutine results in undefined behavior.
+    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr);
+    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
+    if (Fallthrough.isInvalid())
+      return false;
+  }
 
-    // [dcl.fct.def.coroutine]/3
-    // The unqualified-id set_exception is found in the scope of P by class
-    // member access lookup (3.4.5).
-    DeclarationName SetExDN = PP.getIdentifierInfo("set_exception");
-    LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName);
-    if (LookupQualifiedName(SetExResult, RD)) {
-      // Form the call 'p.set_exception(std::current_exception())'
-      SetException = buildStdCurrentExceptionCall(*this, Loc);
-      if (SetException.isInvalid())
-        return FD->setInvalidDecl();
-      Expr *E = SetException.get();
-      SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E);
-      SetException = ActOnFinishFullExpr(SetException.get(), Loc);
-      if (SetException.isInvalid())
-        return FD->setInvalidDecl();
-    }
+  this->OnFallthrough = Fallthrough.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeOnException() {
+// Try to form 'p.set_exception(std::current_exception());' to handle
+// uncaught exceptions.
+// TODO: Post WG21 Issaquah 2016 renamed set_exception to unhandled_exception
+// TODO: and dropped exception_ptr parameter. Make it so.
+
+  if (!PromiseRecordDecl)
+    return true;
+
+    // If exceptions are disabled, don't try to build OnException.
+  if (!S.getLangOpts().CXXExceptions)
+    return true;
+
+  ExprResult SetException;
+
+  // [dcl.fct.def.coroutine]/3
+  // The unqualified-id set_exception is found in the scope of P by class
+  // member access lookup (3.4.5).
+  DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception");
+  LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName);
+  if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) {
+    // Form the call 'p.set_exception(std::current_exception())'
+    SetException = buildStdCurrentExceptionCall(S, Loc);
+    if (SetException.isInvalid())
+      return false;
+    Expr *E = SetException.get();
+    SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E);
+    SetException = S.ActOnFinishFullExpr(SetException.get(), Loc);
+    if (SetException.isInvalid())
+      return false;
   }
 
+  this->OnException = SetException.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeReturnObject() {
+
   // Build implicit 'p.get_return_object()' expression and form initialization
   // of return type from it.
   ExprResult ReturnObject =
-      buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
+      buildPromiseCall(S, &Fn, Loc, "get_return_object", None);
   if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
-  QualType RetType = FD->getReturnType();
+    return false;
+  QualType RetType = FD.getReturnType();
   if (!RetType->isDependentType()) {
     InitializedEntity Entity =
         InitializedEntity::InitializeResult(Loc, RetType, false);
-    ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
+    ReturnObject = S.PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
                                                    ReturnObject.get());
     if (ReturnObject.isInvalid())
-      return FD->setInvalidDecl();
+      return false;
   }
-  ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
+  ReturnObject = S.ActOnFinishFullExpr(ReturnObject.get(), Loc);
   if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
+    return false;
 
-  // FIXME: Perform move-initialization of parameters into frame-local copies.
-  SmallVector<Expr*, 16> ParamMoves;
+  this->ReturnValue = ReturnObject.get();
+  return true;
+}
 
-  // Build body for the coroutine wrapper statement.
-  Body = new (Context) CoroutineBodyStmt(
-      Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
-      SetException.get(), Fallthrough.get(), Allocation, Deallocation,
-      ReturnObject.get(), ParamMoves);
+bool SubStmtBuilder::makeParamMoves() {
+  // FIXME: Perform move-initialization of parameters into frame-local copies.
+  return true;
 }
Index: lib/AST/StmtCXX.cpp
===================================================================
--- lib/AST/StmtCXX.cpp
+++ lib/AST/StmtCXX.cpp
@@ -86,3 +86,27 @@
 const VarDecl *CXXForRangeStmt::getLoopVariable() const {
   return const_cast<CXXForRangeStmt *>(this)->getLoopVariable();
 }
+
+CoroutineBodyStmt *CoroutineBodyStmt::Create(
+    const ASTContext &C, CoroutineBodyStmt::CtorArgs const& Args) {
+  std::size_t Size =
+      sizeof(CoroutineBodyStmt) + Args.ParamMoves.size() * sizeof(Stmt *);
+
+  void *Mem = C.Allocate(Size, alignof(CoroutineBodyStmt));
+  return new (Mem) CoroutineBodyStmt(Args);
+}
+
+CoroutineBodyStmt::CoroutineBodyStmt(CoroutineBodyStmt::CtorArgs const &Args)
+    : Stmt(CoroutineBodyStmtClass), NumParams(Args.ParamMoves.size()) {
+  SubStmts[CoroutineBodyStmt::Body] = Args.Body;
+  SubStmts[CoroutineBodyStmt::Promise] = Args.Promise;
+  SubStmts[CoroutineBodyStmt::InitSuspend] = Args.InitialSuspend;
+  SubStmts[CoroutineBodyStmt::FinalSuspend] = Args.FinalSuspend;
+  SubStmts[CoroutineBodyStmt::OnException] = Args.OnException;
+  SubStmts[CoroutineBodyStmt::OnFallthrough] = Args.OnFallthrough;
+  SubStmts[CoroutineBodyStmt::Allocate] = Args.Allocate;
+  SubStmts[CoroutineBodyStmt::Deallocate] = Args.Deallocate;
+  SubStmts[CoroutineBodyStmt::ReturnValue] = Args.ReturnValue;
+  std::copy(Args.ParamMoves.begin(), Args.ParamMoves.end(),
+            const_cast<Stmt **>(getParamMoves().data()));
+}
\ No newline at end of file
Index: include/clang/Sema/ScopeInfo.h
===================================================================
--- include/clang/Sema/ScopeInfo.h
+++ include/clang/Sema/ScopeInfo.h
@@ -157,7 +157,7 @@
   SmallVector<ReturnStmt*, 4> Returns;
 
   /// \brief The promise object for this coroutine, if any.
-  VarDecl *CoroutinePromise;
+  VarDecl *CoroutinePromise = nullptr;
 
   /// \brief The list of coroutine control flow constructs (co_await, co_yield,
   /// co_return) that occur within the function or block. Empty if and only if
Index: include/clang/AST/StmtCXX.h
===================================================================
--- include/clang/AST/StmtCXX.h
+++ include/clang/AST/StmtCXX.h
@@ -309,27 +309,31 @@
     ReturnValue,   ///< Return value for thunk function.
     FirstParamMove ///< First offset for move construction of parameter copies.
   };
+  unsigned NumParams;
   Stmt *SubStmts[SubStmt::FirstParamMove];
 
   friend class ASTStmtReader;
 public:
-  CoroutineBodyStmt(Stmt *Body, Stmt *Promise, Stmt *InitSuspend,
-                    Stmt *FinalSuspend, Stmt *OnException, Stmt *OnFallthrough,
-                    Expr *Allocate, Stmt *Deallocate,
-                    Expr *ReturnValue, ArrayRef<Expr *> ParamMoves)
-      : Stmt(CoroutineBodyStmtClass) {
-    SubStmts[CoroutineBodyStmt::Body] = Body;
-    SubStmts[CoroutineBodyStmt::Promise] = Promise;
-    SubStmts[CoroutineBodyStmt::InitSuspend] = InitSuspend;
-    SubStmts[CoroutineBodyStmt::FinalSuspend] = FinalSuspend;
-    SubStmts[CoroutineBodyStmt::OnException] = OnException;
-    SubStmts[CoroutineBodyStmt::OnFallthrough] = OnFallthrough;
-    SubStmts[CoroutineBodyStmt::Allocate] = Allocate;
-    SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate;
-    SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue;
-    // FIXME: Tail-allocate space for parameter move expressions and store them.
-    assert(ParamMoves.empty() && "not implemented yet");
-  }
+
+  struct CtorArgs {
+    Stmt *Body = nullptr;
+    Stmt *Promise = nullptr;
+    Expr *InitialSuspend = nullptr;
+    Expr *FinalSuspend = nullptr;
+    Stmt *OnException = nullptr;
+    Stmt *OnFallthrough = nullptr;
+    Expr *Allocate = nullptr;
+    Expr *Deallocate = nullptr;
+    Stmt *ReturnValue = nullptr;
+    ArrayRef<Stmt *> ParamMoves;
+  };
+
+private:
+
+  CoroutineBodyStmt(CtorArgs const& Args);  
+
+public:
+  static CoroutineBodyStmt *Create(const ASTContext &C, CtorArgs const &Args);
 
   /// \brief Retrieve the body of the coroutine as written. This will be either
   /// a CompoundStmt or a TryStmt.
@@ -356,6 +360,9 @@
   Expr *getReturnValueInit() const {
     return cast<Expr>(SubStmts[SubStmt::ReturnValue]);
   }
+  ArrayRef<Stmt const *> getParamMoves() const {
+    return {SubStmts + SubStmt::FirstParamMove, NumParams};
+  }
 
   SourceLocation getLocStart() const LLVM_READONLY {
     return getBody()->getLocStart();
@@ -365,7 +372,8 @@
   }
 
   child_range children() {
-    return child_range(SubStmts, SubStmts + SubStmt::FirstParamMove);
+    return child_range(SubStmts,
+                       SubStmts + SubStmt::FirstParamMove + NumParams);
   }
 
   static bool classof(const Stmt *T) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to