https://github.com/usx95 updated 
https://github.com/llvm/llvm-project/pull/149158

>From 0311169154e4db2bb049168a1e73e3ae67d96848 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <u...@google.com>
Date: Wed, 16 Jul 2025 18:22:39 +0000
Subject: [PATCH] [LifetimeSafety] Revamp test suite using unittests

---
 .../clang/Analysis/Analyses/LifetimeSafety.h  |  78 +++-
 clang/lib/Analysis/LifetimeSafety.cpp         | 201 ++++++---
 clang/lib/Sema/AnalysisBasedWarnings.cpp      |   4 +-
 clang/unittests/Analysis/CMakeLists.txt       |   1 +
 .../unittests/Analysis/LifetimeSafetyTest.cpp | 424 ++++++++++++++++++
 5 files changed, 648 insertions(+), 60 deletions(-)
 create mode 100644 clang/unittests/Analysis/LifetimeSafetyTest.cpp

diff --git a/clang/include/clang/Analysis/Analyses/LifetimeSafety.h 
b/clang/include/clang/Analysis/Analyses/LifetimeSafety.h
index 9998702a41cab..ff71147a20f6c 100644
--- a/clang/include/clang/Analysis/Analyses/LifetimeSafety.h
+++ b/clang/include/clang/Analysis/Analyses/LifetimeSafety.h
@@ -17,14 +17,82 @@
 
//===----------------------------------------------------------------------===//
 #ifndef LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H
 #define LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H
-#include "clang/AST/DeclBase.h"
 #include "clang/Analysis/AnalysisDeclContext.h"
 #include "clang/Analysis/CFG.h"
-namespace clang {
+#include "llvm/ADT/ImmutableSet.h"
+#include "llvm/ADT/StringMap.h"
+#include <memory>
 
-void runLifetimeSafetyAnalysis(const DeclContext &DC, const CFG &Cfg,
-                               AnalysisDeclContext &AC);
+namespace clang::lifetimes {
+namespace internal {
+// Forward declarations of internal types.
+class Fact;
+class FactManager;
+class LoanPropagationAnalysis;
+struct LifetimeFactory;
 
-} // namespace clang
+/// A generic, type-safe wrapper for an ID, distinguished by its `Tag` type.
+/// Used for giving ID to loans and origins.
+template <typename Tag> struct ID {
+  uint32_t Value = 0;
+
+  bool operator==(const ID<Tag> &Other) const { return Value == Other.Value; }
+  bool operator!=(const ID<Tag> &Other) const { return !(*this == Other); }
+  bool operator<(const ID<Tag> &Other) const { return Value < Other.Value; }
+  ID<Tag> operator++(int) {
+    ID<Tag> Tmp = *this;
+    ++Value;
+    return Tmp;
+  }
+  void Profile(llvm::FoldingSetNodeID &IDBuilder) const {
+    IDBuilder.AddInteger(Value);
+  }
+};
+
+using LoanID = ID<struct LoanTag>;
+using OriginID = ID<struct OriginTag>;
+
+// Using LLVM's immutable collections is efficient for dataflow analysis
+// as it avoids deep copies during state transitions.
+// TODO(opt): Consider using a bitset to represent the set of loans.
+using LoanSet = llvm::ImmutableSet<LoanID>;
+using OriginSet = llvm::ImmutableSet<OriginID>;
+
+using ProgramPoint = std::pair<const CFGBlock *, const Fact *>;
+
+/// Running the lifetime safety analysis and querying its results. It
+/// encapsulates the various dataflow analyses.
+class LifetimeSafetyAnalysis {
+public:
+  LifetimeSafetyAnalysis(AnalysisDeclContext &AC);
+  ~LifetimeSafetyAnalysis();
+
+  void run();
+
+  /// Returns the set of loans an origin holds at a specific program point.
+  LoanSet getLoansAtPoint(OriginID OID, ProgramPoint PP) const;
+
+  /// Finds the OriginID for a given declaration.
+  /// Returns a null optional if not found.
+  std::optional<OriginID> getOriginIDForDecl(const ValueDecl *D) const;
+
+  /// Finds the LoanID for a loan created on a specific variable.
+  /// Returns a null optional if not found.
+  std::optional<LoanID> getLoanIDForVar(const VarDecl *VD) const;
+
+  llvm::StringMap<ProgramPoint> getTestPoints() const;
+
+private:
+  AnalysisDeclContext &AC;
+  std::unique_ptr<LifetimeFactory> Factory;
+  std::unique_ptr<FactManager> FactMgr;
+  std::unique_ptr<LoanPropagationAnalysis> LoanPropagation;
+};
+} // namespace internal
+
+/// The main entry point for the analysis.
+void runLifetimeSafetyAnalysis(AnalysisDeclContext &AC);
+
+} // namespace clang::lifetimes
 
 #endif // LLVM_CLANG_ANALYSIS_ANALYSES_LIFETIMESAFETY_H
diff --git a/clang/lib/Analysis/LifetimeSafety.cpp 
b/clang/lib/Analysis/LifetimeSafety.cpp
index e3a03cf93880e..50b47e1431723 100644
--- a/clang/lib/Analysis/LifetimeSafety.cpp
+++ b/clang/lib/Analysis/LifetimeSafety.cpp
@@ -24,8 +24,14 @@
 #include "llvm/Support/TimeProfiler.h"
 #include <cstdint>
 
-namespace clang {
+namespace clang::lifetimes {
+namespace internal {
 namespace {
+template <typename Tag>
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, ID<Tag> ID) {
+  return OS << ID.Value;
+}
+} // namespace
 
 /// Represents the storage location being borrowed, e.g., a specific stack
 /// variable.
@@ -36,32 +42,6 @@ struct AccessPath {
   AccessPath(const clang::ValueDecl *D) : D(D) {}
 };
 
-/// A generic, type-safe wrapper for an ID, distinguished by its `Tag` type.
-/// Used for giving ID to loans and origins.
-template <typename Tag> struct ID {
-  uint32_t Value = 0;
-
-  bool operator==(const ID<Tag> &Other) const { return Value == Other.Value; }
-  bool operator!=(const ID<Tag> &Other) const { return !(*this == Other); }
-  bool operator<(const ID<Tag> &Other) const { return Value < Other.Value; }
-  ID<Tag> operator++(int) {
-    ID<Tag> Tmp = *this;
-    ++Value;
-    return Tmp;
-  }
-  void Profile(llvm::FoldingSetNodeID &IDBuilder) const {
-    IDBuilder.AddInteger(Value);
-  }
-};
-
-template <typename Tag>
-inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, ID<Tag> ID) {
-  return OS << ID.Value;
-}
-
-using LoanID = ID<struct LoanTag>;
-using OriginID = ID<struct OriginTag>;
-
 /// Information about a single borrow, or "Loan". A loan is created when a
 /// reference or pointer is created.
 struct Loan {
@@ -223,7 +203,9 @@ class Fact {
     /// An origin is propagated from a source to a destination (e.g., p = q).
     AssignOrigin,
     /// An origin escapes the function by flowing into the return value.
-    ReturnOfOrigin
+    ReturnOfOrigin,
+    /// A marker for a specific point in the code, for testing.
+    TestPoint,
   };
 
 private:
@@ -310,6 +292,24 @@ class ReturnOfOriginFact : public Fact {
   }
 };
 
+/// A dummy-fact used to mark a specific point in the code for testing.
+/// It is generated by recognizing a `void("__lifetime_test_point_...")` cast.
+class TestPointFact : public Fact {
+  std::string Annotation;
+
+public:
+  static bool classof(const Fact *F) { return F->getKind() == Kind::TestPoint; 
}
+
+  explicit TestPointFact(std::string Annotation)
+      : Fact(Kind::TestPoint), Annotation(std::move(Annotation)) {}
+
+  const std::string &getAnnotation() const { return Annotation; }
+
+  void dump(llvm::raw_ostream &OS) const override {
+    OS << "TestPoint (Annotation: \"" << getAnnotation() << "\")\n";
+  }
+};
+
 class FactManager {
 public:
   llvm::ArrayRef<const Fact *> getFacts(const CFGBlock *B) const {
@@ -363,6 +363,7 @@ class FactManager {
 };
 
 class FactGenerator : public ConstStmtVisitor<FactGenerator> {
+  using Base = ConstStmtVisitor<FactGenerator>;
 
 public:
   FactGenerator(FactManager &FactMgr, AnalysisDeclContext &AC)
@@ -458,6 +459,15 @@ class FactGenerator : public 
ConstStmtVisitor<FactGenerator> {
     }
   }
 
+  void VisitCXXFunctionalCastExpr(const CXXFunctionalCastExpr *FCE) {
+    // Check if this is a test point marker. If so, we are done with this
+    // expression.
+    if (VisitTestPoint(FCE))
+      return;
+    // Visit as normal otherwise.
+    Base::VisitCXXFunctionalCastExpr(FCE);
+  }
+
 private:
   // Check if a type has an origin.
   bool hasOrigin(QualType QT) { return QT->isPointerOrReferenceType(); }
@@ -491,6 +501,27 @@ class FactGenerator : public 
ConstStmtVisitor<FactGenerator> {
     }
   }
 
+  /// Checks if the expression is a `void("__lifetime_test_point_...")` cast.
+  /// If so, creates a `TestPointFact` and returns true.
+  bool VisitTestPoint(const CXXFunctionalCastExpr *FCE) {
+    if (!FCE->getType()->isVoidType())
+      return false;
+
+    const auto *SubExpr = FCE->getSubExpr()->IgnoreParenImpCasts();
+    if (const auto *SL = dyn_cast<StringLiteral>(SubExpr)) {
+      llvm::StringRef LiteralValue = SL->getString();
+      const std::string Prefix = "__lifetime_test_point_";
+
+      if (LiteralValue.starts_with(Prefix)) {
+        std::string Annotation = 
LiteralValue.drop_front(Prefix.length()).str();
+        CurrentBlockFacts.push_back(
+            FactMgr.createFact<TestPointFact>(Annotation));
+        return true;
+      }
+    }
+    return false;
+  }
+
   FactManager &FactMgr;
   AnalysisDeclContext &AC;
   llvm::SmallVector<Fact *> CurrentBlockFacts;
@@ -500,6 +531,9 @@ class FactGenerator : public 
ConstStmtVisitor<FactGenerator> {
 //                         Generic Dataflow Analysis
 // ========================================================================= //
 
+// DO NOT SUBMIT: TODO: Document notion of before or after in the analyses.
+using ProgramPoint = std::pair<const CFGBlock *, const Fact *>;
+
 enum class Direction { Forward, Backward };
 
 /// A generic, policy-based driver for dataflow analyses. It combines
@@ -532,6 +566,7 @@ class DataflowAnalysis {
 
   llvm::DenseMap<const CFGBlock *, Lattice> InStates;
   llvm::DenseMap<const CFGBlock *, Lattice> OutStates;
+  llvm::DenseMap<ProgramPoint, Lattice> PerPointStates;
 
   static constexpr bool isForward() { return Dir == Direction::Forward; }
 
@@ -577,6 +612,8 @@ class DataflowAnalysis {
     }
   }
 
+  Lattice getState(ProgramPoint P) const { return PerPointStates.lookup(P); }
+
   Lattice getInState(const CFGBlock *B) const { return InStates.lookup(B); }
 
   Lattice getOutState(const CFGBlock *B) const { return OutStates.lookup(B); }
@@ -590,18 +627,22 @@ class DataflowAnalysis {
     getOutState(&B).dump(llvm::dbgs());
   }
 
+private:
   /// Computes the state at one end of a block by applying all its facts
   /// sequentially to a given state from the other end.
-  /// TODO: We might need to store intermediate states per-fact in the block 
for
-  /// later analysis.
   Lattice transferBlock(const CFGBlock *Block, Lattice State) {
     auto Facts = AllFacts.getFacts(Block);
-    if constexpr (isForward())
-      for (const Fact *F : Facts)
+    if constexpr (isForward()) {
+      for (const Fact *F : Facts) {
         State = transferFact(State, F);
-    else
-      for (const Fact *F : llvm::reverse(Facts))
+        PerPointStates[{Block, F}] = State;
+      }
+    } else {
+      for (const Fact *F : llvm::reverse(Facts)) {
         State = transferFact(State, F);
+        PerPointStates[{Block, F}] = State;
+      }
+    }
     return State;
   }
 
@@ -617,6 +658,8 @@ class DataflowAnalysis {
       return D->transfer(In, *F->getAs<AssignOriginFact>());
     case Fact::Kind::ReturnOfOrigin:
       return D->transfer(In, *F->getAs<ReturnOfOriginFact>());
+    case Fact::Kind::TestPoint:
+      return D->transfer(In, *F->getAs<TestPointFact>());
     }
     llvm_unreachable("Unknown fact kind");
   }
@@ -626,14 +669,16 @@ class DataflowAnalysis {
   Lattice transfer(Lattice In, const ExpireFact &) { return In; }
   Lattice transfer(Lattice In, const AssignOriginFact &) { return In; }
   Lattice transfer(Lattice In, const ReturnOfOriginFact &) { return In; }
+  Lattice transfer(Lattice In, const TestPointFact &) { return In; }
 };
 
 namespace utils {
 
 /// Computes the union of two ImmutableSets.
 template <typename T>
-llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A, llvm::ImmutableSet<T> B,
-                           typename llvm::ImmutableSet<T>::Factory &F) {
+static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A,
+                                  llvm::ImmutableSet<T> B,
+                                  typename llvm::ImmutableSet<T>::Factory &F) {
   if (A.getHeight() < B.getHeight())
     std::swap(A, B);
   for (const T &E : B)
@@ -646,7 +691,7 @@ llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A, 
llvm::ImmutableSet<T> B,
 // efficient merge could be implemented using a Patricia Trie or HAMT
 // instead of the current AVL-tree-based ImmutableMap.
 template <typename K, typename V, typename Joiner>
-llvm::ImmutableMap<K, V>
+static llvm::ImmutableMap<K, V>
 join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> B,
      typename llvm::ImmutableMap<K, V>::Factory &F, Joiner joinValues) {
   if (A.getHeight() < B.getHeight())
@@ -670,10 +715,6 @@ join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> 
B,
 //                          Loan Propagation Analysis
 // ========================================================================= //
 
-// Using LLVM's immutable collections is efficient for dataflow analysis
-// as it avoids deep copies during state transitions.
-// TODO(opt): Consider using a bitset to represent the set of loans.
-using LoanSet = llvm::ImmutableSet<LoanID>;
 using OriginLoanMap = llvm::ImmutableMap<OriginID, LoanSet>;
 
 /// An object to hold the factories for immutable collections, ensuring
@@ -769,6 +810,10 @@ class LoanPropagationAnalysis
         Factory.OriginMapFactory.add(In.Origins, DestOID, SrcLoans));
   }
 
+  LoanSet getLoans(OriginID OID, ProgramPoint P) {
+    return getLoans(getState(P), OID);
+  }
+
 private:
   LoanSet getLoans(Lattice L, OriginID OID) {
     if (auto *Loans = L.Origins.lookup(OID))
@@ -779,22 +824,31 @@ class LoanPropagationAnalysis
 
 // ========================================================================= //
 //  TODO:
-// - Modifying loan propagation to answer `LoanSet getLoans(Origin O, Point P)`
 // - Modify loan expiry analysis to answer `bool isExpired(Loan L, Point P)`
 // - Modify origin liveness analysis to answer `bool isLive(Origin O, Point P)`
 // - Using the above three to perform the final error reporting.
 // ========================================================================= //
-} // anonymous namespace
 
-void runLifetimeSafetyAnalysis(const DeclContext &DC, const CFG &Cfg,
-                               AnalysisDeclContext &AC) {
+// ========================================================================= //
+//                  LifetimeSafetyAnalysis Class Implementation
+// ========================================================================= //
+
+LifetimeSafetyAnalysis::~LifetimeSafetyAnalysis() = default;
+
+LifetimeSafetyAnalysis::LifetimeSafetyAnalysis(AnalysisDeclContext &AC)
+    : AC(AC), Factory(std::make_unique<LifetimeFactory>()),
+      FactMgr(std::make_unique<FactManager>()) {}
+
+void LifetimeSafetyAnalysis::run() {
   llvm::TimeTraceScope TimeProfile("LifetimeSafetyAnalysis");
+
+  const CFG &Cfg = *AC.getCFG();
   DEBUG_WITH_TYPE("PrintCFG", Cfg.dump(AC.getASTContext().getLangOpts(),
                                        /*ShowColors=*/true));
-  FactManager FactMgr;
-  FactGenerator FactGen(FactMgr, AC);
+
+  FactGenerator FactGen(*FactMgr, AC);
   FactGen.run();
-  DEBUG_WITH_TYPE("LifetimeFacts", FactMgr.dump(Cfg, AC));
+  DEBUG_WITH_TYPE("LifetimeFacts", FactMgr->dump(Cfg, AC));
 
   /// TODO(opt): Consider optimizing individual blocks before running the
   /// dataflow analysis.
@@ -805,9 +859,50 @@ void runLifetimeSafetyAnalysis(const DeclContext &DC, 
const CFG &Cfg,
   ///    blocks; only Decls are visible.  Therefore, loans in a block that
   ///    never reach an Origin associated with a Decl can be safely dropped by
   ///    the analysis.
-  LifetimeFactory Factory;
-  LoanPropagationAnalysis LoanPropagation(Cfg, AC, FactMgr, Factory);
-  LoanPropagation.run();
-  DEBUG_WITH_TYPE("LifetimeLoanPropagation", LoanPropagation.dump());
+  LoanPropagation =
+      std::make_unique<LoanPropagationAnalysis>(Cfg, AC, *FactMgr, *Factory);
+  LoanPropagation->run();
+  DEBUG_WITH_TYPE("LifetimeLoanPropagation", LoanPropagation->dump());
+}
+
+LoanSet LifetimeSafetyAnalysis::getLoansAtPoint(OriginID OID,
+                                                ProgramPoint PP) const {
+  assert(LoanPropagation && "Analysis has not been run.");
+  return LoanPropagation->getLoans(OID, PP);
+}
+
+std::optional<OriginID>
+LifetimeSafetyAnalysis::getOriginIDForDecl(const ValueDecl *D) const {
+  assert(FactMgr && "FactManager not initialized");
+  // This assumes the OriginManager's `get` can find an existing origin.
+  // We might need a `find` method on OriginManager to avoid `getOrCreate` 
logic
+  // in a const-query context if that becomes an issue.
+  return FactMgr->getOriginMgr().get(*D);
+}
+
+std::optional<LoanID>
+LifetimeSafetyAnalysis::getLoanIDForVar(const VarDecl *VD) const {
+  assert(FactMgr && "FactManager not initialized");
+  for (const Loan &L : FactMgr->getLoanMgr().getLoans()) {
+    if (L.Path.D == VD)
+      return L.ID;
+  }
+  return std::nullopt;
+}
+
+llvm::StringMap<ProgramPoint> LifetimeSafetyAnalysis::getTestPoints() const {
+  assert(FactMgr && "FactManager not initialized");
+  llvm::StringMap<ProgramPoint> AnnotationToPointMap;
+  for (const CFGBlock *Block : *AC.getCFG())
+    for (const Fact *F : FactMgr->getFacts(Block))
+      if (const auto *TPF = F->getAs<TestPointFact>())
+        AnnotationToPointMap[TPF->getAnnotation()] = {Block, F};
+  return AnnotationToPointMap;
+}
+} // namespace internal
+
+void runLifetimeSafetyAnalysis(AnalysisDeclContext &AC) {
+  internal::LifetimeSafetyAnalysis Analysis(AC);
+  Analysis.run();
 }
-} // namespace clang
+} // namespace clang::lifetimes
diff --git a/clang/lib/Sema/AnalysisBasedWarnings.cpp 
b/clang/lib/Sema/AnalysisBasedWarnings.cpp
index 5eba024e83634..89c5a3596f584 100644
--- a/clang/lib/Sema/AnalysisBasedWarnings.cpp
+++ b/clang/lib/Sema/AnalysisBasedWarnings.cpp
@@ -3030,8 +3030,8 @@ void clang::sema::AnalysisBasedWarnings::IssueWarnings(
   // TODO: Enable lifetime safety analysis for other languages once it is
   // stable.
   if (EnableLifetimeSafetyAnalysis && S.getLangOpts().CPlusPlus) {
-    if (CFG *cfg = AC.getCFG())
-      runLifetimeSafetyAnalysis(*cast<DeclContext>(D), *cfg, AC);
+    if (AC.getCFG())
+      lifetimes::runLifetimeSafetyAnalysis(AC);
   }
   // Check for violations of "called once" parameter properties.
   if (S.getLangOpts().ObjC && !S.getLangOpts().CPlusPlus &&
diff --git a/clang/unittests/Analysis/CMakeLists.txt 
b/clang/unittests/Analysis/CMakeLists.txt
index 059a74843155c..52e7d2854633d 100644
--- a/clang/unittests/Analysis/CMakeLists.txt
+++ b/clang/unittests/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ add_clang_unittest(ClangAnalysisTests
   CloneDetectionTest.cpp
   ExprMutationAnalyzerTest.cpp
   IntervalPartitionTest.cpp
+  LifetimeSafetyTest.cpp
   MacroExpansionContextTest.cpp
   UnsafeBufferUsageTest.cpp
   CLANG_LIBS
diff --git a/clang/unittests/Analysis/LifetimeSafetyTest.cpp 
b/clang/unittests/Analysis/LifetimeSafetyTest.cpp
new file mode 100644
index 0000000000000..d9da5ce92550c
--- /dev/null
+++ b/clang/unittests/Analysis/LifetimeSafetyTest.cpp
@@ -0,0 +1,424 @@
+//===- LifetimeSafetyTest.cpp - Lifetime Safety Tests -*---------- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Analysis/Analyses/LifetimeSafety.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Testing/TestAST.h"
+#include "llvm/ADT/StringMap.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <optional>
+#include <vector>
+
+namespace clang::lifetimes::internal {
+namespace {
+
+using namespace ast_matchers;
+using ::testing::UnorderedElementsAreArray;
+
+// A helper class to run the full lifetime analysis on a piece of code
+// and provide an interface for querying the results.
+class LifetimeTestRunner {
+public:
+  LifetimeTestRunner(llvm::StringRef Code) {
+    std::string FullCode = R"(
+      #define POINT(name) void("__lifetime_test_point_" #name)
+      struct MyObj { ~MyObj() {} int i; };
+    )";
+    FullCode += Code.str();
+
+    TestAST = std::make_unique<clang::TestAST>(FullCode);
+    ASTCtx = &TestAST->context();
+
+    // Find the target function using AST matchers.
+    auto MatchResult =
+        match(functionDecl(hasName("target")).bind("target"), *ASTCtx);
+    auto *FD = selectFirst<FunctionDecl>("target", MatchResult);
+    if (!FD) {
+      ADD_FAILURE() << "Test case must have a function named 'target'";
+      return;
+    }
+    AnalysisCtx = std::make_unique<AnalysisDeclContext>(nullptr, FD);
+    AnalysisCtx->getCFGBuildOptions().setAllAlwaysAdd();
+
+    // Run the main analysis.
+    Analysis = std::make_unique<LifetimeSafetyAnalysis>(*AnalysisCtx);
+    Analysis->run();
+
+    AnnotationToPointMap = Analysis->getTestPoints();
+  }
+
+  LifetimeSafetyAnalysis &getAnalysis() { return *Analysis; }
+  ASTContext &getASTContext() { return *ASTCtx; }
+
+  ProgramPoint getProgramPoint(llvm::StringRef Annotation) {
+    auto It = AnnotationToPointMap.find(Annotation);
+    if (It == AnnotationToPointMap.end()) {
+      ADD_FAILURE() << "Annotation '" << Annotation << "' not found.";
+      return {nullptr, nullptr};
+    }
+    return It->second;
+  }
+
+private:
+  std::unique_ptr<TestAST> TestAST;
+  ASTContext *ASTCtx = nullptr;
+  std::unique_ptr<AnalysisDeclContext> AnalysisCtx;
+  std::unique_ptr<LifetimeSafetyAnalysis> Analysis;
+  llvm::StringMap<ProgramPoint> AnnotationToPointMap;
+};
+
+// A convenience wrapper that uses the LifetimeSafetyAnalysis public API.
+class LifetimeTestHelper {
+public:
+  LifetimeTestHelper(LifetimeTestRunner &Runner)
+      : Runner(Runner), Analysis(Runner.getAnalysis()) {}
+
+  std::optional<OriginID> getOriginForDecl(llvm::StringRef VarName) {
+    auto *VD = findDecl<ValueDecl>(VarName);
+    if (!VD)
+      return std::nullopt;
+    auto OID = Analysis.getOriginIDForDecl(VD);
+    if (!OID)
+      ADD_FAILURE() << "Origin for '" << VarName << "' not found.";
+    return OID;
+  }
+
+  std::optional<LoanID> getLoanForVar(llvm::StringRef VarName) {
+    auto *VD = findDecl<VarDecl>(VarName);
+    if (!VD)
+      return std::nullopt;
+    auto LID = Analysis.getLoanIDForVar(VD);
+    if (!LID)
+      ADD_FAILURE() << "Loan for '" << VarName << "' not found.";
+    return LID;
+  }
+
+  std::optional<LoanSet> getLoansAtPoint(OriginID OID,
+                                         llvm::StringRef Annotation) {
+    ProgramPoint PP = Runner.getProgramPoint(Annotation);
+    if (!PP.first)
+      return std::nullopt;
+    return Analysis.getLoansAtPoint(OID, PP);
+  }
+
+private:
+  template <typename DeclT> DeclT *findDecl(llvm::StringRef Name) {
+    auto &Ctx = Runner.getASTContext();
+    auto Results = match(valueDecl(hasName(Name)).bind("v"), Ctx);
+    if (Results.empty()) {
+      ADD_FAILURE() << "Declaration '" << Name << "' not found in AST.";
+      return nullptr;
+    }
+    return const_cast<DeclT *>(selectFirst<DeclT>("v", Results));
+  }
+
+  LifetimeTestRunner &Runner;
+  LifetimeSafetyAnalysis &Analysis;
+};
+
+// ========================================================================= //
+//                         GTest Matchers & Fixture
+// ========================================================================= //
+
+// It holds the name of the origin variable and a reference to the helper.
+class OriginInfo {
+public:
+  OriginInfo(llvm::StringRef OriginVar, LifetimeTestHelper &Helper)
+      : OriginVar(OriginVar), Helper(Helper) {}
+  llvm::StringRef OriginVar;
+  LifetimeTestHelper &Helper;
+};
+
+// The implementation of the matcher. It takes a vector of strings.
+MATCHER_P2(HasLoansToImpl, LoanVars, Annotation, "") {
+  const OriginInfo &Info = arg;
+  std::optional<OriginID> OIDOpt = 
Info.Helper.getOriginForDecl(Info.OriginVar);
+  if (!OIDOpt) {
+    *result_listener << "could not find origin for '" << Info.OriginVar.str()
+                     << "'";
+    return false;
+  }
+
+  std::optional<LoanSet> ActualLoansSetOpt =
+      Info.Helper.getLoansAtPoint(*OIDOpt, Annotation);
+  if (!ActualLoansSetOpt) {
+    *result_listener << "could not get a valid loan set at point '"
+                     << Annotation << "'";
+    return false;
+  }
+  std::vector<LoanID> ActualLoans(ActualLoansSetOpt->begin(),
+                                  ActualLoansSetOpt->end());
+
+  std::vector<LoanID> ExpectedLoans;
+  for (const auto &LoanVar : LoanVars) {
+    std::optional<LoanID> ExpectedLIDOpt = Info.Helper.getLoanForVar(LoanVar);
+    if (!ExpectedLIDOpt) {
+      *result_listener << "could not find loan for var '" << LoanVar << "'";
+      return false;
+    }
+    ExpectedLoans.push_back(*ExpectedLIDOpt);
+  }
+
+  return ExplainMatchResult(UnorderedElementsAreArray(ExpectedLoans),
+                            ActualLoans, result_listener);
+}
+
+// Base test fixture to manage the runner and helper.
+class LifetimeAnalysisTest : public ::testing::Test {
+protected:
+  void SetupTest(llvm::StringRef Code) {
+    Runner = std::make_unique<LifetimeTestRunner>(Code);
+    Helper = std::make_unique<LifetimeTestHelper>(*Runner);
+  }
+
+  OriginInfo Origin(llvm::StringRef OriginVar) {
+    return OriginInfo(OriginVar, *Helper);
+  }
+
+  // Factory function that hides the std::vector creation.
+  auto HasLoansTo(std::initializer_list<std::string> LoanVars,
+                  const char *Annotation) {
+    return HasLoansToImpl(std::vector<std::string>(LoanVars), Annotation);
+  }
+
+  std::unique_ptr<LifetimeTestRunner> Runner;
+  std::unique_ptr<LifetimeTestHelper> Helper;
+};
+
+// ========================================================================= //
+//                                 TEST CASES
+// ========================================================================= //
+
+TEST_F(LifetimeAnalysisTest, SimpleLoanAndOrigin) {
+  SetupTest(R"(
+    void target() {
+      int x;
+      int* p = &x;
+      POINT(p1);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"x"}, "p1"));
+}
+
+TEST_F(LifetimeAnalysisTest, OverwriteOrigin) {
+  SetupTest(R"(
+    void target() {
+      MyObj s1, s2;
+
+      MyObj* p = &s1;
+      POINT(after_s1);
+
+      p = &s2;
+      POINT(after_s2);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1"}, "after_s1"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s2"}, "after_s2"));
+}
+
+TEST_F(LifetimeAnalysisTest, ConditionalLoan) {
+  SetupTest(R"(
+    void target(bool cond) {
+      int a, b;
+      int *p = nullptr;
+      if (cond) {
+        p = &a;
+        POINT(after_then);
+      } else {
+        p = &b;
+        POINT(after_else);
+      }
+      POINT(after_if);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"a"}, "after_then"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"b"}, "after_else"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"a", "b"}, "after_if"));
+}
+
+TEST_F(LifetimeAnalysisTest, PointerChain) {
+  SetupTest(R"(
+    void target() {
+      MyObj y;
+      MyObj* ptr1 = &y;
+      POINT(p1);
+
+      MyObj* ptr2 = ptr1;
+      POINT(p2);
+
+      ptr2 = ptr1;
+      POINT(p3);
+
+      ptr2 = ptr2; // Self assignment
+      POINT(p4);
+    }
+  )");
+  EXPECT_THAT(Origin("ptr1"), HasLoansTo({"y"}, "p1"));
+  EXPECT_THAT(Origin("ptr2"), HasLoansTo({"y"}, "p2"));
+  EXPECT_THAT(Origin("ptr2"), HasLoansTo({"y"}, "p3"));
+  EXPECT_THAT(Origin("ptr2"), HasLoansTo({"y"}, "p4"));
+}
+
+TEST_F(LifetimeAnalysisTest, ReassignToNull) {
+  SetupTest(R"(
+    void target() {
+      MyObj s1;
+      MyObj* p = &s1;
+      POINT(before_null);
+      p = nullptr;
+      POINT(after_null);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1"}, "before_null"));
+  // After assigning to null, the origin for `p` should have no loans.
+  EXPECT_THAT(Origin("p"), HasLoansTo({}, "after_null"));
+}
+
+TEST_F(LifetimeAnalysisTest, ReassignInIf) {
+  SetupTest(R"(
+    void target(bool condition) {
+      MyObj s1, s2;
+      MyObj* p = &s1;
+      POINT(before_if);
+      if (condition) {
+        p = &s2;
+        POINT(after_reassign);
+      }
+      POINT(after_if);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1"}, "before_if"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s2"}, "after_reassign"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1", "s2"}, "after_if"));
+}
+
+TEST_F(LifetimeAnalysisTest, AssignInSwitch) {
+  SetupTest(R"(
+    void target(int mode) {
+      MyObj s1, s2, s3;
+      MyObj* p = nullptr;
+      switch (mode) {
+        case 1:
+          p = &s1;
+          POINT(case1);
+          break;
+        case 2:
+          p = &s2;
+          POINT(case2);
+          break;
+        default:
+          p = &s3;
+          POINT(case3);
+          break;
+      }
+      POINT(after_switch);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1"}, "case1"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s2"}, "case2"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s3"}, "case3"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1", "s2", "s3"}, "after_switch"));
+}
+
+TEST_F(LifetimeAnalysisTest, LoanInLoop) {
+  SetupTest(R"(
+    void target(bool condition) {
+      MyObj* p = nullptr;
+      while (condition) {
+        MyObj inner;
+        p = &inner;
+        POINT(in_loop);
+      }
+      POINT(after_loop);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"inner"}, "in_loop"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"inner"}, "after_loop"));
+}
+
+TEST_F(LifetimeAnalysisTest, LoopWithBreak) {
+  SetupTest(R"(
+    void target(int count) {
+      MyObj s1;
+      MyObj s2;
+      MyObj* p = &s1;
+      POINT(before_loop);
+      for (int i = 0; i < count; ++i) {
+        if (i == 5) {
+          p = &s2;
+          POINT(inside_if);
+          break;
+        }
+        POINT(after_if);
+      }
+      POINT(after_loop);
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1"}, "before_loop"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s2"}, "inside_if"));
+  // At the join point after if, s2 cannot make it to p without the if.
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1"}, "after_if"));
+  // At the join point after the loop, p could hold a loan to s1 (if the loop
+  // completed normally) or to s2 (if the loop was broken).
+  EXPECT_THAT(Origin("p"), HasLoansTo({"s1", "s2"}, "after_loop"));
+}
+
+TEST_F(LifetimeAnalysisTest, PointersInACycle) {
+  SetupTest(R"(
+    void target(bool condition) {
+      MyObj v1, v2, v3;
+      MyObj *p1 = &v1, *p2 = &v2, *p3 = &v3;
+
+      POINT(before_while);
+      while (condition) {
+        MyObj* temp = p1;
+        p1 = p2;
+        p2 = p3;
+        p3 = temp;
+      }
+      POINT(after_loop);
+    }
+  )");
+  EXPECT_THAT(Origin("p1"), HasLoansTo({"v1"}, "before_while"));
+  EXPECT_THAT(Origin("p2"), HasLoansTo({"v2"}, "before_while"));
+  EXPECT_THAT(Origin("p3"), HasLoansTo({"v3"}, "before_while"));
+
+  // At the fixed point after the loop, all pointers could point to any of
+  // the three variables.
+  EXPECT_THAT(Origin("p1"), HasLoansTo({"v1", "v2", "v3"}, "after_loop"));
+  EXPECT_THAT(Origin("p2"), HasLoansTo({"v1", "v2", "v3"}, "after_loop"));
+  EXPECT_THAT(Origin("p3"), HasLoansTo({"v1", "v2", "v3"}, "after_loop"));
+  EXPECT_THAT(Origin("temp"), HasLoansTo({"v1", "v2", "v3"}, "after_loop"));
+}
+
+TEST_F(LifetimeAnalysisTest, NestedScopes) {
+  SetupTest(R"(
+    void target() {
+      MyObj* p = nullptr;
+      {
+        MyObj outer;
+        p = &outer;
+        POINT(before_inner_scope);
+        {
+          MyObj inner;
+          p = &inner;
+          POINT(inside_inner_scope);
+        } // inner expires
+        POINT(after_inner_scope);
+      } // outer expires
+    }
+  )");
+  EXPECT_THAT(Origin("p"), HasLoansTo({"outer"}, "before_inner_scope"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"inner"}, "inside_inner_scope"));
+  EXPECT_THAT(Origin("p"), HasLoansTo({"inner"}, "after_inner_scope"));
+}
+
+} // anonymous namespace
+} // namespace clang::lifetimes::internal

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to