haberman updated this revision to Diff 333914.
haberman added a comment.

- Updated formatting.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D99517/new/

https://reviews.llvm.org/D99517

Files:
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/AttrDocs.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/ScopeInfo.h
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGCall.cpp
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CGStmt.cpp
  clang/lib/CodeGen/CodeGenFunction.h
  clang/lib/Sema/JumpDiagnostics.cpp
  clang/lib/Sema/Sema.cpp
  clang/lib/Sema/SemaStmt.cpp
  clang/lib/Sema/SemaStmtAttr.cpp
  clang/test/CodeGen/attr-musttail.cpp
  clang/test/Sema/attr-musttail.cpp

Index: clang/test/Sema/attr-musttail.cpp
===================================================================
--- /dev/null
+++ clang/test/Sema/attr-musttail.cpp
@@ -0,0 +1,137 @@
+// RUN: %clang_cc1 -verify -fsyntax-only %s
+
+int Bar();
+
+int Func1() {
+  [[clang::musttail(1, 2)]] Bar(); // expected-error {{'musttail' attribute takes no arguments}}
+  [[clang::musttail]] Bar();       // expected-error {{musttail attribute can only be applied to a return statement}}
+  [[clang::musttail]] return 5;    // expected-error {{musttail attribute requires that the return value is a function call}}
+  [[clang::musttail]] return Bar();
+}
+
+int f();
+
+[[clang::musttail]] static int i = f(); // expected-error {{'musttail' attribute cannot be applied to a declaration}}
+
+long g(int x);
+int h(long x);
+
+class Foo {
+public:
+  int MemberFunction(int x);
+  int MemberFunction2();
+};
+
+int Func2(int x) {
+  // Param arity differs.
+  [[clang::musttail]] return Bar(); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+  // Return type differs.
+  [[clang::musttail]] return g(x); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+  // Param type differs.
+  [[clang::musttail]] return h(x); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+  // "this" pointer differs.
+  Foo foo;
+  [[clang::musttail]] return foo.MemberFunction(x); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+}
+
+int j = 0;
+
+class HasNonTrivialDestructor {
+public:
+  ~HasNonTrivialDestructor() { j--; }
+};
+
+int ReturnsInt(int x);
+
+int Func3(int x) {
+  HasNonTrivialDestructor foo;              // expected-note {{jump exits scope of variable with non-trivial destructor}}
+  [[clang::musttail]] return ReturnsInt(x); // expected-error {{musttail attribute does not allow any variables in scope that require destruction}}
+}
+
+int Func4(int x) {
+  HasNonTrivialDestructor foo; // expected-note {{jump exits scope of variable with non-trivial destructor}}
+  {
+    [[clang::musttail]] return ReturnsInt(x); // expected-error {{musttail attribute does not allow any variables in scope that require destruction}}
+  }
+}
+
+int NonTrivialParam(HasNonTrivialDestructor x);
+
+int Func5(HasNonTrivialDestructor x) {
+  [[clang::musttail]] return NonTrivialParam(x); // expected-error {{musttail attribute requires that the return value is a function call, which must not create or destroy any temporaries}}
+}
+
+HasNonTrivialDestructor ReturnsNonTrivialValue();
+
+HasNonTrivialDestructor Func6() {
+  [[clang::musttail]] return ReturnsNonTrivialValue(); // expected-error {{musttail attribute requires that the return value is a function call, which must not create or destroy any temporaries}}
+}
+
+int Func8(Foo *foo, int (Foo::*p_mem)()) {
+  [[clang::musttail]] return (foo->*p_mem)(); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+}
+
+int Func10(Foo *foo) {
+  [[clang::musttail]] return foo->MemberFunction2(); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+}
+
+void Func7() {
+  HasNonTrivialDestructor foo;
+  class Nested {
+    __attribute__((noinline)) static int NestedMethod(int x) {
+      // This is ok.
+      [[clang::musttail]] return ReturnsInt(x);
+    }
+  };
+}
+
+struct Data {
+  int (Data::*pmf)();
+  typedef int Func(Data *);
+  static void StaticMethod();
+  void NonStaticMethod() {
+    [[clang::musttail]] return StaticMethod(); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+  }
+
+  void BadPMF() {
+    // We need to specially handle this, otherwise it can crash the compiler.
+    [[clang::musttail]] return ((*this)->*pmf)(); // expected-error {{left hand operand to ->* must be a pointer to class compatible with the right hand operand, but is 'Data'}}
+  }
+};
+
+Data data_global;
+
+void Data::StaticMethod() {
+  [[clang::musttail]] return data_global.NonStaticMethod(); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+}
+
+template <class T>
+T TemplateFunc(T x) {
+  return x ? 5 : 10;
+}
+
+int Func10(int x) {
+  [[clang::musttail]] return TemplateFunc<int>(x);
+}
+
+template <class T>
+T BadTemplateFunc(T x) {
+  [[clang::musttail]] return TemplateFunc<int>(x); // expected-error {{musttail attribute requires that caller and callee have identical parameter types and return types}}
+}
+
+long Func11(long x) {
+  return BadTemplateFunc<long>(x); // expected-note {{in instantiation of}}
+}
+
+void IntParam(int x);
+void ObjPtrParam(HasNonTrivialDestructor *x);
+
+void Func12(int x) {
+  HasNonTrivialDestructor vla[x]; // expected-note {{jump exits scope of variable with non-trivial destructor}}
+  ObjPtrParam(&vla[0]);
+  [[clang::musttail]] return IntParam(x); // expected-error {{musttail attribute does not allow any variables in scope that require destruction}}
+}
+
+void ObjParam(HasNonTrivialDestructor obj) {
+  [[clang::musttail]] return ObjParam(obj); // expected-error {{musttail attribute requires that the return value is a function call, which must not create or destroy any temporaries}}
+}
Index: clang/test/CodeGen/attr-musttail.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGen/attr-musttail.cpp
@@ -0,0 +1,140 @@
+// RUN: %clang_cc1 -S -emit-llvm %s -triple x86_64-unknown-linux-gnu -o - | FileCheck %s
+
+int Bar(int);
+int Baz(int);
+
+int Func1(int x) {
+  if (x) {
+    [[clang::musttail]] return Bar(x); // CHECK: %call = musttail call i32 @_Z3Bari(i32 %1)
+  } else {
+    [[clang::musttail]] return Baz(x); // CHECK: %call1 = musttail call i32 @_Z3Bazi(i32 %3)
+  }
+}
+
+int Func2(int x) {
+  {
+    [[clang::musttail]] return Bar(Bar(x));
+  }
+}
+
+// CHECK: %call1 = musttail call i32 @_Z3Bari(i32 %call)
+
+class Foo {
+public:
+  static int StaticMethod(int x);
+  int MemberFunction(int x);
+  int TailFrom(int x);
+  int TailFrom2(int x);
+  int TailFrom3(int x);
+};
+
+int Foo::TailFrom(int x) {
+  [[clang::musttail]] return MemberFunction(x);
+}
+
+// CHECK: %call = musttail call i32 @_ZN3Foo14MemberFunctionEi(%class.Foo* nonnull dereferenceable(1) %this1, i32 %0)
+
+int Func3(int x) {
+  [[clang::musttail]] return Foo::StaticMethod(x);
+}
+
+// CHECK: %call = musttail call i32 @_ZN3Foo12StaticMethodEi(i32 %0)
+
+int Func4(int x) {
+  Foo foo; // Object with trivial destructor.
+  [[clang::musttail]] return foo.StaticMethod(x);
+}
+
+// CHECK: %call = musttail call i32 @_ZN3Foo12StaticMethodEi(i32 %0)
+
+int (Foo::*pmf)(int);
+
+int Foo::TailFrom2(int x) {
+  [[clang::musttail]] return ((*this).*pmf)(x);
+}
+
+// CHECK: %call = musttail call i32 %8(%class.Foo* nonnull dereferenceable(1) %this.adjusted, i32 %9)
+
+int Foo::TailFrom3(int x) {
+  [[clang::musttail]] return (this->*pmf)(x);
+}
+
+// CHECK: %call = musttail call i32 %8(%class.Foo* nonnull dereferenceable(1) %this.adjusted, i32 %9)
+
+void ReturnsVoid();
+
+void Func5() {
+  [[clang::musttail]] return ReturnsVoid();
+}
+
+// CHECK: musttail call void @_Z11ReturnsVoidv()
+
+class HasTrivialDestructor {};
+
+int ReturnsInt(int x);
+
+int Func6(int x) {
+  HasTrivialDestructor foo;
+  [[clang::musttail]] return ReturnsInt(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z10ReturnsInti(i32 %0)
+
+struct Data {
+  int (*fptr)(Data *);
+};
+
+int Func7(Data *data) {
+  [[clang::musttail]] return data->fptr(data);
+}
+
+// CHECK: %call = musttail call i32 %1(%struct.Data* %2)
+
+template <class T>
+T TemplateFunc(T) {
+  return 5;
+}
+
+int Func9(int x) {
+  [[clang::musttail]] return TemplateFunc<int>(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z12TemplateFuncIiET_S0_(i32 %0)
+
+template <class T>
+int Func10(int x) {
+  T t;
+  [[clang::musttail]] return Bar(x);
+}
+
+int Func11(int x) {
+  return Func10<int>(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z3Bari(i32 %0)
+
+template <class T>
+T Func12(T x) {
+  [[clang::musttail]] return ::Bar(x);
+}
+
+int Func13(int x) {
+  return Func12<int>(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z3Bari(i32 %0)
+
+int Func14(int x) {
+  int vla[x];
+  [[clang::musttail]] return Bar(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z3Bari(i32 %3)
+
+void TrivialDestructorParam(HasTrivialDestructor obj);
+
+void Func14(HasTrivialDestructor obj) {
+  [[clang::musttail]] return TrivialDestructorParam(obj);
+}
+
+// CHECK: musttail call void @_Z22TrivialDestructorParam20HasTrivialDestructor()
Index: clang/lib/Sema/SemaStmtAttr.cpp
===================================================================
--- clang/lib/Sema/SemaStmtAttr.cpp
+++ clang/lib/Sema/SemaStmtAttr.cpp
@@ -12,6 +12,7 @@
 
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/EvaluatedExprVisitor.h"
+#include "clang/AST/ParentMapContext.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Sema/DelayedDiagnostic.h"
@@ -209,6 +210,16 @@
   return ::new (S.Context) NoMergeAttr(S.Context, A);
 }
 
+static Attr *handleMustTailAttr(Sema &S, Stmt *St, const ParsedAttr &A,
+                                SourceRange Range) {
+  MustTailAttr MTA(S.Context, A);
+  if (S.CheckAttrNoArgs(A))
+    return nullptr;
+
+  // Validation is in Sema::ActOnAttributedStmt().
+  return ::new (S.Context) MustTailAttr(S.Context, A);
+}
+
 static Attr *handleLikely(Sema &S, Stmt *St, const ParsedAttr &A,
                           SourceRange Range) {
 
@@ -425,6 +436,8 @@
     return handleSuppressAttr(S, St, A, Range);
   case ParsedAttr::AT_NoMerge:
     return handleNoMergeAttr(S, St, A, Range);
+  case ParsedAttr::AT_MustTail:
+    return handleMustTailAttr(S, St, A, Range);
   case ParsedAttr::AT_Likely:
     return handleLikely(S, St, A, Range);
   case ParsedAttr::AT_Unlikely:
Index: clang/lib/Sema/SemaStmt.cpp
===================================================================
--- clang/lib/Sema/SemaStmt.cpp
+++ clang/lib/Sema/SemaStmt.cpp
@@ -558,11 +558,128 @@
 StmtResult Sema::ActOnAttributedStmt(SourceLocation AttrLoc,
                                      ArrayRef<const Attr*> Attrs,
                                      Stmt *SubStmt) {
+  for (const auto *A : Attrs) {
+    if (A->getKind() == attr::MustTail) {
+      if (!checkMustTailAttr(SubStmt, *A)) {
+        return SubStmt;
+      }
+      setFunctionHasMustTail();
+    }
+  }
+
   // Fill in the declaration and return it.
   AttributedStmt *LS = AttributedStmt::Create(Context, AttrLoc, Attrs, SubStmt);
   return LS;
 }
 
+bool Sema::checkMustTailAttr(const Stmt *St, const Attr &MTA) {
+  const ReturnStmt *R = dyn_cast<const ReturnStmt>(St);
+
+  if (!R) {
+    Diag(St->getBeginLoc(), diag::err_musttail_only_on_return)
+        << MTA.getSpelling();
+    return false;
+  }
+
+  const Expr *Ex = R->getRetValue();
+
+  // We don't actually support tail calling through an implicit cast (we require
+  // the return types to match), but getting the actual function call will let
+  // us give a better error message about the return type mismatch.
+  if (const ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(Ex)) {
+    Ex = ICE->getSubExpr();
+  }
+
+  const CallExpr *CE = dyn_cast<CallExpr>(Ex);
+
+  if (!CE) {
+    Diag(St->getBeginLoc(), diag::err_musttail_needs_call) << MTA.getSpelling();
+    return false;
+  }
+
+  const Expr *Callee = CE->getCallee()->IgnoreParens();
+  const FunctionProtoType *CalleeType;
+  QualType CalleeThis;
+
+  const FunctionDecl *CallerDecl = dyn_cast<FunctionDecl>(CurContext);
+  if (!CallerDecl) {
+    Diag(St->getBeginLoc(), diag::err_musttail_only_from_function)
+        << MTA.getSpelling();
+    return false;
+  } else if (CallerDecl->isDependentContext()) {
+    // We have to suspend our check until template instantiation time.
+    return true;
+  }
+
+  // Detect member function calls, inspired by Expr::findBoundMemberType().
+  // We can't call Expr::findBoundMemberType() directly because we also need the
+  // type of "this".
+  if (const MemberExpr *mem = dyn_cast<MemberExpr>(Callee)) {
+    // Call is: obj.method() or obj->method()
+    const CXXMethodDecl *CMD = dyn_cast<CXXMethodDecl>(mem->getMemberDecl());
+    assert(CMD && !CMD->isStatic());
+    CalleeThis = CMD->getThisType()->getPointeeType();
+    CalleeType = CMD->getType()->castAs<FunctionProtoType>();
+  } else if (const BinaryOperator *op = dyn_cast<BinaryOperator>(Callee)) {
+    // Call is: obj->*method_ptr or obj.*method_ptr
+    const MemberPointerType *MPT =
+        op->getRHS()->getType()->castAs<MemberPointerType>();
+    CalleeThis = QualType(MPT->getClass(), 0);
+    CalleeType = MPT->getPointeeType()->castAs<FunctionProtoType>();
+  } else {
+    // Regular non-member function call.
+    assert(!Callee->isBoundMemberFunction((Context)));
+    QualType FunctionType = Callee->getType()->getPointeeType();
+    if (FunctionType.isNull()) {
+      // This call is ill-formed, for example (obj->*method_ptr)() where the
+      // method pointer type doesn't match.  There has already been a
+      // diagnostic so we don't emit one here.
+      assert(hasUncompilableErrorOccurred() && "expected previous error");
+      return false;
+    }
+    CalleeType = FunctionType->getUnqualifiedDesugaredType()
+                     ->castAs<FunctionProtoType>();
+  }
+
+  auto GetThisType = [](const FunctionDecl *FD) -> QualType {
+    const CXXMethodDecl *CMD = dyn_cast<const CXXMethodDecl>(FD);
+    return CMD && !CMD->isStatic() ? CMD->getThisType()->getPointeeType()
+                                   : QualType();
+  };
+
+  auto TypesMatch = [this](QualType a, QualType b) -> bool {
+    if (a == QualType() || b == QualType()) {
+      return a == b;
+    } else {
+      return Context.hasSimilarType(a, b);
+    }
+  };
+
+  bool types_match =
+      TypesMatch(CallerDecl->getReturnType(), CalleeType->getReturnType()) &&
+      TypesMatch(GetThisType(CallerDecl), CalleeThis) &&
+      CallerDecl->param_size() == CalleeType->getNumParams();
+  if (types_match) {
+    ArrayRef<QualType> callee_params = CalleeType->getParamTypes();
+    ArrayRef<ParmVarDecl *> caller_params = CallerDecl->parameters();
+    size_t n = CallerDecl->param_size();
+    for (size_t i = 0; i < n; i++) {
+      if (!TypesMatch(callee_params[i], caller_params[i]->getType())) {
+        types_match = false;
+        break;
+      }
+    }
+  }
+
+  if (!types_match) {
+    Diag(St->getBeginLoc(), diag::err_musttail_return_type_mismatch)
+        << MTA.getSpelling();
+    return false;
+  }
+
+  return true;
+}
+
 namespace {
 class CommaVisitor : public EvaluatedExprVisitor<CommaVisitor> {
   typedef EvaluatedExprVisitor<CommaVisitor> Inherited;
Index: clang/lib/Sema/Sema.cpp
===================================================================
--- clang/lib/Sema/Sema.cpp
+++ clang/lib/Sema/Sema.cpp
@@ -2079,6 +2079,11 @@
     FunctionScopes.back()->setHasIndirectGoto();
 }
 
+void Sema::setFunctionHasMustTail() {
+  if (!FunctionScopes.empty())
+    FunctionScopes.back()->setHasMustTail();
+}
+
 BlockScopeInfo *Sema::getCurBlock() {
   if (FunctionScopes.empty())
     return nullptr;
Index: clang/lib/Sema/JumpDiagnostics.cpp
===================================================================
--- clang/lib/Sema/JumpDiagnostics.cpp
+++ clang/lib/Sema/JumpDiagnostics.cpp
@@ -11,13 +11,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Sema/SemaInternal.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
 #include "clang/AST/StmtOpenMP.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/BitVector.h"
 using namespace clang;
 
@@ -29,6 +30,10 @@
 ///    int a[n];
 ///  L:
 ///
+/// We also detect jumps out of protected scopes when it's not possible to do
+/// cleanups properly.  Indirect jumps and ASM jumps can't do cleanups because
+/// the target is unknown.  Return statements with \c [musttail] cannot handle
+/// any cleanups due to the nature of a tail call.
 class JumpScopeChecker {
   Sema &S;
 
@@ -68,6 +73,7 @@
 
   SmallVector<Stmt*, 4> IndirectJumps;
   SmallVector<Stmt*, 4> AsmJumps;
+  SmallVector<AttributedStmt *, 4> MustTailStmts;
   SmallVector<LabelDecl*, 4> IndirectJumpTargets;
   SmallVector<LabelDecl*, 4> AsmJumpTargets;
 public:
@@ -81,6 +87,7 @@
 
   void VerifyJumps();
   void VerifyIndirectOrAsmJumps(bool IsAsmGoto);
+  void VerifyMustTailStmts();
   void NoteJumpIntoScopes(ArrayRef<unsigned> ToScopes);
   void DiagnoseIndirectOrAsmJump(Stmt *IG, unsigned IGScope, LabelDecl *Target,
                                  unsigned TargetScope);
@@ -88,6 +95,7 @@
                  unsigned JumpDiag, unsigned JumpDiagWarning,
                  unsigned JumpDiagCXX98Compat);
   void CheckGotoStmt(GotoStmt *GS);
+  const Attr *GetMustTailAttr(AttributedStmt *AS);
 
   unsigned GetDeepestCommonScope(unsigned A, unsigned B);
 };
@@ -109,6 +117,7 @@
   VerifyJumps();
   VerifyIndirectOrAsmJumps(false);
   VerifyIndirectOrAsmJumps(true);
+  VerifyMustTailStmts();
 }
 
 /// GetDeepestCommonScope - Finds the innermost scope enclosing the
@@ -580,6 +589,15 @@
     LabelAndGotoScopes[S] = ParentScope;
     break;
 
+  case Stmt::AttributedStmtClass: {
+    AttributedStmt *AS = cast<AttributedStmt>(S);
+    if (GetMustTailAttr(AS)) {
+      LabelAndGotoScopes[AS] = ParentScope;
+      MustTailStmts.push_back(AS);
+    }
+    break;
+  }
+
   default:
     if (auto *ED = dyn_cast<OMPExecutableDirective>(S)) {
       if (!ED->isStandaloneDirective()) {
@@ -971,6 +989,28 @@
   }
 }
 
+void JumpScopeChecker::VerifyMustTailStmts() {
+  for (AttributedStmt *AS : MustTailStmts) {
+    for (unsigned I = LabelAndGotoScopes[AS]; I; I = Scopes[I].ParentScope) {
+      if (Scopes[I].OutDiag) {
+        S.Diag(AS->getBeginLoc(), diag::err_musttail_no_destruction)
+            << GetMustTailAttr(AS)->getSpelling();
+        S.Diag(Scopes[I].Loc, Scopes[I].OutDiag);
+      }
+    }
+  }
+}
+
+const Attr *JumpScopeChecker::GetMustTailAttr(AttributedStmt *AS) {
+  for (const auto *A : AS->getAttrs()) {
+    if (A->getKind() == attr::MustTail) {
+      return A;
+    }
+  }
+
+  return nullptr;
+}
+
 void Sema::DiagnoseInvalidJumps(Stmt *Body) {
   (void)JumpScopeChecker(Body, *this);
 }
Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -517,6 +517,11 @@
   /// True if the current statement has nomerge attribute.
   bool InNoMergeAttributedStmt = false;
 
+  bool InMustTailCallExpr = false;
+
+  /// True if the current statement has musttail attribute.
+  const CallExpr *MustTailCall = nullptr;
+
   /// True if the current function should be marked mustprogress.
   bool FnIsMustProgress = false;
 
Index: clang/lib/CodeGen/CGStmt.cpp
===================================================================
--- clang/lib/CodeGen/CGStmt.cpp
+++ clang/lib/CodeGen/CGStmt.cpp
@@ -16,6 +16,8 @@
 #include "CodeGenModule.h"
 #include "TargetInfo.h"
 #include "clang/AST/Attr.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/Stmt.h"
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/DiagnosticSema.h"
@@ -643,12 +645,21 @@
 
 void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool nomerge = false;
-  for (const auto *A : S.getAttrs())
+  const CallExpr *musttail = nullptr;
+  for (const auto *A : S.getAttrs()) {
     if (A->getKind() == attr::NoMerge) {
       nomerge = true;
-      break;
     }
+    if (A->getKind() == attr::MustTail) {
+      const Stmt *Sub = S.getSubStmt();
+      const ReturnStmt *R = dyn_cast<ReturnStmt>(Sub);
+      assert(R && "musttail should only be on ReturnStmt");
+      musttail = dyn_cast<CallExpr>(R->getRetValue());
+      assert(musttail && "musttail must return CallExpr");
+    }
+  }
   SaveAndRestore<bool> save_nomerge(InNoMergeAttributedStmt, nomerge);
+  SaveAndRestore<const CallExpr *> save_musttail(MustTailCall, musttail);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Support/ConvertUTF.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/Path.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Transforms/Utils/SanitizerStats.h"
 
 #include <string>
@@ -4824,6 +4825,8 @@
 
 RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
                                      ReturnValueSlot ReturnValue) {
+  SaveAndRestore<bool> save_musttail(InMustTailCallExpr, E == MustTailCall);
+
   // Builtins never have block type.
   if (E->getCallee()->getType()->isBlockPointerType())
     return EmitBlockCallExpr(E, ReturnValue);
Index: clang/lib/CodeGen/CGCall.cpp
===================================================================
--- clang/lib/CodeGen/CGCall.cpp
+++ clang/lib/CodeGen/CGCall.cpp
@@ -5252,10 +5252,12 @@
   if (CGM.getLangOpts().ObjCAutoRefCount)
     AddObjCARCExceptionMetadata(CI);
 
-  // Suppress tail calls if requested.
+  // Set tail call kind if necessary.
   if (llvm::CallInst *Call = dyn_cast<llvm::CallInst>(CI)) {
     if (TargetDecl && TargetDecl->hasAttr<NotTailCalledAttr>())
       Call->setTailCallKind(llvm::CallInst::TCK_NoTail);
+    else if (InMustTailCallExpr)
+      Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
   }
 
   // Add metadata for calls to MSAllocator functions
@@ -5307,6 +5309,22 @@
     return GetUndefRValue(RetTy);
   }
 
+  // If this is a musttail call, return immediately. We do not branch to the
+  // prologue in this case.
+  if (InMustTailCallExpr) {
+    // TODO(haberman): insert checks/assertions to verify that this early exit
+    // is safe. We tried to verify this in Sema but we should double-check
+    // here.
+    if (RetTy->isVoidType()) {
+      Builder.CreateRetVoid();
+    } else {
+      Builder.CreateRet(CI);
+    }
+    Builder.ClearInsertionPoint();
+    EnsureInsertPoint();
+    return GetUndefRValue(RetTy);
+  }
+
   // Perform the swifterror writeback.
   if (swiftErrorTemp.isValid()) {
     llvm::Value *errorResult = Builder.CreateLoad(swiftErrorTemp);
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1848,6 +1848,7 @@
   void setFunctionHasBranchIntoScope();
   void setFunctionHasBranchProtectedScope();
   void setFunctionHasIndirectGoto();
+  void setFunctionHasMustTail();
 
   void PushCompoundScope(bool IsStmtExpr);
   void PopCompoundScope();
@@ -11338,6 +11339,10 @@
   /// function, issuing a diagnostic if not.
   void checkVariadicArgument(const Expr *E, VariadicCallType CT);
 
+  /// Check whether the given expression can have musttail applied to it,
+  /// issuing a diagnostic and returning false if not.
+  bool checkMustTailAttr(const Stmt *St, const Attr &MTA);
+
   /// Check to see if a given expression could have '.c_str()' called on it.
   bool hasCStrMethod(const Expr *E);
 
Index: clang/include/clang/Sema/ScopeInfo.h
===================================================================
--- clang/include/clang/Sema/ScopeInfo.h
+++ clang/include/clang/Sema/ScopeInfo.h
@@ -118,6 +118,9 @@
   /// Whether this function contains any indirect gotos.
   bool HasIndirectGoto : 1;
 
+  /// Whether this function contains any statement marked with \c [musttail].
+  bool HasMustTail : 1;
+
   /// Whether a statement was dropped because it was invalid.
   bool HasDroppedStmt : 1;
 
@@ -370,14 +373,13 @@
 public:
   FunctionScopeInfo(DiagnosticsEngine &Diag)
       : Kind(SK_Function), HasBranchProtectedScope(false),
-        HasBranchIntoScope(false), HasIndirectGoto(false),
+        HasBranchIntoScope(false), HasIndirectGoto(false), HasMustTail(false),
         HasDroppedStmt(false), HasOMPDeclareReductionCombiner(false),
         HasFallthroughStmt(false), UsesFPIntrin(false),
-        HasPotentialAvailabilityViolations(false),
-        ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false),
-        ObjCWarnForNoDesignatedInitChain(false), ObjCIsSecondaryInit(false),
-        ObjCWarnForNoInitDelegation(false), NeedsCoroutineSuspends(true),
-        ErrorTrap(Diag) {}
+        HasPotentialAvailabilityViolations(false), ObjCShouldCallSuper(false),
+        ObjCIsDesignatedInit(false), ObjCWarnForNoDesignatedInitChain(false),
+        ObjCIsSecondaryInit(false), ObjCWarnForNoInitDelegation(false),
+        NeedsCoroutineSuspends(true), ErrorTrap(Diag) {}
 
   virtual ~FunctionScopeInfo();
 
@@ -423,6 +425,8 @@
     HasIndirectGoto = true;
   }
 
+  void setHasMustTail() { HasMustTail = true; }
+
   void setHasDroppedStmt() {
     HasDroppedStmt = true;
   }
@@ -450,9 +454,8 @@
   }
 
   bool NeedsScopeChecking() const {
-    return !HasDroppedStmt &&
-        (HasIndirectGoto ||
-          (HasBranchProtectedScope && HasBranchIntoScope));
+    return !HasDroppedStmt && (HasIndirectGoto || HasMustTail ||
+                               (HasBranchProtectedScope && HasBranchIntoScope));
   }
 
   // Add a block introduced in this function.
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2823,6 +2823,19 @@
   "%0 attribute is ignored because there exists no call expression inside the "
   "statement">,
   InGroup<IgnoredAttributes>;
+def err_musttail_only_on_return : Error<
+  "%0 attribute can only be applied to a return statement">;
+def err_musttail_needs_call : Error<
+  "%0 attribute requires that the return value is a function call, which must "
+  "not create or destroy any temporaries.">;
+def err_musttail_only_from_function : Error<
+  "%0 attribute can only be used from a regular function.">;
+def err_musttail_return_type_mismatch : Error<
+  "%0 attribute requires that caller and callee have identical parameter types "
+  "and return types">;
+def err_musttail_no_destruction : Error<
+  "%0 attribute does not allow any variables in scope that require destruction"
+  >;
 def err_nsobject_attribute : Error<
   "'NSObject' attribute is for pointer types only">;
 def err_attributes_are_not_compatible : Error<
Index: clang/include/clang/Basic/AttrDocs.td
===================================================================
--- clang/include/clang/Basic/AttrDocs.td
+++ clang/include/clang/Basic/AttrDocs.td
@@ -443,6 +443,24 @@
   }];
 }
 
+def MustTailDocs : Documentation {
+  let Category = DocCatStmt;
+  let Content = [{
+If a return statement is marked ``musttail``, this indicates that the
+compiler must generate a tail call for the program to be correct, even when
+optimizations are disabled. This guarantees that the call will not cause
+unbounded stack growth if it is part of a recursive cycle in the call graph.
+
+``clang::musttail`` can only be applied to a return statement whose value is a
+function call (even functions returning void must use 'return', although no
+value is returned). The target function must have the same number of arguments
+as the caller. The types of the return value and all arguments must perfectly
+match, including the implicit "this" argument, if any. There may not be any
+variables currently in scope that require destruction. The arguments and return
+type of the function must be trivially destructible.
+  }];
+}
+
 def AssertCapabilityDocs : Documentation {
   let Category = DocCatFunction;
   let Heading = "assert_capability, assert_shared_capability";
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -1352,6 +1352,11 @@
   let SimpleHandler = 1;
 }
 
+def MustTail : StmtAttr {
+  let Spellings = [Clang<"musttail">];
+  let Documentation = [MustTailDocs];
+}
+
 def FastCall : DeclOrTypeAttr {
   let Spellings = [GCC<"fastcall">, Keyword<"__fastcall">,
                    Keyword<"_fastcall">];
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to