haberman updated this revision to Diff 336310.
haberman marked 3 inline comments as done.
haberman added a comment.

- Refined the implicit constructor skipping code.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D99517/new/

https://reviews.llvm.org/D99517

Files:
  clang/include/clang/AST/IgnoreExpr.h
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/AttrDocs.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/ScopeInfo.h
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGCall.cpp
  clang/lib/CodeGen/CGClass.cpp
  clang/lib/CodeGen/CGDecl.cpp
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CGExprCXX.cpp
  clang/lib/CodeGen/CGStmt.cpp
  clang/lib/CodeGen/CodeGenFunction.h
  clang/lib/CodeGen/EHScopeStack.h
  clang/lib/Sema/JumpDiagnostics.cpp
  clang/lib/Sema/Sema.cpp
  clang/lib/Sema/SemaStmt.cpp
  clang/lib/Sema/SemaStmtAttr.cpp
  clang/test/CodeGen/attr-musttail.cpp
  clang/test/Sema/attr-musttail.c
  clang/test/Sema/attr-musttail.cpp
  clang/test/Sema/attr-musttail.m

Index: clang/test/Sema/attr-musttail.m
===================================================================
--- /dev/null
+++ clang/test/Sema/attr-musttail.m
@@ -0,0 +1,26 @@
+// RUN: %clang_cc1 -fsyntax-only -fblocks -Wno-objc-root-class -verify %s
+
+void TestObjcBlock(void) {
+  void (^x)(void) = ^(void) {
+    __attribute__((musttail)) return TestObjcBlock(); // expected-error{{'musttail' attribute cannot be used from a block}}
+  };
+  __attribute__((musttail)) return x();
+}
+
+void ReturnsVoid(void);
+void TestObjcBlockVar(void) {
+  __block int i = 0;                              // expected-note{{jump exits scope of __block variable}}
+  __attribute__((musttail)) return ReturnsVoid(); // expected-error{{cannot perform a tail call from this return statement}}
+}
+
+__attribute__((objc_root_class))
+@interface TestObjcClass
+@end
+
+@implementation TestObjcClass
+
+- (void)testObjCMethod {
+  __attribute__((musttail)) return ReturnsVoid(); // expected-error{{'musttail' attribute cannot be used from an Objective-C function}}
+}
+
+@end
Index: clang/test/Sema/attr-musttail.cpp
===================================================================
--- /dev/null
+++ clang/test/Sema/attr-musttail.cpp
@@ -0,0 +1,201 @@
+// RUN: %clang_cc1 -verify -fsyntax-only -fms-extensions -fcxx-exceptions -fopenmp %s
+
+int ReturnsInt1();
+int Func1() {
+  [[clang::musttail]] ReturnsInt1();              // expected-error {{'musttail' attribute only applies to return statements}}
+  [[clang::musttail(1, 2)]] return ReturnsInt1(); // expected-error {{'musttail' attribute takes no arguments}}
+  [[clang::musttail]] return 5;                   // expected-error {{'musttail' attribute requires that the return value is the result of a function call}}
+  [[clang::musttail]] return ReturnsInt1();
+}
+
+void NoFunctionCall() {
+  [[clang::musttail]] return; // expected-error {{'musttail' attribute requires that the return value is the result of a function call}}
+}
+
+[[clang::musttail]] static int int_val = ReturnsInt1(); // expected-error {{'musttail' attribute cannot be applied to a declaration}}
+
+void NoParams(); // expected-note {{target function has different number of parameters (expected 1 but has 0)}}
+void TestParamArityMismatch(int x) {
+  [[clang::musttail]] return NoParams(); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+
+void LongParam(long x); // expected-note {{target function has type mismatch at 1st parameter (expected 'long' but has 'int')}}
+void TestParamTypeMismatch(int x) {
+  [[clang::musttail]] return LongParam(x); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+
+long ReturnsLong(); // expected-note {{target function has different return type ('int' expected but has 'long')}}
+int TestReturnTypeMismatch() {
+  [[clang::musttail]] return ReturnsLong(); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+
+struct Struct1 {
+  void MemberFunction(); // expected-note {{target function is a member of different class (expected 'void' but has 'Struct1')}}
+};
+void TestNonMemberToMember() {
+  Struct1 st;
+  [[clang::musttail]] return st.MemberFunction(); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+
+void ReturnsVoid(); // expected-note {{target function is a member of different class (expected 'Struct2' but has 'void')}}
+struct Struct2 {
+  void TestMemberToNonMember() {
+    [[clang::musttail]] return ReturnsVoid(); // expected-error{{'musttail' attribute requires that caller and callee have compatible function signatures}}
+  }
+};
+
+class HasNonTrivialDestructor {
+public:
+  ~HasNonTrivialDestructor() {}
+  int ReturnsInt();
+};
+
+void ReturnsVoid2();
+void TestNonTrivialDestructorInScope() {
+  HasNonTrivialDestructor foo;              // expected-note {{jump exits scope of variable with non-trivial destructor}}
+  [[clang::musttail]] return ReturnsVoid(); // expected-error {{cannot perform a tail call from this return statement}}
+}
+
+int NonTrivialParam(HasNonTrivialDestructor x);
+int TestNonTrivialParam(HasNonTrivialDestructor x) {
+  [[clang::musttail]] return NonTrivialParam(x); // expected-error {{'musttail' attribute requires that the return value, all parameters, and any temporaries created by the expression are trivially destructible}}
+}
+
+HasNonTrivialDestructor ReturnsNonTrivialValue();
+HasNonTrivialDestructor TestReturnsNonTrivialValue() {
+  [[clang::musttail]] return (ReturnsNonTrivialValue()); // expected-error {{'musttail' attribute requires that the return value, all parameters, and any temporaries created by the expression are trivially destructible}}
+}
+
+HasNonTrivialDestructor TestReturnsNonTrivialNonFunctionCall() {
+  [[clang::musttail]] return HasNonTrivialDestructor(); // expected-error {{'musttail' attribute requires that the return value is the result of a function call}}
+}
+
+struct UsesPointerToMember {
+  void (UsesPointerToMember::*p_mem)();
+};
+void TestUsesPointerToMember(UsesPointerToMember *foo) {
+  // "this" pointer cannot double as first parameter.
+  [[clang::musttail]] return (foo->*(foo->p_mem))(); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}} expected-note{{target function is a member of different class (expected 'void' but has 'UsesPointerToMember')}}
+}
+
+void ReturnsVoid2();
+void TestNestedClass() {
+  HasNonTrivialDestructor foo;
+  class Nested {
+    __attribute__((noinline)) static void NestedMethod() {
+      // Outer non-trivial destructor does not affect nested class.
+      [[clang::musttail]] return ReturnsVoid2();
+    }
+  };
+}
+
+template <class T>
+T TemplateFunc(T x) { // expected-note{{target function has different return type ('long' expected but has 'int')}}
+  return x ? 5 : 10;
+}
+int OkTemplateFunc(int x) {
+  [[clang::musttail]] return TemplateFunc<int>(x);
+}
+template <class T>
+T BadTemplateFunc(T x) {
+  [[clang::musttail]] return TemplateFunc<int>(x); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+long TestBadTemplateFunc(long x) {
+  return BadTemplateFunc<long>(x); // expected-note {{in instantiation of}}
+}
+
+void IntParam(int x);
+void TestVLA(int x) {
+  HasNonTrivialDestructor vla[x];         // expected-note {{jump exits scope of variable with non-trivial destructor}}
+  [[clang::musttail]] return IntParam(x); // expected-error {{cannot perform a tail call from this return statement}}
+}
+
+void TestNonTrivialDestructorSubArg(int x) {
+  [[clang::musttail]] return IntParam(NonTrivialParam(HasNonTrivialDestructor())); // expected-error {{'musttail' attribute requires that the return value, all parameters, and any temporaries created by the expression are trivially destructible}}
+}
+
+void VariadicFunction(int x, ...);
+void TestVariadicFunction(int x, ...) {
+  [[clang::musttail]] return VariadicFunction(x); // expected-error {{'musttail' attribute may not be used with variadic functions}}
+}
+
+int TakesIntParam(int x);     // expected-note {{target function has type mismatch at 1st parameter (expected 'int' but has 'short')}}
+int TakesShortParam(short x); // expected-note {{target function has type mismatch at 1st parameter (expected 'short' but has 'int')}}
+int TestIntParamMismatch(int x) {
+  [[clang::musttail]] return TakesShortParam(x); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+int TestIntParamMismatch2(short x) {
+  [[clang::musttail]] return TakesIntParam(x); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+
+__regcall int RegCallReturnsInt(); // expected-note {{target function has calling convention regcall (expected cdecl)}}
+int TestMismatchCallingConvention() {
+  [[clang::musttail]] return RegCallReturnsInt(); // expected-error {{'musttail' attribute requires that caller and callee use the same calling convention}}
+}
+
+int TestLambda() {
+  auto lambda = []() { return 12; };   // expected-note {{target function is a member of different class (expected 'void' but has 'const (lambda}}
+  [[clang::musttail]] return lambda(); // expected-error {{'musttail' attribute requires that caller and callee have compatible function signatures}}
+}
+
+// These tests are merely verifying that we don't crash with incomplete or
+// erroneous ASTs. These cases crashed the compiler in early iterations.
+
+struct TestBadPMF {
+  int (TestBadPMF::*pmf)();
+  void BadPMF() {
+    [[clang::musttail]] return ((*this)->*pmf)(); // expected-error {{left hand operand to ->* must be a pointer to class compatible with the right hand operand, but is 'TestBadPMF'}}
+  }
+};
+
+namespace ns {}
+void TestCallNonValue() {
+  [[clang::musttail]] return ns; // expected-error {{unexpected namespace name 'ns': expected expression}}
+}
+
+int TestNonTrivialTemporary(int) {
+  [[clang::musttail]] return TakesIntParam(HasNonTrivialDestructor().ReturnsInt()); // expected-error {{'musttail' attribute requires that the return value, all parameters, and any temporaries created by the expression are trivially destructible}}
+}
+
+void ReturnsVoid();
+struct TestDestructor {
+  ~TestDestructor() {                         // expected-note {{caller is a destructor}}
+    [[clang::musttail]] return ReturnsVoid(); // expected-error {{destructor '~TestDestructor' must not return void expression}}  // expected-error {{'musttail' attribute cannot be used when caller or callee is a constructor or destructor, as these can have unusual calling conventions}}
+  }
+};
+
+struct ClassWithDestructor { // expected-note {{callee is a destructor}}
+  void TestExplicitDestructorCall() {
+    [[clang::musttail]] return this->~ClassWithDestructor(); // expected-error {{'musttail' attribute cannot be used when caller or callee is a constructor or destructor, as these can have unusual calling conventions}}
+  }
+};
+
+struct HasNonTrivialCopyConstructor {
+  HasNonTrivialCopyConstructor(const HasNonTrivialCopyConstructor &);
+};
+HasNonTrivialCopyConstructor ReturnsClassByValue();
+HasNonTrivialCopyConstructor TestNonElidableCopyConstructor() {
+  // This is an elidable constructor, but when it is written explicitly
+  // we decline to elide it.
+  [[clang::musttail]] return HasNonTrivialCopyConstructor(ReturnsClassByValue()); // expected-error{{'musttail' attribute requires that the return value is the result of a function call}}
+}
+
+struct ClassWithConstructor {}; // expected-note {{callee is a constructor}}
+void TestExplicitConstructorCall(ClassWithConstructor a) {
+  [[clang::musttail]] return a.ClassWithConstructor::ClassWithConstructor(); // expected-error{{'musttail' attribute cannot be used when caller or callee is a constructor or destructor, as these can have unusual calling conventions}}  expected-warning{{explicit constructor calls are a Microsoft extension}}
+}
+
+void TestStatementExpression() {
+  ({
+    HasNonTrivialDestructor foo;               // expected-note {{jump exits scope of variable with non-trivial destructor}}
+    [[clang::musttail]] return ReturnsVoid2(); // expected-error {{cannot perform a tail call from this return statement}}
+  });
+}
+
+struct MyException {};
+void TestTryBlock() {
+  try {                                        // expected-note {{jump exits try block}}
+    [[clang::musttail]] return ReturnsVoid2(); // expected-error {{cannot perform a tail call from this return statement}}
+  } catch (MyException &e) {
+  }
+}
Index: clang/test/Sema/attr-musttail.c
===================================================================
--- /dev/null
+++ clang/test/Sema/attr-musttail.c
@@ -0,0 +1,15 @@
+// RUN: %clang_cc1 -verify -fsyntax-only %s
+
+int NotAProtoType(); // expected-note{{add 'void' to the parameter list to turn a old-style K&R function declaration into a prototype}}
+int TestCalleeNotProtoType(void) {
+  __attribute__((musttail)) return NotAProtoType(); // expected-error{{'musttail' attribute requires that both caller and callee functions have a prototype}}
+}
+
+int ProtoType(void);
+int TestCallerNotProtoType() {                  // expected-note{{add 'void' to the parameter list to turn a old-style K&R function declaration into a prototype}}
+  __attribute__((musttail)) return ProtoType(); // expected-error{{'musttail' attribute requires that both caller and callee functions have a prototype}}
+}
+
+int TestProtoType(void) {
+  return ProtoType();
+}
Index: clang/test/CodeGen/attr-musttail.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGen/attr-musttail.cpp
@@ -0,0 +1,194 @@
+// RUN: %clang_cc1 -fno-elide-constructors -S -emit-llvm %s -triple x86_64-unknown-linux-gnu -o - | FileCheck %s
+// RUN: %clang_cc1 -fno-elide-constructors -S -emit-llvm %s -triple x86_64-unknown-linux-gnu -o - | opt -verify
+// FIXME: remove the call to "opt" once the tests are running the Clang verifier automatically again.
+
+int Bar(int);
+int Baz(int);
+
+int Func1(int x) {
+  if (x) {
+    // CHECK: %call = musttail call i32 @_Z3Bari(i32 %1)
+    // CHECK-NEXT: ret i32 %call
+    [[clang::musttail]] return Bar(x);
+  } else {
+    [[clang::musttail]] return Baz(x); // CHECK: %call1 = musttail call i32 @_Z3Bazi(i32 %3)
+  }
+}
+
+int Func2(int x) {
+  {
+    [[clang::musttail]] return Bar(Bar(x));
+  }
+}
+
+// CHECK: %call1 = musttail call i32 @_Z3Bari(i32 %call)
+
+class Foo {
+public:
+  static int StaticMethod(int x);
+  int MemberFunction(int x);
+  int TailFrom(int x);
+  int TailFrom2(int x);
+  int TailFrom3(int x);
+};
+
+int Foo::TailFrom(int x) {
+  [[clang::musttail]] return MemberFunction(x);
+}
+
+// CHECK: %call = musttail call i32 @_ZN3Foo14MemberFunctionEi(%class.Foo* nonnull align 1 dereferenceable(1) %this1, i32 %0)
+
+int Func3(int x) {
+  [[clang::musttail]] return Foo::StaticMethod(x);
+}
+
+// CHECK: %call = musttail call i32 @_ZN3Foo12StaticMethodEi(i32 %0)
+
+int Func4(int x) {
+  Foo foo; // Object with trivial destructor.
+  [[clang::musttail]] return foo.StaticMethod(x);
+}
+
+// CHECK: %call = musttail call i32 @_ZN3Foo12StaticMethodEi(i32 %0)
+
+int (Foo::*pmf)(int);
+
+int Foo::TailFrom2(int x) {
+  [[clang::musttail]] return ((*this).*pmf)(x);
+}
+
+// CHECK: %call = musttail call i32 %8(%class.Foo* nonnull align 1 dereferenceable(1) %this.adjusted, i32 %9)
+
+int Foo::TailFrom3(int x) {
+  [[clang::musttail]] return (this->*pmf)(x);
+}
+
+// CHECK: %call = musttail call i32 %8(%class.Foo* nonnull align 1 dereferenceable(1) %this.adjusted, i32 %9)
+
+void ReturnsVoid();
+
+void Func5() {
+  [[clang::musttail]] return ReturnsVoid();
+}
+
+// CHECK: musttail call void @_Z11ReturnsVoidv()
+
+class HasTrivialDestructor {};
+
+int ReturnsInt(int x);
+
+int Func6(int x) {
+  HasTrivialDestructor foo;
+  [[clang::musttail]] return ReturnsInt(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z10ReturnsInti(i32 %0)
+
+struct Data {
+  int (*fptr)(Data *);
+};
+
+int Func7(Data *data) {
+  [[clang::musttail]] return data->fptr(data);
+}
+
+// CHECK: %call = musttail call i32 %1(%struct.Data* %2)
+
+template <class T>
+T TemplateFunc(T) {
+  return 5;
+}
+
+int Func9(int x) {
+  [[clang::musttail]] return TemplateFunc<int>(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z12TemplateFuncIiET_S0_(i32 %0)
+
+template <class T>
+int Func10(int x) {
+  T t;
+  [[clang::musttail]] return Bar(x);
+}
+
+int Func11(int x) {
+  return Func10<int>(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z3Bari(i32 %0)
+
+template <class T>
+T Func12(T x) {
+  [[clang::musttail]] return ::Bar(x);
+}
+
+int Func13(int x) {
+  return Func12<int>(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z3Bari(i32 %0)
+
+int Func14(int x) {
+  int vla[x];
+  [[clang::musttail]] return Bar(x);
+}
+
+// CHECK: %call = musttail call i32 @_Z3Bari(i32 %3)
+
+void TrivialDestructorParam(HasTrivialDestructor obj);
+
+void Func14(HasTrivialDestructor obj) {
+  [[clang::musttail]] return TrivialDestructorParam(obj);
+}
+
+// CHECK: musttail call void @_Z22TrivialDestructorParam20HasTrivialDestructor()
+
+struct Struct3 {
+  void ConstMemberFunction(const int *) const;
+  void NonConstMemberFunction(int *i);
+};
+void Struct3::NonConstMemberFunction(int *i) {
+  // The parameters are not identical, but they are compatible.
+  [[clang::musttail]] return ConstMemberFunction(i);
+}
+
+// CHECK: musttail call void @_ZNK7Struct319ConstMemberFunctionEPKi(%struct.Struct3* nonnull align 1 dereferenceable(1) %this1, i32* %0)
+
+struct HasNonTrivialCopyConstructor {
+  HasNonTrivialCopyConstructor(const HasNonTrivialCopyConstructor &);
+};
+HasNonTrivialCopyConstructor ReturnsClassByValue();
+HasNonTrivialCopyConstructor TestNonElidableCopyConstructor() {
+  [[clang::musttail]] return (((ReturnsClassByValue())));
+}
+
+// CHECK: musttail call void @_Z19ReturnsClassByValuev(%struct.HasNonTrivialCopyConstructor* sret(%struct.HasNonTrivialCopyConstructor) align 1 %agg.result)
+
+struct HasNonTrivialCopyConstructor2 {
+  // Copy constructor works even if it has extra default params.
+  HasNonTrivialCopyConstructor2(const HasNonTrivialCopyConstructor &, int DefaultParam = 5);
+};
+HasNonTrivialCopyConstructor2 ReturnsClassByValue2();
+HasNonTrivialCopyConstructor2 TestNonElidableCopyConstructor2() {
+  [[clang::musttail]] return (((ReturnsClassByValue2())));
+}
+
+// CHECK: musttail call void @_Z20ReturnsClassByValue2v()
+
+void NoCalleeDecl(int x) {
+  void (*p)(int) = nullptr;
+  [[clang::musttail]] return p(x);
+}
+
+// CHECK: musttail call void %0(i32 %1)
+
+struct LargeWithCopyConstructor {
+  LargeWithCopyConstructor(const LargeWithCopyConstructor &);
+  char data[32];
+};
+LargeWithCopyConstructor ReturnsLarge();
+LargeWithCopyConstructor TestLargeWithCopyConstructor() {
+  [[clang::musttail]] return ReturnsLarge();
+}
+
+// CHECK: musttail call void @_Z12ReturnsLargev(%struct.LargeWithCopyConstructor* sret(%struct.LargeWithCopyConstructor) align 1 %agg.result)
Index: clang/lib/Sema/SemaStmtAttr.cpp
===================================================================
--- clang/lib/Sema/SemaStmtAttr.cpp
+++ clang/lib/Sema/SemaStmtAttr.cpp
@@ -209,6 +209,14 @@
   return ::new (S.Context) NoMergeAttr(S.Context, A);
 }
 
+static Attr *handleMustTailAttr(Sema &S, Stmt *St, const ParsedAttr &A,
+                                SourceRange Range) {
+  MustTailAttr MTA(S.Context, A);
+
+  // Validation is in Sema::ActOnAttributedStmt().
+  return ::new (S.Context) MustTailAttr(S.Context, A);
+}
+
 static Attr *handleLikely(Sema &S, Stmt *St, const ParsedAttr &A,
                           SourceRange Range) {
 
@@ -399,6 +407,8 @@
     return handleSuppressAttr(S, St, A, Range);
   case ParsedAttr::AT_NoMerge:
     return handleNoMergeAttr(S, St, A, Range);
+  case ParsedAttr::AT_MustTail:
+    return handleMustTailAttr(S, St, A, Range);
   case ParsedAttr::AT_Likely:
     return handleLikely(S, St, A, Range);
   case ParsedAttr::AT_Unlikely:
Index: clang/lib/Sema/SemaStmt.cpp
===================================================================
--- clang/lib/Sema/SemaStmt.cpp
+++ clang/lib/Sema/SemaStmt.cpp
@@ -10,17 +10,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Sema/Ownership.h"
-#include "clang/Sema/SemaInternal.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTDiagnostic.h"
 #include "clang/AST/ASTLambda.h"
-#include "clang/AST/CharUnits.h"
 #include "clang/AST/CXXInheritance.h"
+#include "clang/AST/CharUnits.h"
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/EvaluatedExprVisitor.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
+#include "clang/AST/IgnoreExpr.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
@@ -30,8 +29,10 @@
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Sema/Initialization.h"
 #include "clang/Sema/Lookup.h"
+#include "clang/Sema/Ownership.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
@@ -558,6 +559,17 @@
 StmtResult Sema::BuildAttributedStmt(SourceLocation AttrsLoc,
                                      ArrayRef<const Attr *> Attrs,
                                      Stmt *SubStmt) {
+  // FIXME: move this code should move when a planned refactoring around
+  // statement attributes lands.
+  for (const auto *A : Attrs) {
+    if (A->getKind() == attr::MustTail) {
+      if (!checkAndRewriteMustTailAttr(SubStmt, *A)) {
+        return SubStmt;
+      }
+      setFunctionHasMustTail();
+    }
+  }
+
   return AttributedStmt::Create(Context, AttrsLoc, Attrs, SubStmt);
 }
 
@@ -573,6 +585,219 @@
   return SubStmt;
 }
 
+bool Sema::checkAndRewriteMustTailAttr(Stmt *St, const Attr &MTA) {
+  if (!checkMustTailAttr(St, MTA))
+    return false;
+
+  // FIXME: Replace Expr::IgnoreImplicitAsWritten() with this function.
+  // Currently it does not skip implicit constructors in an initialization
+  // context.
+  auto IgnoreImplicitAsWritten = [](Expr *Ex) -> Expr * {
+    return IgnoreExprNodes(Ex, IgnoreImplicitAsWrittenSingleStep,
+                           IgnoreImplicitConstructorSingleStep);
+  };
+
+  // Now that we have verified that 'musttail' is valid here, rewrite the
+  // return value to remove all implicit nodes, but retain parentheses.
+  ReturnStmt *R = cast<ReturnStmt>(St);
+  R->setRetValue(IgnoreImplicitAsWritten(R->getRetValue()));
+  Expr *Ex = R->getRetValue();
+  while (!isa<CallExpr>(Ex)) {
+    auto *PE = cast<ParenExpr>(Ex);
+    Ex = IgnoreImplicitAsWritten(PE->getSubExpr());
+    PE->setSubExpr(Ex);
+  }
+  return true;
+}
+
+bool Sema::checkMustTailAttr(const Stmt *St, const Attr &MTA) {
+  if (CurContext->isDependentContext())
+    // We have to suspend our check until template instantiation time.
+    return true;
+
+  const ReturnStmt *R = cast<const ReturnStmt>(St);
+  const Expr *Ex = R->getRetValue();
+
+  // FIXME: Add Expr::IgnoreParenImplicitAsWritten() with this definition.
+  auto IgnoreParenImplicitAsWritten = [](const Expr *Ex) -> const Expr * {
+    return IgnoreExprNodes(const_cast<Expr *>(Ex), IgnoreParensSingleStep,
+                           IgnoreImplicitAsWrittenSingleStep,
+                           IgnoreImplicitConstructorSingleStep);
+  };
+
+  const CallExpr *CE =
+      dyn_cast_or_null<CallExpr>(IgnoreParenImplicitAsWritten(Ex));
+
+  if (!CE) {
+    Diag(St->getBeginLoc(), diag::err_musttail_needs_call) << &MTA;
+    return false;
+  }
+
+  if (!CE->getCalleeDecl()) {
+    assert(hasUncompilableErrorOccurred() && "expected previous error");
+    return false;
+  }
+
+  if (const ExprWithCleanups *EWC = dyn_cast<ExprWithCleanups>(Ex)) {
+    if (EWC->cleanupsHaveSideEffects()) {
+      Diag(St->getBeginLoc(), diag::err_musttail_needs_trivial_args) << &MTA;
+      return false;
+    }
+  }
+
+  // We need to determine the full function type (including "this" type, if any)
+  // for both caller and callee.
+  struct FuncType {
+    QualType This;
+    const FunctionProtoType *Func;
+  } CallerType, CalleeType;
+
+  const FunctionDecl *CallerDecl = dyn_cast<FunctionDecl>(CurContext);
+
+  auto GetMethodType = [this, St, MTA](const CXXMethodDecl *CMD, FuncType &Type,
+                                       bool IsCallee) -> bool {
+    if (isa<CXXConstructorDecl>(CMD) || isa<CXXDestructorDecl>(CMD)) {
+      Diag(St->getBeginLoc(), diag::err_musttail_structors_forbidden) << &MTA;
+      Diag(CMD->getBeginLoc(), diag::note_musttail_structors_forbidden)
+          << IsCallee << isa<CXXDestructorDecl>(CMD);
+      return false;
+    }
+    if (!CMD->isStatic())
+      Type.This = CMD->getThisType()->getPointeeType();
+    Type.Func = CMD->getType()->castAs<FunctionProtoType>();
+    return true;
+  };
+
+  if (!CallerDecl || isa<CapturedDecl>(CurContext)) {
+    int ContextType;
+    if (isa<BlockDecl>(CurContext))
+      ContextType = 0;
+    else if (isa<ObjCMethodDecl>(CurContext))
+      ContextType = 1;
+    else
+      ContextType = 2;
+    Diag(St->getBeginLoc(), diag::err_musttail_forbidden_from_this_context)
+        << &MTA << ContextType;
+    return false;
+  } else if (const CXXMethodDecl *CMD = dyn_cast<CXXMethodDecl>(CurContext)) {
+    // Caller is a class/struct method.
+    if (!GetMethodType(CMD, CallerType, false))
+      return false;
+  } else {
+    // Caller is a non-method function.
+    CallerType.Func = dyn_cast<FunctionProtoType>(CallerDecl->getType());
+  }
+
+  const Decl *CalleeDecl = CE->getCalleeDecl();
+  const ValueDecl *VD = cast<ValueDecl>(CalleeDecl);
+
+  if (const CXXMethodDecl *CMD = dyn_cast<CXXMethodDecl>(CalleeDecl)) {
+    // Call is: obj.method(), obj->method(), functor(), etc.
+    if (!GetMethodType(CMD, CalleeType, true))
+      return false;
+  } else if (VD && dyn_cast<MemberPointerType>(VD->getType())) {
+    // Call is: obj->*method_ptr or obj.*method_ptr
+    const MemberPointerType *MPT = VD->getType()->castAs<MemberPointerType>();
+    CalleeType.This = QualType(MPT->getClass(), 0);
+    CalleeType.Func = MPT->getPointeeType()->castAs<FunctionProtoType>();
+  } else {
+    // Non-method function.
+    CalleeType.Func =
+        dyn_cast<FunctionProtoType>(CE->getCallee()
+                                        ->IgnoreParens()
+                                        ->getType()
+                                        ->getPointeeType()
+                                        ->getUnqualifiedDesugaredType());
+  }
+
+  // Both caller and callee must have a prototype (no K&R declarations).
+  if (!CalleeType.Func || !CallerType.Func) {
+    Diag(St->getBeginLoc(), diag::err_musttail_needs_prototype) << &MTA;
+    if (!CalleeType.Func && CE->getDirectCallee()) {
+      Diag(CE->getDirectCallee()->getBeginLoc(),
+           diag::note_musttail_fix_non_prototype);
+    }
+    if (!CallerType.Func)
+      Diag(CallerDecl->getBeginLoc(), diag::note_musttail_fix_non_prototype);
+    return false;
+  }
+
+  if (CallerType.Func->getCallConv() != CalleeType.Func->getCallConv()) {
+    SourceLocation CalleeLoc = CE->getDirectCallee()
+                                   ? CE->getDirectCallee()->getBeginLoc()
+                                   : St->getBeginLoc();
+    Diag(St->getBeginLoc(), diag::err_musttail_callconv_mismatch) << &MTA;
+    Diag(CalleeLoc, diag::note_musttail_callconv_mismatch)
+        << FunctionType::getNameForCallConv(CallerType.Func->getCallConv())
+        << FunctionType::getNameForCallConv(CalleeType.Func->getCallConv());
+    return false;
+  }
+
+  if (CalleeType.Func->isVariadic() || CallerType.Func->isVariadic()) {
+    Diag(St->getBeginLoc(), diag::err_musttail_no_variadic) << &MTA;
+    return false;
+  }
+
+  auto CheckTypesMatch = [this](FuncType CallerType, FuncType CalleeType,
+                                PartialDiagnostic &PD) -> bool {
+    enum {
+      ft_different_class,
+      ft_parameter_arity,
+      ft_parameter_mismatch,
+      ft_return_type,
+    };
+
+    auto DoTypesMatch = [this, &PD](QualType A, QualType B,
+                                    unsigned Select) -> bool {
+      if (A.isNull())
+        A = Context.VoidTy;
+      if (B.isNull())
+        B = Context.VoidTy;
+      if (!Context.hasSimilarType(A, B)) {
+        PD << Select << A << B;
+        return false;
+      }
+      return true;
+    };
+
+    if (!DoTypesMatch(CallerType.Func->getReturnType(),
+                      CalleeType.Func->getReturnType(), ft_return_type) ||
+        !DoTypesMatch(CallerType.This, CalleeType.This, ft_different_class))
+      return false;
+
+    if (CallerType.Func->getNumParams() != CalleeType.Func->getNumParams()) {
+      PD << ft_parameter_arity << CallerType.Func->getNumParams()
+         << CalleeType.Func->getNumParams();
+      return false;
+    }
+
+    ArrayRef<QualType> CalleeParams = CalleeType.Func->getParamTypes();
+    ArrayRef<QualType> CallerParams = CallerType.Func->getParamTypes();
+    size_t N = CallerType.Func->getNumParams();
+    for (size_t I = 0; I < N; I++) {
+      if (!DoTypesMatch(CalleeParams[I], CallerParams[I],
+                        ft_parameter_mismatch)) {
+        PD << static_cast<int>(I) + 1;
+        return false;
+      }
+    }
+
+    return true;
+  };
+
+  PartialDiagnostic PD = PDiag(diag::note_musttail_mismatch);
+  if (!CheckTypesMatch(CallerType, CalleeType, PD)) {
+    SourceLocation CalleeLoc = CE->getDirectCallee()
+                                   ? CE->getDirectCallee()->getBeginLoc()
+                                   : St->getBeginLoc();
+    Diag(St->getBeginLoc(), diag::err_musttail_mismatch) << &MTA;
+    Diag(CalleeLoc, PD);
+    return false;
+  }
+
+  return true;
+}
+
 namespace {
 class CommaVisitor : public EvaluatedExprVisitor<CommaVisitor> {
   typedef EvaluatedExprVisitor<CommaVisitor> Inherited;
Index: clang/lib/Sema/Sema.cpp
===================================================================
--- clang/lib/Sema/Sema.cpp
+++ clang/lib/Sema/Sema.cpp
@@ -2079,6 +2079,11 @@
     FunctionScopes.back()->setHasIndirectGoto();
 }
 
+void Sema::setFunctionHasMustTail() {
+  if (!FunctionScopes.empty())
+    FunctionScopes.back()->setHasMustTail();
+}
+
 BlockScopeInfo *Sema::getCurBlock() {
   if (FunctionScopes.empty())
     return nullptr;
Index: clang/lib/Sema/JumpDiagnostics.cpp
===================================================================
--- clang/lib/Sema/JumpDiagnostics.cpp
+++ clang/lib/Sema/JumpDiagnostics.cpp
@@ -11,13 +11,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Sema/SemaInternal.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
 #include "clang/AST/StmtOpenMP.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/BitVector.h"
 using namespace clang;
 
@@ -29,6 +30,10 @@
 ///    int a[n];
 ///  L:
 ///
+/// We also detect jumps out of protected scopes when it's not possible to do
+/// cleanups properly. Indirect jumps and ASM jumps can't do cleanups because
+/// the target is unknown. Return statements with \c [[clang::musttail]] cannot
+/// handle any cleanups due to the nature of a tail call.
 class JumpScopeChecker {
   Sema &S;
 
@@ -68,6 +73,7 @@
 
   SmallVector<Stmt*, 4> IndirectJumps;
   SmallVector<Stmt*, 4> AsmJumps;
+  SmallVector<AttributedStmt *, 4> MustTailStmts;
   SmallVector<LabelDecl*, 4> IndirectJumpTargets;
   SmallVector<LabelDecl*, 4> AsmJumpTargets;
 public:
@@ -81,6 +87,7 @@
 
   void VerifyJumps();
   void VerifyIndirectOrAsmJumps(bool IsAsmGoto);
+  void VerifyMustTailStmts();
   void NoteJumpIntoScopes(ArrayRef<unsigned> ToScopes);
   void DiagnoseIndirectOrAsmJump(Stmt *IG, unsigned IGScope, LabelDecl *Target,
                                  unsigned TargetScope);
@@ -88,6 +95,7 @@
                  unsigned JumpDiag, unsigned JumpDiagWarning,
                  unsigned JumpDiagCXX98Compat);
   void CheckGotoStmt(GotoStmt *GS);
+  const Attr *GetMustTailAttr(AttributedStmt *AS);
 
   unsigned GetDeepestCommonScope(unsigned A, unsigned B);
 };
@@ -109,6 +117,7 @@
   VerifyJumps();
   VerifyIndirectOrAsmJumps(false);
   VerifyIndirectOrAsmJumps(true);
+  VerifyMustTailStmts();
 }
 
 /// GetDeepestCommonScope - Finds the innermost scope enclosing the
@@ -580,6 +589,15 @@
     LabelAndGotoScopes[S] = ParentScope;
     break;
 
+  case Stmt::AttributedStmtClass: {
+    AttributedStmt *AS = cast<AttributedStmt>(S);
+    if (GetMustTailAttr(AS)) {
+      LabelAndGotoScopes[AS] = ParentScope;
+      MustTailStmts.push_back(AS);
+    }
+    break;
+  }
+
   default:
     if (auto *ED = dyn_cast<OMPExecutableDirective>(S)) {
       if (!ED->isStandaloneDirective()) {
@@ -971,6 +989,24 @@
   }
 }
 
+void JumpScopeChecker::VerifyMustTailStmts() {
+  for (AttributedStmt *AS : MustTailStmts) {
+    for (unsigned I = LabelAndGotoScopes[AS]; I; I = Scopes[I].ParentScope) {
+      if (Scopes[I].OutDiag) {
+        S.Diag(AS->getBeginLoc(), diag::err_musttail_scope);
+        S.Diag(Scopes[I].Loc, Scopes[I].OutDiag);
+      }
+    }
+  }
+}
+
+const Attr *JumpScopeChecker::GetMustTailAttr(AttributedStmt *AS) {
+  ArrayRef<const Attr *> Attrs = AS->getAttrs();
+  const auto *Iter =
+      llvm::find_if(Attrs, [](const Attr *A) { return isa<MustTailAttr>(A); });
+  return Iter != Attrs.end() ? *Iter : nullptr;
+}
+
 void Sema::DiagnoseInvalidJumps(Stmt *Body) {
   (void)JumpScopeChecker(Body, *this);
 }
Index: clang/lib/CodeGen/EHScopeStack.h
===================================================================
--- clang/lib/CodeGen/EHScopeStack.h
+++ clang/lib/CodeGen/EHScopeStack.h
@@ -150,6 +150,8 @@
     Cleanup(Cleanup &&) {}
     Cleanup() = default;
 
+    virtual bool isRedundantBeforeReturn() { return false; }
+
     /// Generation flags.
     class Flags {
       enum {
Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -517,6 +517,10 @@
   /// True if the current statement has nomerge attribute.
   bool InNoMergeAttributedStmt = false;
 
+  // The CallExpr within the current statement that the musttail attribute
+  // applies to.  nullptr if there is no 'musttail' on the current statement.
+  const CallExpr *MustTailCall = nullptr;
+
   /// True if the current function should be marked mustprogress.
   bool FnIsMustProgress = false;
 
@@ -565,6 +569,8 @@
   llvm::Instruction *CurrentFuncletPad = nullptr;
 
   class CallLifetimeEnd final : public EHScopeStack::Cleanup {
+    bool isRedundantBeforeReturn() override { return true; }
+
     llvm::Value *Addr;
     llvm::Value *Size;
 
@@ -3909,12 +3915,14 @@
   /// LLVM arguments and the types they were derived from.
   RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
                   ReturnValueSlot ReturnValue, const CallArgList &Args,
-                  llvm::CallBase **callOrInvoke, SourceLocation Loc);
+                  llvm::CallBase **callOrInvoke, bool IsMustTail,
+                  SourceLocation Loc);
   RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
                   ReturnValueSlot ReturnValue, const CallArgList &Args,
-                  llvm::CallBase **callOrInvoke = nullptr) {
+                  llvm::CallBase **callOrInvoke = nullptr,
+                  bool IsMustTail = false) {
     return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke,
-                    SourceLocation());
+                    IsMustTail, SourceLocation());
   }
   RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E,
                   ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr);
Index: clang/lib/CodeGen/CGStmt.cpp
===================================================================
--- clang/lib/CodeGen/CGStmt.cpp
+++ clang/lib/CodeGen/CGStmt.cpp
@@ -16,6 +16,8 @@
 #include "CodeGenModule.h"
 #include "TargetInfo.h"
 #include "clang/AST/Attr.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/Stmt.h"
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/DiagnosticSema.h"
@@ -646,12 +648,20 @@
 
 void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool nomerge = false;
-  for (const auto *A : S.getAttrs())
+  const CallExpr *musttail = nullptr;
+
+  for (const auto *A : S.getAttrs()) {
     if (A->getKind() == attr::NoMerge) {
       nomerge = true;
-      break;
     }
+    if (A->getKind() == attr::MustTail) {
+      const Stmt *Sub = S.getSubStmt();
+      const ReturnStmt *R = cast<ReturnStmt>(Sub);
+      musttail = cast<CallExpr>(R->getRetValue()->IgnoreParens());
+    }
+  }
   SaveAndRestore<bool> save_nomerge(InNoMergeAttributedStmt, nomerge);
+  SaveAndRestore<const CallExpr *> save_musttail(MustTailCall, musttail);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
Index: clang/lib/CodeGen/CGExprCXX.cpp
===================================================================
--- clang/lib/CodeGen/CGExprCXX.cpp
+++ clang/lib/CodeGen/CGExprCXX.cpp
@@ -87,6 +87,7 @@
   auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall(
       Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize);
   return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr,
+                  CE && CE == MustTailCall,
                   CE ? CE->getExprLoc() : SourceLocation());
 }
 
@@ -112,7 +113,7 @@
   commonEmitCXXMemberOrOperatorCall(*this, DtorDecl, This, ImplicitParam,
                                     ImplicitParamTy, CE, Args, nullptr);
   return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee,
-                  ReturnValueSlot(), Args, nullptr,
+                  ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall,
                   CE ? CE->getExprLoc() : SourceLocation{});
 }
 
@@ -472,7 +473,8 @@
   EmitCallArgs(Args, FPT, E->arguments());
   return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required,
                                                       /*PrefixSize=*/0),
-                  Callee, ReturnValue, Args, nullptr, E->getExprLoc());
+                  Callee, ReturnValue, Args, nullptr, E == MustTailCall,
+                  E->getExprLoc());
 }
 
 RValue
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Support/ConvertUTF.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/Path.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Transforms/Utils/SanitizerStats.h"
 
 #include <string>
@@ -5286,7 +5287,7 @@
   }
   llvm::CallBase *CallOrInvoke = nullptr;
   RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke,
-                         E->getExprLoc());
+                         E == MustTailCall, E->getExprLoc());
 
   // Generate function declaration DISuprogram in order to be used
   // in debug info about call sites.
Index: clang/lib/CodeGen/CGDecl.cpp
===================================================================
--- clang/lib/CodeGen/CGDecl.cpp
+++ clang/lib/CodeGen/CGDecl.cpp
@@ -550,6 +550,7 @@
   struct CallStackRestore final : EHScopeStack::Cleanup {
     Address Stack;
     CallStackRestore(Address Stack) : Stack(Stack) {}
+    bool isRedundantBeforeReturn() override { return true; }
     void Emit(CodeGenFunction &CGF, Flags flags) override {
       llvm::Value *V = CGF.Builder.CreateLoad(Stack);
       llvm::Function *F = CGF.CGM.getIntrinsic(llvm::Intrinsic::stackrestore);
Index: clang/lib/CodeGen/CGClass.cpp
===================================================================
--- clang/lib/CodeGen/CGClass.cpp
+++ clang/lib/CodeGen/CGClass.cpp
@@ -2182,7 +2182,7 @@
   const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall(
       Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs);
   CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type));
-  EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, Loc);
+  EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc);
 
   // Generate vtable assumptions if we're constructing a complete object
   // with a vtable.  We don't do this for base subobjects for two reasons:
Index: clang/lib/CodeGen/CGCall.cpp
===================================================================
--- clang/lib/CodeGen/CGCall.cpp
+++ clang/lib/CodeGen/CGCall.cpp
@@ -4565,7 +4565,7 @@
                                  const CGCallee &Callee,
                                  ReturnValueSlot ReturnValue,
                                  const CallArgList &CallArgs,
-                                 llvm::CallBase **callOrInvoke,
+                                 llvm::CallBase **callOrInvoke, bool IsMustTail,
                                  SourceLocation Loc) {
   // FIXME: We no longer need the types from CallArgs; lift up and simplify.
 
@@ -5258,10 +5258,12 @@
   if (CGM.getLangOpts().ObjCAutoRefCount)
     AddObjCARCExceptionMetadata(CI);
 
-  // Suppress tail calls if requested.
+  // Set tail call kind if necessary.
   if (llvm::CallInst *Call = dyn_cast<llvm::CallInst>(CI)) {
     if (TargetDecl && TargetDecl->hasAttr<NotTailCalledAttr>())
       Call->setTailCallKind(llvm::CallInst::TCK_NoTail);
+    else if (IsMustTail)
+      Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
   }
 
   // Add metadata for calls to MSAllocator functions
@@ -5313,6 +5315,24 @@
     return GetUndefRValue(RetTy);
   }
 
+  // If this is a musttail call, return immediately. We do not branch to the
+  // epilogue in this case.
+  if (IsMustTail) {
+    for (auto it = EHStack.find(CurrentCleanupScopeDepth); it != EHStack.end();
+         ++it) {
+      EHCleanupScope *Cleanup = dyn_cast<EHCleanupScope>(&*it);
+      if (!(Cleanup && Cleanup->getCleanup()->isRedundantBeforeReturn()))
+        CGM.ErrorUnsupported(MustTailCall, "tail call skipping over cleanups");
+    }
+    if (CI->getType()->isVoidTy())
+      Builder.CreateRetVoid();
+    else
+      Builder.CreateRet(CI);
+    Builder.ClearInsertionPoint();
+    EnsureInsertPoint();
+    return GetUndefRValue(RetTy);
+  }
+
   // Perform the swifterror writeback.
   if (swiftErrorTemp.isValid()) {
     llvm::Value *errorResult = Builder.CreateLoad(swiftErrorTemp);
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1848,6 +1848,7 @@
   void setFunctionHasBranchIntoScope();
   void setFunctionHasBranchProtectedScope();
   void setFunctionHasIndirectGoto();
+  void setFunctionHasMustTail();
 
   void PushCompoundScope(bool IsStmtExpr);
   void PopCompoundScope();
@@ -11347,6 +11348,18 @@
   /// function, issuing a diagnostic if not.
   void checkVariadicArgument(const Expr *E, VariadicCallType CT);
 
+  /// Check whether the given expression can have musttail applied to it,
+  /// issuing a diagnostic and returning false if not. In the success case,
+  /// the statement is rewritten to remove implicit nodes from the return
+  /// value.
+  bool checkAndRewriteMustTailAttr(Stmt *St, const Attr &MTA);
+
+private:
+  /// Check whether the given expression can have musttail applied to it,
+  /// issuing a diagnostic and returning false if not.
+  bool checkMustTailAttr(const Stmt *St, const Attr &MTA);
+
+public:
   /// Check to see if a given expression could have '.c_str()' called on it.
   bool hasCStrMethod(const Expr *E);
 
Index: clang/include/clang/Sema/ScopeInfo.h
===================================================================
--- clang/include/clang/Sema/ScopeInfo.h
+++ clang/include/clang/Sema/ScopeInfo.h
@@ -118,6 +118,10 @@
   /// Whether this function contains any indirect gotos.
   bool HasIndirectGoto : 1;
 
+  /// Whether this function contains any statement marked with
+  /// \c [[clang::musttail]].
+  bool HasMustTail : 1;
+
   /// Whether a statement was dropped because it was invalid.
   bool HasDroppedStmt : 1;
 
@@ -370,14 +374,13 @@
 public:
   FunctionScopeInfo(DiagnosticsEngine &Diag)
       : Kind(SK_Function), HasBranchProtectedScope(false),
-        HasBranchIntoScope(false), HasIndirectGoto(false),
+        HasBranchIntoScope(false), HasIndirectGoto(false), HasMustTail(false),
         HasDroppedStmt(false), HasOMPDeclareReductionCombiner(false),
         HasFallthroughStmt(false), UsesFPIntrin(false),
-        HasPotentialAvailabilityViolations(false),
-        ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false),
-        ObjCWarnForNoDesignatedInitChain(false), ObjCIsSecondaryInit(false),
-        ObjCWarnForNoInitDelegation(false), NeedsCoroutineSuspends(true),
-        ErrorTrap(Diag) {}
+        HasPotentialAvailabilityViolations(false), ObjCShouldCallSuper(false),
+        ObjCIsDesignatedInit(false), ObjCWarnForNoDesignatedInitChain(false),
+        ObjCIsSecondaryInit(false), ObjCWarnForNoInitDelegation(false),
+        NeedsCoroutineSuspends(true), ErrorTrap(Diag) {}
 
   virtual ~FunctionScopeInfo();
 
@@ -423,6 +426,8 @@
     HasIndirectGoto = true;
   }
 
+  void setHasMustTail() { HasMustTail = true; }
+
   void setHasDroppedStmt() {
     HasDroppedStmt = true;
   }
@@ -450,9 +455,8 @@
   }
 
   bool NeedsScopeChecking() const {
-    return !HasDroppedStmt &&
-        (HasIndirectGoto ||
-          (HasBranchProtectedScope && HasBranchIntoScope));
+    return !HasDroppedStmt && (HasIndirectGoto || HasMustTail ||
+                               (HasBranchProtectedScope && HasBranchIntoScope));
   }
 
   // Add a block introduced in this function.
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2823,6 +2823,45 @@
   "%0 attribute is ignored because there exists no call expression inside the "
   "statement">,
   InGroup<IgnoredAttributes>;
+def err_musttail_needs_trivial_args : Error<
+  "%0 attribute requires that the return value, all parameters, and any "
+  "temporaries created by the expression are trivially destructible">;
+def err_musttail_needs_call : Error<
+  "%0 attribute requires that the return value is the result of a function call"
+  >;
+def err_musttail_needs_prototype : Error<
+  "%0 attribute requires that both caller and callee functions have a "
+  "prototype">;
+def note_musttail_fix_non_prototype : Note<
+  "add 'void' to the parameter list to turn a old-style K&R function "
+  "declaration into a prototype">;
+def err_musttail_forbidden_from_this_context : Error<
+  "%0 attribute cannot be used from "
+  "%select{a block|an Objective-C function|this context}1">;
+def err_musttail_structors_forbidden : Error<
+  "%0 attribute cannot be used when caller or callee is a constructor or "
+  "destructor, as these can have unusual calling conventions">;
+def note_musttail_structors_forbidden : Note<"%select{caller|callee}0 is a "
+  "%select{constructor|destructor}1">;
+def err_musttail_mismatch : Error<
+  "%0 attribute requires that caller and callee have compatible function "
+  "signatures">;
+def note_musttail_mismatch : Note<
+    "target function "
+    "%select{is a member of different class%diff{ (expected $ but has $)|}1,2"
+    "|has different number of parameters (expected %1 but has %2)"
+    "|has type mismatch at %ordinal3 parameter"
+    "%diff{ (expected $ but has $)|}1,2"
+    "|has different return type%diff{ ($ expected but has $)|}1,2}0">;
+def err_musttail_callconv_mismatch : Error<
+  "%0 attribute requires that caller and callee use the same calling convention"
+  >;
+def note_musttail_callconv_mismatch : Note<
+  "target function has calling convention %1 (expected %0)">;
+def err_musttail_scope : Error<
+  "cannot perform a tail call from this return statement">;
+def err_musttail_no_variadic : Error<
+  "%0 attribute may not be used with variadic functions">;
 def err_nsobject_attribute : Error<
   "'NSObject' attribute is for pointer types only">;
 def err_attributes_are_not_compatible : Error<
Index: clang/include/clang/Basic/AttrDocs.td
===================================================================
--- clang/include/clang/Basic/AttrDocs.td
+++ clang/include/clang/Basic/AttrDocs.td
@@ -443,6 +443,32 @@
   }];
 }
 
+def MustTailDocs : Documentation {
+  let Category = DocCatStmt;
+  let Content = [{
+If a ``return`` statement is marked ``musttail``, this indicates that the
+compiler must generate a tail call for the program to be correct, even when
+optimizations are disabled. This guarantees that the call will not cause
+unbounded stack growth if it is part of a recursive cycle in the call graph.
+
+If the callee is a virtual function that is implemented by a thunk, there is
+no guarantee in general that the thunk tail-calls the implementation of the
+virtual function, so such a call in a recursive cycle can still result in
+unbounded stack growth.
+
+``clang::musttail`` can only be applied to a ``return`` statement whose value
+is the result of a function call (even functions returning void must use
+``return``, although no value is returned). The target function must have the
+same number of arguments as the caller. The types of the return value and all
+arguments must be similar according to C++ rules (differing only in cv
+qualifiers or array size), including the implicit "this" argument, if any.
+Any variables in scope, including all arguments to the function and the
+return value must be trivially destructible. The calling convention of the
+caller and callee must match, and they must not be variadic functions or have
+old style K&R C function declarations.
+  }];
+}
+
 def AssertCapabilityDocs : Documentation {
   let Category = DocCatFunction;
   let Heading = "assert_capability, assert_shared_capability";
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -1370,6 +1370,12 @@
   let SimpleHandler = 1;
 }
 
+def MustTail : StmtAttr {
+  let Spellings = [Clang<"musttail">];
+  let Documentation = [MustTailDocs];
+  let Subjects = SubjectList<[ReturnStmt], ErrorDiag, "return statements">;
+}
+
 def FastCall : DeclOrTypeAttr {
   let Spellings = [GCC<"fastcall">, Keyword<"__fastcall">,
                    Keyword<"_fastcall">];
Index: clang/include/clang/AST/IgnoreExpr.h
===================================================================
--- clang/include/clang/AST/IgnoreExpr.h
+++ clang/include/clang/AST/IgnoreExpr.h
@@ -121,6 +121,17 @@
   return E;
 }
 
+inline Expr *IgnoreImplicitConstructorSingleStep(Expr *E) {
+  auto *CCE = dyn_cast<CXXConstructExpr>(E);
+  if (CCE && !isa<CXXTemporaryObjectExpr>(CCE) &&
+      (CCE->getNumArgs() == 1 ||
+       (CCE->getNumArgs() > 1 && CCE->getArg(1)->isDefaultArgument())) &&
+      !CCE->getArg(0)->isDefaultArgument() && !CCE->isListInitialization()) {
+    return CCE->getArg(0);
+  }
+  return E;
+}
+
 inline Expr *IgnoreImplicitAsWrittenSingleStep(Expr *E) {
   if (auto *ICE = dyn_cast<ImplicitCastExpr>(E))
     return ICE->getSubExprAsWritten();
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to