avl updated this revision to Diff 272378.
avl added a comment.

1. deleted code doing more strict tailcall marking.
2. left removal of "AllCallsAreTailCalls".
3. added check for non-capturing calls while tracking alloca.
4. re-titled the patch.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D82085/new/

https://reviews.llvm.org/D82085

Files:
  llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
  llvm/test/Transforms/TailCallElim/tre-noncapturing-alloca-calls.ll

Index: llvm/test/Transforms/TailCallElim/tre-noncapturing-alloca-calls.ll
===================================================================
--- /dev/null
+++ llvm/test/Transforms/TailCallElim/tre-noncapturing-alloca-calls.ll
@@ -0,0 +1,356 @@
+; RUN: opt < %s -tailcallelim -verify-dom-info -S | FileCheck %s
+
+; IR for that test was generated from the following C++ source:
+;
+;int count;
+;__attribute__((noinline)) void globalIncrement(const int* param) { count += *param; }
+;
+;void test(int recurseCount)
+;{
+;    if (recurseCount == 0) return;
+;    int temp = 10;
+;    globalIncrement(&temp);
+;    test(recurseCount - 1);
+;}
+;
+;struct Counter
+;{
+;    int count = 0;
+;
+;    __attribute__((noinline)) void increment(const int param) { count += param;
+;}
+;    virtual __attribute__((noinline)) void increment(const int* param) { count += *param; }
+;
+;    __attribute__((noinline)) int getCount () const { return count; }
+;};
+;
+;void test2(int recurseCount)
+;{
+;    if (recurseCount == 0) return;
+;    Counter counter;
+;    counter.increment(0);
+;    test2(recurseCount - 1);
+;}
+;
+;void test3(int recurseCount)
+;{
+;    if (recurseCount == 0) return;
+;    Counter counter;
+;    counter.increment(counter.getCount());
+;    test3(recurseCount - 1);
+;}
+;
+;void test4(int recurseCount)
+;{
+;    if (recurseCount == 0) return;
+;    int temp = 10;
+;    Counter counter;
+;    counter.increment(&temp);
+;    test4(recurseCount - 1);
+;}
+;
+;struct Counter2 : public Counter
+;{
+;    virtual __attribute__((noinline)) void increment(const int* param) {
+;            ptr = param;
+;            count += *param;
+;    }
+;
+;    const int* ptr;
+;};
+;
+;void test5(int recurseCount)
+;{
+;    if (recurseCount == 0) return;
+;    int temp = 10;
+;    Counter2 counter;
+;    counter.increment(&temp);
+;    test5(recurseCount - 1);
+;}
+;
+
+%struct.Counter = type <{ i32 (...)**, i32, [4 x i8] }>
+%struct.Counter2 = type { %struct.Counter.base, i32* }
+%struct.Counter.base = type <{ i32 (...)**, i32 }>
+
+$_ZN7Counter9incrementEi = comdat any
+
+$_ZNK7Counter8getCountEv = comdat any
+
+$_ZN7Counter9incrementEPKi = comdat any
+
+$_ZN8Counter29incrementEPKi = comdat any
+
+$_ZTV7Counter = comdat any
+
+$_ZTS7Counter = comdat any
+
+$_ZTI7Counter = comdat any
+
+$_ZTV8Counter2 = comdat any
+
+$_ZTS8Counter2 = comdat any
+
+$_ZTI8Counter2 = comdat any
+
+@count = dso_local local_unnamed_addr global i32 0, align 4
+
+@_ZTV7Counter = linkonce_odr dso_local unnamed_addr constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* bitcast ({ i8*, i8* }* @_ZTI7Counter to i8*), i8* bitcast (void (%struct.Counter*, i32*)* @_ZN7Counter9incrementEPKi to i8*)] }, comdat, align 8
+@_ZTVN10__cxxabiv117__class_type_infoE = external dso_local global i8*
+@_ZTS7Counter = linkonce_odr dso_local constant [9 x i8] c"7Counter\00", comdat, align 1
+@_ZTI7Counter = linkonce_odr dso_local constant { i8*, i8* } { i8* bitcast (i8** getelementptr inbounds (i8*, i8** @_ZTVN10__cxxabiv117__class_type_infoE, i64 2) to i8*), i8* getelementptr inbounds ([9 x i8], [9 x i8]* @_ZTS7Counter, i32 0, i32 0) }, comdat, align 8
+@_ZTV8Counter2 = linkonce_odr dso_local unnamed_addr constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* bitcast ({ i8*, i8*, i8* }* @_ZTI8Counter2 to i8*), i8* bitcast (void (%struct.Counter2*, i32*)* @_ZN8Counter29incrementEPKi to i8*)] }, comdat, align 8
+@_ZTVN10__cxxabiv120__si_class_type_infoE = external dso_local global i8*
+@_ZTS8Counter2 = linkonce_odr dso_local constant [10 x i8] c"8Counter2\00", comdat, align 1
+@_ZTI8Counter2 = linkonce_odr dso_local constant { i8*, i8*, i8* } { i8* bitcast (i8** getelementptr inbounds (i8*, i8** @_ZTVN10__cxxabiv120__si_class_type_infoE, i64 2) to i8*), i8* getelementptr inbounds ([10 x i8], [10 x i8]* @_ZTS8Counter2, i32 0, i32 0), i8* bitcast ({ i8*, i8* }* @_ZTI7Counter to i8*) }, comdat, align 8
+
+
+; Function Attrs: nofree noinline norecurse nounwind uwtable
+define dso_local void @_Z15globalIncrementPKi(i32* nocapture readonly %param) local_unnamed_addr #0 {
+entry:
+  %0 = load i32, i32* %param, align 4
+  %1 = load i32, i32* @count, align 4
+  %add = add nsw i32 %1, %0
+  store i32 %add, i32* @count, align 4
+  ret void
+}
+
+; Test that TRE could be done for recursive tail routine containing
+; call to function receiving a pointer to local stack. 
+
+; CHECK: void @_Z4testi
+; CHECK: br label %tailrecurse
+; CHECK: tailrecurse:
+; CHECK-NOT: call void @_Z4testi
+; CHECK: br label %tailrecurse
+; CHECK-NOT: call void @_Z4testi
+; CHECK: ret
+
+; Function Attrs: nounwind uwtable
+define dso_local void @_Z4testi(i32 %recurseCount) local_unnamed_addr #1 {
+entry:
+  %temp = alloca i32, align 4
+  %cmp = icmp eq i32 %recurseCount, 0
+  br i1 %cmp, label %return, label %if.end
+
+if.end:                                           ; preds = %entry
+  %0 = bitcast i32* %temp to i8*
+  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #6
+  store i32 10, i32* %temp, align 4
+  call void @_Z15globalIncrementPKi(i32* nonnull %temp)
+  %sub = add nsw i32 %recurseCount, -1
+  call void @_Z4testi(i32 %sub)
+  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #6
+  br label %return
+
+return:                                           ; preds = %entry, %if.end
+  ret void
+}
+
+; Function Attrs: noinline nounwind uwtable
+define linkonce_odr dso_local void @_ZN7Counter9incrementEi(%struct.Counter* %this, i32 %param) local_unnamed_addr #3 comdat align 2 {
+entry:
+  %count = getelementptr inbounds %struct.Counter, %struct.Counter* %this, i64 0, i32 1
+  %0 = load i32, i32* %count, align 8
+  %add = add nsw i32 %0, %param
+  store i32 %add, i32* %count, align 8
+  ret void
+}
+
+; Test that TRE could be done for recursive tail routine containing
+; class object(call to methods of such a class receives a pointer
+; to local stack("this")).
+
+; CHECK: void @_Z5test2i
+; CHECK: br label %tailrecurse
+; CHECK: tailrecurse:
+; CHECK-NOT: call void @_Z5test2i
+; CHECK: br label %tailrecurse
+; CHECK-NOT: call void @_Z5test2i
+; CHECK: ret
+
+
+; Function Attrs: nounwind uwtable
+define dso_local void @_Z5test2i(i32 %recurseCount) local_unnamed_addr #1 {
+entry:
+  %counter = alloca %struct.Counter, align 8
+  %cmp = icmp eq i32 %recurseCount, 0
+  br i1 %cmp, label %return, label %if.end
+
+if.end:                                           ; preds = %entry
+  %0 = bitcast %struct.Counter* %counter to i8*
+  call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %0) #6
+  %1 = getelementptr inbounds %struct.Counter, %struct.Counter* %counter, i64 0, i32 0
+  store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV7Counter, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
+  %count.i = getelementptr inbounds %struct.Counter, %struct.Counter* %counter, i64 0, i32 1
+  store i32 0, i32* %count.i, align 8
+  call void @_ZN7Counter9incrementEi(%struct.Counter* nonnull %counter, i32 0)
+  %sub = add nsw i32 %recurseCount, -1
+  call void @_Z5test2i(i32 %sub)
+  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %0) #6
+  br label %return
+
+return:                                           ; preds = %entry, %if.end
+  ret void
+}
+
+; Function Attrs: noinline nounwind uwtable
+define linkonce_odr dso_local i32 @_ZNK7Counter8getCountEv(%struct.Counter* %this) local_unnamed_addr #3 comdat align 2 {
+entry:
+  %count = getelementptr inbounds %struct.Counter, %struct.Counter* %this, i64 0, i32 1
+  %0 = load i32, i32* %count, align 8
+  ret i32 %0
+}
+
+; CHECK: void @_Z5test3i
+; CHECK: br label %tailrecurse
+; CHECK: tailrecurse:
+; CHECK-NOT: call void @_Z5test3i
+; CHECK: br label %tailrecurse
+; CHECK-NOT: call void @_Z5test3i
+; CHECK: ret
+
+
+; Function Attrs: nounwind uwtable
+define dso_local void @_Z5test3i(i32 %recurseCount) local_unnamed_addr #1 {
+entry:
+  %counter = alloca %struct.Counter, align 8
+  %cmp = icmp eq i32 %recurseCount, 0
+  br i1 %cmp, label %return, label %if.end
+
+if.end:                                           ; preds = %entry
+  %0 = bitcast %struct.Counter* %counter to i8*
+  call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %0) #6
+  %1 = getelementptr inbounds %struct.Counter, %struct.Counter* %counter, i64 0, i32 0
+  store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV7Counter, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
+  %count.i = getelementptr inbounds %struct.Counter, %struct.Counter* %counter, i64 0, i32 1
+  store i32 0, i32* %count.i, align 8
+  %call = call i32 @_ZNK7Counter8getCountEv(%struct.Counter* nonnull %counter)
+  call void @_ZN7Counter9incrementEi(%struct.Counter* nonnull %counter, i32 %call)
+  %sub = add nsw i32 %recurseCount, -1
+  call void @_Z5test3i(i32 %sub)
+  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %0) #6
+  br label %return
+
+return:                                           ; preds = %entry, %if.end
+  ret void
+}
+
+
+; Function Attrs: noinline nounwind uwtable
+define linkonce_odr dso_local void @_ZN7Counter9incrementEPKi(%struct.Counter* %this, i32* %param) unnamed_addr #3 comdat align 2 {
+entry:
+  %0 = load i32, i32* %param, align 4
+  %count = getelementptr inbounds %struct.Counter, %struct.Counter* %this, i64 0, i32 1
+  %1 = load i32, i32* %count, align 8
+  %add = add nsw i32 %1, %0
+  store i32 %add, i32* %count, align 8
+  ret void
+}
+
+; Test that TRE could be done for recursive tail routine containing
+; class object(call to methods of such a class receives a pointer
+; to local stack("this")) and calling method of that object receiving 
+; pointer to local stack.
+
+; CHECK: void @_Z5test4i
+; CHECK: br label %tailrecurse
+; CHECK: tailrecurse:
+; CHECK-NOT: call void @_Z5test4i
+; CHECK: br label %tailrecurse
+; CHECK-NOT: call void @_Z5test4i
+; CHECK: ret
+
+
+; Function Attrs: nounwind uwtable
+define dso_local void @_Z5test4i(i32 %recurseCount) local_unnamed_addr #1 {
+entry:
+  %temp = alloca i32, align 4
+  %counter = alloca %struct.Counter, align 8
+  %cmp = icmp eq i32 %recurseCount, 0
+  br i1 %cmp, label %return, label %if.end
+
+if.end:                                           ; preds = %entry
+  %0 = bitcast i32* %temp to i8*
+  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #6
+  store i32 10, i32* %temp, align 4
+  %1 = bitcast %struct.Counter* %counter to i8*
+  call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %1) #6
+  %2 = getelementptr inbounds %struct.Counter, %struct.Counter* %counter, i64 0, i32 0
+  store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV7Counter, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %2, align 8
+  %count.i = getelementptr inbounds %struct.Counter, %struct.Counter* %counter, i64 0, i32 1
+  store i32 0, i32* %count.i, align 8
+  call void @_ZN7Counter9incrementEPKi(%struct.Counter* nonnull %counter, i32* nonnull %temp)
+  %sub = add nsw i32 %recurseCount, -1
+  call void @_Z5test4i(i32 %sub)
+  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %1) #6
+  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #6
+  br label %return
+
+return:                                           ; preds = %entry, %if.end
+  ret void
+}
+
+; Function Attrs: noinline nounwind uwtable
+define linkonce_odr dso_local void @_ZN8Counter29incrementEPKi(%struct.Counter2* %this, i32* %param) unnamed_addr #0 comdat align 2 {
+entry:
+  %ptr = getelementptr inbounds %struct.Counter2, %struct.Counter2* %this, i32 0, i32 1
+  store i32* %param, i32** %ptr, align 8
+  %0 = load i32, i32* %param, align 4
+  %1 = bitcast %struct.Counter2* %this to %struct.Counter*
+  %count = getelementptr inbounds %struct.Counter, %struct.Counter* %1, i32 0, i32 1
+  %2 = load i32, i32* %count, align 8
+  %add = add nsw i32 %2, %0
+  store i32 %add, i32* %count, align 8
+  ret void
+}
+
+; Test that TRE could NOT be done for recursive tail routine calling 
+; method of an object receiving pointer to local stack and capturing that pointer.
+
+; CHECK: void @_Z5test5i 
+; CHECK-NOT: tailrecurse
+; CHECK: call void @_Z5test5i
+; CHECK: return
+
+; Function Attrs: nounwind uwtable
+define dso_local void @_Z5test5i(i32 %recurseCount) local_unnamed_addr #1 {
+entry:
+  %temp = alloca i32, align 4
+  %counter = alloca %struct.Counter2, align 8
+  %cmp = icmp eq i32 %recurseCount, 0
+  br i1 %cmp, label %return, label %if.end
+
+if.end:                                           ; preds = %entry
+  %0 = bitcast i32* %temp to i8*
+  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #5
+  store i32 10, i32* %temp, align 4
+  %1 = bitcast %struct.Counter2* %counter to i8*
+  call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %1) #5
+  %2 = getelementptr inbounds %struct.Counter2, %struct.Counter2* %counter, i64 0, i32 0, i32 0
+  %3 = getelementptr inbounds %struct.Counter2, %struct.Counter2* %counter, i64 0, i32 0, i32 1
+  store i32 0, i32* %3, align 8
+  store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV8Counter2, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %2, align 8
+  call void @_ZN8Counter29incrementEPKi(%struct.Counter2* nonnull %counter, i32* nonnull %temp)
+  %sub = add nsw i32 %recurseCount, -1
+  call void @_Z5test5i(i32 %sub)
+  call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %1) #5
+  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #5
+  br label %return
+
+return:                                           ; preds = %entry, %if.end
+  ret void
+}
+
+
+; Function Attrs: argmemonly nounwind willreturn
+declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #2
+
+; Function Attrs: argmemonly nounwind willreturn
+declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #2
+
+attributes #0 = { nofree noinline norecurse nounwind uwtable }
+attributes #1 = { nounwind uwtable }
+attributes #2 = { argmemonly nounwind willreturn }
+attributes #3 = { noinline nounwind uwtable }
+attributes #4 = { nounwind }
Index: llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -89,16 +89,6 @@
 STATISTIC(NumRetDuped,   "Number of return duplicated");
 STATISTIC(NumAccumAdded, "Number of accumulators introduced");
 
-/// Scan the specified function for alloca instructions.
-/// If it contains any dynamic allocas, returns false.
-static bool canTRE(Function &F) {
-  // Because of PR962, we don't TRE dynamic allocas.
-  return llvm::all_of(instructions(F), [](Instruction &I) {
-    auto *AI = dyn_cast<AllocaInst>(&I);
-    return !AI || AI->isStaticAlloca();
-  });
-}
-
 namespace {
 struct AllocaDerivedValueTracker {
   // Start at a root value and walk its use-def chain to mark calls that use the
@@ -132,8 +122,25 @@
         // beyond the lifetime of the current frame.
         if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))
           continue;
-        bool IsNocapture =
-            CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U));
+        bool IsNocapture = false;
+
+        if (CB.isDataOperand(U)) {
+          if (CB.doesNotCapture(CB.getDataOperandNo(U)))
+            IsNocapture = true;
+          else if (Function *CalledFunction = CB.getCalledFunction()) {
+            if (CalledFunction->getBasicBlockList().size() > 0 &&
+                CB.getDataOperandNo(U) < CalledFunction->arg_size()) {
+              if (Argument *Arg =
+                      CalledFunction->getArg(CB.getDataOperandNo(U))) {
+                if (Arg->getType()->isPointerTy())
+                  IsNocapture =
+                      !PointerMayBeCaptured(Arg, /* ReturnCaptures=*/true,
+                                            /* StoreCaptures= */ true);
+              }
+            }
+          }
+        }
+
         callUsesLocalStack(CB, IsNocapture);
         if (IsNocapture) {
           // If the alloca-derived argument is passed in as nocapture, then it
@@ -185,11 +192,9 @@
 };
 }
 
-static bool markTails(Function &F, bool &AllCallsAreTailCalls,
-                      OptimizationRemarkEmitter *ORE) {
+static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) {
   if (F.callsFunctionThatReturnsTwice())
     return false;
-  AllCallsAreTailCalls = true;
 
   // The local stack holds all alloca instructions and all byval arguments.
   AllocaDerivedValueTracker Tracker;
@@ -272,11 +277,8 @@
         }
       }
 
-      if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI)) {
+      if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI))
         DeferredTails.push_back(CI);
-      } else {
-        AllCallsAreTailCalls = false;
-      }
     }
 
     for (auto *SuccBB : make_range(succ_begin(BB), succ_end(BB))) {
@@ -313,8 +315,6 @@
       LLVM_DEBUG(dbgs() << "Marked as tail call candidate: " << *CI << "\n");
       CI->setTailCall();
       Modified = true;
-    } else {
-      AllCallsAreTailCalls = false;
     }
   }
 
@@ -326,6 +326,14 @@
 /// instructions between the call and this instruction are movable.
 ///
 static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
+  if (isa<DbgInfoIntrinsic>(I))
+    return true;
+
+  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
+    if (II->getIntrinsicID() == Intrinsic::lifetime_end ||
+        II->getIntrinsicID() == Intrinsic::assume)
+      return true;
+
   // FIXME: We can move load/store/call/free instructions above the call if the
   // call does not mod/ref the memory location being processed.
   if (I->mayHaveSideEffects())  // This also handles volatile loads.
@@ -392,7 +400,6 @@
   // createTailRecurseLoopHeader the first time we find a call we can eliminate.
   BasicBlock *HeaderBB = nullptr;
   SmallVector<PHINode *, 8> ArgumentPHIs;
-  bool RemovableCallsMustBeMarkedTail = false;
 
   // PHI node to store our return value.
   PHINode *RetPN = nullptr;
@@ -419,8 +426,7 @@
                           DomTreeUpdater &DTU)
       : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
 
-  CallInst *findTRECandidate(Instruction *TI,
-                             bool CannotTailCallElimCallsMarkedTail);
+  CallInst *findTRECandidate(Instruction *TI);
 
   void createTailRecurseLoopHeader(CallInst *CI);
 
@@ -428,14 +434,14 @@
 
   bool eliminateCall(CallInst *CI);
 
-  bool foldReturnAndProcessPred(ReturnInst *Ret,
-                                bool CannotTailCallElimCallsMarkedTail);
+  bool foldReturnAndProcessPred(ReturnInst *Ret);
 
-  bool processReturningBlock(ReturnInst *Ret,
-                             bool CannotTailCallElimCallsMarkedTail);
+  bool processReturningBlock(ReturnInst *Ret);
 
   void cleanupAndFinalize();
 
+  bool canTRE(Function &F);
+
 public:
   static bool eliminate(Function &F, const TargetTransformInfo *TTI,
                         AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
@@ -443,8 +449,7 @@
 };
 } // namespace
 
-CallInst *TailRecursionEliminator::findTRECandidate(
-    Instruction *TI, bool CannotTailCallElimCallsMarkedTail) {
+CallInst *TailRecursionEliminator::findTRECandidate(Instruction *TI) {
   BasicBlock *BB = TI->getParent();
 
   if (&BB->front() == TI) // Make sure there is something before the terminator.
@@ -464,11 +469,6 @@
     --BBI;
   }
 
-  // If this call is marked as a tail call, and if there are dynamic allocas in
-  // the function, we cannot perform this optimization.
-  if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail)
-    return nullptr;
-
   // As a special case, detect code like this:
   //   double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call
   // and disable this xform in this case, because the code generator will
@@ -498,26 +498,13 @@
   BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry);
   BI->setDebugLoc(CI->getDebugLoc());
 
-  // If this function has self recursive calls in the tail position where some
-  // are marked tail and some are not, only transform one flavor or another.
-  // We have to choose whether we move allocas in the entry block to the new
-  // entry block or not, so we can't make a good choice for both. We make this
-  // decision here based on whether the first call we found to remove is
-  // marked tail.
-  // NOTE: We could do slightly better here in the case that the function has
-  // no entry block allocas.
-  RemovableCallsMustBeMarkedTail = CI->isTailCall();
-
-  // If this tail call is marked 'tail' and if there are any allocas in the
-  // entry block, move them up to the new entry block.
-  if (RemovableCallsMustBeMarkedTail)
-    // Move all fixed sized allocas from HeaderBB to NewEntry.
-    for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(),
-                              NEBI = NewEntry->begin();
-         OEBI != E;)
-      if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
-        if (isa<ConstantInt>(AI->getArraySize()))
-          AI->moveBefore(&*NEBI);
+  // Move all fixed sized allocas from HeaderBB to NewEntry.
+  for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(),
+                            NEBI = NewEntry->begin();
+       OEBI != E;)
+    if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
+      if (isa<ConstantInt>(AI->getArraySize()))
+        AI->moveBefore(&*NEBI);
 
   // Now that we have created a new block, which jumps to the entry
   // block, insert a PHI node for each argument of the function.
@@ -620,9 +607,6 @@
   if (!HeaderBB)
     createTailRecurseLoopHeader(CI);
 
-  if (RemovableCallsMustBeMarkedTail && !CI->isTailCall())
-    return false;
-
   // Ok, now that we know we have a pseudo-entry block WITH all of the
   // required PHI nodes, add entries into the PHI node for the actual
   // parameters passed into the tail-recursive call.
@@ -672,8 +656,7 @@
   return true;
 }
 
-bool TailRecursionEliminator::foldReturnAndProcessPred(
-    ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) {
+bool TailRecursionEliminator::foldReturnAndProcessPred(ReturnInst *Ret) {
   BasicBlock *BB = Ret->getParent();
 
   bool Change = false;
@@ -698,8 +681,7 @@
   while (!UncondBranchPreds.empty()) {
     BranchInst *BI = UncondBranchPreds.pop_back_val();
     BasicBlock *Pred = BI->getParent();
-    if (CallInst *CI =
-            findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) {
+    if (CallInst *CI = findTRECandidate(BI)) {
       LLVM_DEBUG(dbgs() << "FOLDING: " << *BB
                         << "INTO UNCOND BRANCH PRED: " << *Pred);
       FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU);
@@ -720,9 +702,8 @@
   return Change;
 }
 
-bool TailRecursionEliminator::processReturningBlock(
-    ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) {
-  CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);
+bool TailRecursionEliminator::processReturningBlock(ReturnInst *Ret) {
+  CallInst *CI = findTRECandidate(Ret);
   if (!CI)
     return false;
 
@@ -801,6 +782,88 @@
   }
 }
 
+/// Returns true if this is a block with "return", "unreachable"
+/// or "unconditional branch" to other block with "return"
+/// instruction.
+static bool isTailBlock(BasicBlock *BB) {
+  Instruction *LastBlockInstr = BB->getTerminator();
+
+  // Check that last instruction is a "return", either "unreachable",
+  // either branch to other block containing "return".
+  if (ReturnInst *Ret = dyn_cast<ReturnInst>(LastBlockInstr))
+    return true;
+
+  if (isa<UnreachableInst>(LastBlockInstr))
+    return true;
+
+  if (BranchInst *Branch = dyn_cast<BranchInst>(LastBlockInstr)) {
+    if (Branch->isUnconditional())
+      if (ReturnInst *Ret = dyn_cast<ReturnInst>(
+              Branch->getSuccessor(0)->getFirstNonPHIOrDbg()))
+        return true;
+  }
+
+  return false;
+}
+
+/// Checks that specified call instruction is in chain of recursive
+/// calls before return.
+static bool areAllLastFuncCallsRecursive(CallInst *Inst, Function &F,
+                                         AliasAnalysis *AA) {
+  BasicBlock::iterator BBI(Inst->getParent()->getTerminator());
+  for (--BBI; &*BBI != Inst; --BBI) {
+    if (CallInst *CI = dyn_cast<CallInst>(&*BBI))
+      if (!canMoveAboveCall(CI, Inst, AA) && CI->getCalledFunction() != &F)
+        return false;
+  }
+
+  return true;
+}
+
+bool TailRecursionEliminator::canTRE(Function &F) {
+  // The local stack holds all alloca instructions and all byval arguments.
+  AllocaDerivedValueTracker Tracker;
+  for (Argument &Arg : F.args()) {
+    if (Arg.hasByValAttr())
+      Tracker.walk(&Arg);
+  }
+  for (auto &BB : F) {
+    for (auto &I : BB)
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(&I))
+        Tracker.walk(AI);
+  }
+
+  // do not do TRE if any pointer to local stack has escaped.
+  if (!Tracker.EscapePoints.empty())
+    return false;
+
+  return !llvm::any_of(instructions(F), [&](Instruction &I) {
+    // Because of PR962, we don't TRE dynamic allocas.
+    if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
+      if (AI && !AI->isStaticAlloca())
+        return true;
+    } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
+      if (CI->getCalledFunction() == &F) {
+        // Do not do TRE if CI explicitly marked as NoTailcall or has Operand
+        // Bundles.
+        if (CI->isNoTailCall() || CI->hasOperandBundles())
+          return true;
+
+        // Do not do TRE if exists recursive calls which are not last calls.
+        if (!isTailBlock(CI->getParent()) ||
+            !areAllLastFuncCallsRecursive(CI, F, AA))
+          return true;
+
+        // Do not do TRE if recursive call receives pointer to the local stack.
+        if (Tracker.AllocaUsers.count(CI) > 0)
+          return true;
+      }
+    }
+
+    return false;
+  });
+}
+
 bool TailRecursionEliminator::eliminate(Function &F,
                                         const TargetTransformInfo *TTI,
                                         AliasAnalysis *AA,
@@ -810,23 +873,18 @@
     return false;
 
   bool MadeChange = false;
-  bool AllCallsAreTailCalls = false;
-  MadeChange |= markTails(F, AllCallsAreTailCalls, ORE);
-  if (!AllCallsAreTailCalls)
-    return MadeChange;
+  MadeChange |= markTails(F, ORE);
 
   // If this function is a varargs function, we won't be able to PHI the args
   // right, so don't even try to convert it...
   if (F.getFunctionType()->isVarArg())
     return MadeChange;
 
-  // If false, we cannot perform TRE on tail calls marked with the 'tail'
-  // attribute, because doing so would cause the stack size to increase (real
-  // TRE would deallocate variable sized allocas, TRE doesn't).
-  bool CanTRETailMarkedCall = canTRE(F);
-
   TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
 
+  if (!TRE.canTRE(F))
+    return MadeChange;
+
   // Change any tail recursive calls to loops.
   //
   // FIXME: The code generator produces really bad code when an 'escaping
@@ -836,9 +894,9 @@
   for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) {
     BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB.
     if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
-      bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall);
+      bool Change = TRE.processReturningBlock(Ret);
       if (!Change && BB->getFirstNonPHIOrDbg() == Ret)
-        Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall);
+        Change = TRE.foldReturnAndProcessPred(Ret);
       MadeChange |= Change;
     }
   }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to