modocache created this revision.
modocache added reviewers: rsmith, GorNishanov, eric_niebler.
Herald added a subscriber: EricWF.
Use corutine function arguments to initialize a promise type, but only
if the promise type defines a constructor that takes those arguments.
Otherwise, fall back to the default constructor.
Test Plan: check-clang
Repository:
rC Clang
https://reviews.llvm.org/D41820
Files:
include/clang/Sema/ScopeInfo.h
include/clang/Sema/Sema.h
lib/Sema/CoroutineStmtBuilder.h
lib/Sema/ScopeInfo.cpp
lib/Sema/SemaCoroutine.cpp
lib/Sema/TreeTransform.h
test/CodeGenCoroutines/coro-alloc.cpp
Index: test/CodeGenCoroutines/coro-alloc.cpp
===================================================================
--- test/CodeGenCoroutines/coro-alloc.cpp
+++ test/CodeGenCoroutines/coro-alloc.cpp
@@ -193,3 +193,26 @@
// CHECK: ret i32 %[[LoadRet]]
co_return;
}
+
+struct promise_matching_constructor {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double> {
+ struct promise_type {
+ promise_type(promise_matching_constructor, int, float, double) {}
+ promise_type() = delete;
+ void get_return_object() {}
+ suspend_always initial_suspend() { return {}; }
+ suspend_always final_suspend() { return {}; }
+ void return_void() {}
+ };
+};
+
+// CHECK-LABEL: f5(
+extern "C" void f5(promise_matching_constructor, int, float, double) {
+ // CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4
+ // CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4
+ // CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8
+ // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double>::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
+ co_return;
+}
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -6956,6 +6956,8 @@
// The new CoroutinePromise object needs to be built and put into the current
// FunctionScopeInfo before any transformations or rebuilding occurs.
+ if (!SemaRef.buildCoroutineParameterMoves(FD->getLocation()))
+ return StmtError();
auto *Promise = SemaRef.buildCoroutinePromise(FD->getLocation());
if (!Promise)
return StmtError();
@@ -7046,8 +7048,6 @@
Builder.ReturnStmt = Res.get();
}
}
- if (!Builder.buildParameterMoves())
- return StmtError();
return getDerived().RebuildCoroutineBodyStmt(Builder);
}
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -472,6 +472,69 @@
return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
}
+// Create a static_cast\<T&&>(expr).
+static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
+ if (T.isNull())
+ T = E->getType();
+ QualType TargetType = S.BuildReferenceType(
+ T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
+ SourceLocation ExprLoc = E->getLocStart();
+ TypeSourceInfo *TargetLoc =
+ S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
+
+ return S
+ .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
+ SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
+ .get();
+}
+
+/// \brief Build a variable declaration for move parameter.
+static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
+ IdentifierInfo *II) {
+ TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
+ VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
+ TInfo, SC_None);
+ Decl->setImplicit();
+ return Decl;
+}
+
+// Build statements that move coroutine function parameters to the coroutine
+// frame, and store them on the function scope info.
+bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
+ assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
+ auto *FD = cast<FunctionDecl>(CurContext);
+
+ auto *ScopeInfo = getCurFunction();
+ assert(ScopeInfo->CoroutineParameterMoves.empty() &&
+ "Should not build parameter moves twice");
+
+ for (auto *PD : FD->parameters()) {
+ if (PD->getType()->isDependentType())
+ continue;
+
+ // No need to copy scalars, LLVM will take care of them.
+ if (PD->getType()->getAsCXXRecordDecl()) {
+ ExprResult PDRefExpr = BuildDeclRefExpr(
+ PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope?
+ if (PDRefExpr.isInvalid())
+ return false;
+
+ Expr *CExpr = castForMoving(*this, PDRefExpr.get());
+
+ auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
+ AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
+
+ // Convert decl to a statement.
+ StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
+ if (Stmt.isInvalid())
+ return false;
+
+ ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
+ }
+ }
+ return true;
+}
+
VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
auto *FD = cast<FunctionDecl>(CurContext);
@@ -494,7 +557,64 @@
CheckVariableDeclarationType(VD);
if (VD->isInvalidDecl())
return nullptr;
- ActOnUninitializedDecl(VD);
+
+ auto *ScopeInfo = getCurFunction();
+ // Build a list of arguments, based on the coroutine functions arguments,
+ // that will be passed to the promise type's constructor.
+ llvm::SmallVector<Expr *, 4> CtorArgExprs;
+ for (auto *PD : FD->parameters()) {
+ if (PD->getType()->isDependentType())
+ continue;
+
+ auto RefExpr = ExprEmpty();
+ auto Moves = ScopeInfo->CoroutineParameterMoves;
+ if (Moves.find(PD) != Moves.end()) {
+ // If a reference to the function parameter exists in the coroutine
+ // frame, use that reference.
+ auto *VD = cast<VarDecl>(cast<DeclStmt>(Moves[PD])->getSingleDecl());
+ RefExpr = BuildDeclRefExpr(VD, VD->getType(), ExprValueKind::VK_LValue,
+ FD->getLocation());
+ } else {
+ // If the function parameter doesn't exist in the coroutine frame, it
+ // must be a scalar value. Use it directly.
+ assert(!PD->getType()->getAsCXXRecordDecl() &&
+ "Non-scalar types should have been moved and inserted into the "
+ "parameter moves map");
+ RefExpr =
+ BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
+ ExprValueKind::VK_LValue, FD->getLocation());
+ }
+
+ if (RefExpr.isInvalid())
+ return nullptr;
+ CtorArgExprs.push_back(RefExpr.get());
+ }
+
+ // Create an initialization sequence for the promise type using the
+ // constructor arguments, wrapped in a parenthesized list expression.
+ Expr *PLE = new (Context) ParenListExpr(Context, FD->getLocation(),
+ CtorArgExprs, FD->getLocation());
+ InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
+ InitializationKind Kind = InitializationKind::CreateForInit(
+ VD->getLocation(), /*DirectInit=*/true, PLE);
+ InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
+ /*TopLevelOfInitList=*/false,
+ /*TreatUnavailableAsInvalid=*/false);
+
+ // Attempt to initialize the promise type with the arguments.
+ // If that fails, fall back to the promise type's default constructor.
+ if (InitSeq) {
+ ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
+ if (Result.isInvalid()) {
+ VD->setInvalidDecl();
+ } else if (Result.get()) {
+ VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
+ VD->setInitStyle(VarDecl::CallInit);
+ CheckCompleteVariableDeclaration(VD);
+ }
+ } else
+ ActOnUninitializedDecl(VD);
+
FD->addDecl(VD);
assert(!VD->isInvalidDecl());
return VD;
@@ -518,6 +638,9 @@
if (ScopeInfo->CoroutinePromise)
return ScopeInfo;
+ if (!S.buildCoroutineParameterMoves(Loc))
+ return nullptr;
+
ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
if (!ScopeInfo->CoroutinePromise)
return nullptr;
@@ -861,6 +984,11 @@
!Fn.CoroutinePromise ||
Fn.CoroutinePromise->getType()->isDependentType()) {
this->Body = Body;
+
+ for (auto KV : Fn.CoroutineParameterMoves)
+ this->ParamMovesVector.push_back(KV.second);
+ this->ParamMoves = this->ParamMovesVector;
+
if (!IsPromiseDependentType) {
PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
assert(PromiseRecordDecl && "Type should have already been checked");
@@ -870,7 +998,7 @@
bool CoroutineStmtBuilder::buildStatements() {
assert(this->IsValid && "coroutine already invalid");
- this->IsValid = makeReturnObject() && makeParamMoves();
+ this->IsValid = makeReturnObject();
if (this->IsValid && !IsPromiseDependentType)
buildDependentStatements();
return this->IsValid;
@@ -886,12 +1014,6 @@
return this->IsValid;
}
-bool CoroutineStmtBuilder::buildParameterMoves() {
- assert(this->IsValid && "coroutine already invalid");
- assert(this->ParamMoves.empty() && "param moves already built");
- return this->IsValid = makeParamMoves();
-}
-
bool CoroutineStmtBuilder::makePromiseStmt() {
// Form a declaration statement for the promise declaration, so that AST
// visitors can more easily find it.
@@ -1288,66 +1410,6 @@
return true;
}
-// Create a static_cast\<T&&>(expr).
-static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
- if (T.isNull())
- T = E->getType();
- QualType TargetType = S.BuildReferenceType(
- T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
- SourceLocation ExprLoc = E->getLocStart();
- TypeSourceInfo *TargetLoc =
- S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
-
- return S
- .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
- SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
- .get();
-}
-
-
-/// \brief Build a variable declaration for move parameter.
-static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
- IdentifierInfo *II) {
- TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
- VarDecl *Decl =
- VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, TInfo, SC_None);
- Decl->setImplicit();
- return Decl;
-}
-
-bool CoroutineStmtBuilder::makeParamMoves() {
- for (auto *paramDecl : FD.parameters()) {
- auto Ty = paramDecl->getType();
- if (Ty->isDependentType())
- continue;
-
- // No need to copy scalars, llvm will take care of them.
- if (Ty->getAsCXXRecordDecl()) {
- ExprResult ParamRef =
- S.BuildDeclRefExpr(paramDecl, paramDecl->getType(),
- ExprValueKind::VK_LValue, Loc); // FIXME: scope?
- if (ParamRef.isInvalid())
- return false;
-
- Expr *RCast = castForMoving(S, ParamRef.get());
-
- auto D = buildVarDecl(S, Loc, Ty, paramDecl->getIdentifier());
- S.AddInitializerToDecl(D, RCast, /*DirectInit=*/true);
-
- // Convert decl to a statement.
- StmtResult Stmt = S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(D), Loc, Loc);
- if (Stmt.isInvalid())
- return false;
-
- ParamMovesVector.push_back(Stmt.get());
- }
- }
-
- // Convert to ArrayRef in CtorArgs structure that builder inherits from.
- ParamMoves = ParamMovesVector;
- return true;
-}
-
StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
if (!Res)
Index: lib/Sema/ScopeInfo.cpp
===================================================================
--- lib/Sema/ScopeInfo.cpp
+++ lib/Sema/ScopeInfo.cpp
@@ -43,6 +43,7 @@
// Coroutine state
FirstCoroutineStmtLoc = SourceLocation();
CoroutinePromise = nullptr;
+ CoroutineParameterMoves.clear();
NeedsCoroutineSuspends = true;
CoroutineSuspends.first = nullptr;
CoroutineSuspends.second = nullptr;
Index: lib/Sema/CoroutineStmtBuilder.h
===================================================================
--- lib/Sema/CoroutineStmtBuilder.h
+++ lib/Sema/CoroutineStmtBuilder.h
@@ -51,9 +51,6 @@
/// name lookup.
bool buildDependentStatements();
- /// \brief Build just parameter moves. To use for rebuilding in TreeTransform.
- bool buildParameterMoves();
-
bool isInvalid() const { return !this->IsValid; }
private:
@@ -65,7 +62,6 @@
bool makeReturnObject();
bool makeGroDeclAndReturnStmt();
bool makeReturnOnAllocFailure();
- bool makeParamMoves();
};
} // end namespace clang
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8473,6 +8473,7 @@
StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E,
bool IsImplicit = false);
StmtResult BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs);
+ bool buildCoroutineParameterMoves(SourceLocation Loc);
VarDecl *buildCoroutinePromise(SourceLocation Loc);
void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
Index: include/clang/Sema/ScopeInfo.h
===================================================================
--- include/clang/Sema/ScopeInfo.h
+++ include/clang/Sema/ScopeInfo.h
@@ -22,6 +22,7 @@
#include "clang/Sema/CleanupInfo.h"
#include "clang/Sema/Ownership.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
@@ -172,6 +173,10 @@
/// \brief The promise object for this coroutine, if any.
VarDecl *CoroutinePromise = nullptr;
+ /// \brief A mapping between the coroutine function parameters that were moved
+ /// to the coroutine frame, and their move statements.
+ llvm::SmallMapVector<ParmVarDecl *, Stmt *, 4> CoroutineParameterMoves;
+
/// \brief The initial and final coroutine suspend points.
std::pair<Stmt *, Stmt *> CoroutineSuspends;
_______________________________________________
cfe-commits mailing list
[email protected]
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits