https://github.com/malek203 updated 
https://github.com/llvm/llvm-project/pull/131831

>From 7157d4c42e534c5c8c06b564055aed1030b9f690 Mon Sep 17 00:00:00 2001
From: Malek Ben Slimane <malek.ben.slim...@sap.com>
Date: Wed, 25 Sep 2024 15:21:08 +0200
Subject: [PATCH] Thread Safety Analysis: Check managed capabilities of
 returned scoped capability

Verify that the return value of type scoped lockable manages the
mutexes expected by the function annotations.
---
 .../clang/Analysis/Analyses/ThreadSafety.h    | 16 +++-
 .../clang/Basic/DiagnosticSemaKinds.td        |  4 +-
 clang/lib/Analysis/ThreadSafety.cpp           | 90 ++++++++++++++++++-
 clang/lib/Sema/AnalysisBasedWarnings.cpp      | 23 ++---
 .../SemaCXX/warn-thread-safety-analysis.cpp   | 75 +++++++++++-----
 5 files changed, 165 insertions(+), 43 deletions(-)

diff --git a/clang/include/clang/Analysis/Analyses/ThreadSafety.h 
b/clang/include/clang/Analysis/Analyses/ThreadSafety.h
index 20b75c46593e0..210610f672933 100644
--- a/clang/include/clang/Analysis/Analyses/ThreadSafety.h
+++ b/clang/include/clang/Analysis/Analyses/ThreadSafety.h
@@ -243,10 +243,13 @@ class ThreadSafetyHandler {
   /// \param Kind -- The kind of the expected mutex.
   /// \param Expected -- The name of the expected mutex.
   /// \param Actual -- The name of the actual mutex.
+  /// \param ForParam -- Indicates whether the note applies to a function
+  /// parameter.
   virtual void handleUnmatchedUnderlyingMutexes(SourceLocation Loc,
                                                 SourceLocation DLoc,
                                                 Name ScopeName, StringRef Kind,
-                                                Name Expected, Name Actual) {}
+                                                Name Expected, Name Actual,
+                                                bool ForParam) {}
 
   /// Warn when we get fewer underlying mutexes than expected.
   /// \param Loc -- The location of the call expression.
@@ -254,10 +257,13 @@ class ThreadSafetyHandler {
   /// \param ScopeName -- The name of the scope passed to the function.
   /// \param Kind -- The kind of the expected mutex.
   /// \param Expected -- The name of the expected mutex.
+  /// \param ForParam -- Indicates whether the note applies to a function
+  /// parameter.
   virtual void handleExpectMoreUnderlyingMutexes(SourceLocation Loc,
                                                  SourceLocation DLoc,
                                                  Name ScopeName, StringRef 
Kind,
-                                                 Name Expected) {}
+                                                 Name Expected, bool ForParam) 
{
+  }
 
   /// Warn when we get more underlying mutexes than expected.
   /// \param Loc -- The location of the call expression.
@@ -265,11 +271,13 @@ class ThreadSafetyHandler {
   /// \param ScopeName -- The name of the scope passed to the function.
   /// \param Kind -- The kind of the actual mutex.
   /// \param Actual -- The name of the actual mutex.
+  /// \param ForParam -- Indicates whether the note applies to a function
+  /// parameter.
   virtual void handleExpectFewerUnderlyingMutexes(SourceLocation Loc,
                                                   SourceLocation DLoc,
                                                   Name ScopeName,
-                                                  StringRef Kind, Name Actual) 
{
-  }
+                                                  StringRef Kind, Name Actual,
+                                                  bool ForParam) {}
 
   /// Warn that L1 cannot be acquired before L2.
   virtual void handleLockAcquiredBefore(StringRef Kind, Name L1Name,
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 627bebb31fc8d..82d9cb02960dd 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -4117,8 +4117,8 @@ def warn_expect_more_underlying_mutexes : Warning<
 def warn_expect_fewer_underlying_mutexes : Warning<
   "did not expect %0 '%2' to be managed by '%1'">,
   InGroup<ThreadSafetyAnalysis>, DefaultIgnore;
-def note_managed_mismatch_here_for_param : Note<
-  "see attribute on parameter here">;
+def note_managed_mismatch_here : Note<
+  "see attribute on %select{function|parameter}0 here">;
 
 
 // Thread safety warnings negative capabilities
diff --git a/clang/lib/Analysis/ThreadSafety.cpp 
b/clang/lib/Analysis/ThreadSafety.cpp
index 6b5b49377fa08..e019ee9073efd 100644
--- a/clang/lib/Analysis/ThreadSafety.cpp
+++ b/clang/lib/Analysis/ThreadSafety.cpp
@@ -1040,6 +1040,7 @@ class ThreadSafetyAnalyzer {
   std::vector<CFGBlockInfo> BlockInfo;
 
   BeforeSet *GlobalBeforeSet;
+  CapExprSet ExpectedReturnedCapabilities;
 
 public:
   ThreadSafetyAnalyzer(ThreadSafetyHandler &H, BeforeSet* Bset)
@@ -2041,15 +2042,16 @@ void BuildLockset::handleCall(const Expr *Exp, const 
NamedDecl *D,
         if (!a.has_value()) {
           Analyzer->Handler.handleExpectFewerUnderlyingMutexes(
               Exp->getExprLoc(), D->getLocation(), Scope->toString(),
-              b.value().getKind(), b.value().toString());
+              b.value().getKind(), b.value().toString(), true);
         } else if (!b.has_value()) {
           Analyzer->Handler.handleExpectMoreUnderlyingMutexes(
               Exp->getExprLoc(), D->getLocation(), Scope->toString(),
-              a.value().getKind(), a.value().toString());
-        } else if (!a.value().equals(b.value())) {
+              a.value().getKind(), a.value().toString(), true);
+        } else if (!a.value().matches(b.value())) {
           Analyzer->Handler.handleUnmatchedUnderlyingMutexes(
               Exp->getExprLoc(), D->getLocation(), Scope->toString(),
-              a.value().getKind(), a.value().toString(), b.value().toString());
+              a.value().getKind(), a.value().toString(), b.value().toString(),
+              true);
           break;
         }
       }
@@ -2294,6 +2296,25 @@ void BuildLockset::VisitMaterializeTemporaryExpr(
   }
 }
 
+static bool checkRecordTypeForScopedCapability(QualType Ty) {
+  const RecordType *RT = Ty->getAs<RecordType>();
+
+  if (!RT)
+    return false;
+
+  if (RT->getDecl()->hasAttr<ScopedLockableAttr>())
+    return true;
+
+  // Else check if any base classes have the attribute.
+  if (const auto *CRD = dyn_cast<CXXRecordDecl>(RT->getDecl())) {
+    if (!CRD->forallBases([](const CXXRecordDecl *Base) {
+          return !Base->hasAttr<ScopedLockableAttr>();
+        }))
+      return true;
+  }
+  return false;
+}
+
 void BuildLockset::VisitReturnStmt(const ReturnStmt *S) {
   if (Analyzer->CurrentFunction == nullptr)
     return;
@@ -2316,6 +2337,49 @@ void BuildLockset::VisitReturnStmt(const ReturnStmt *S) {
         ReturnType->getPointeeType().isConstQualified() ? AK_Read : AK_Written,
         POK_ReturnPointer);
   }
+
+  if (!checkRecordTypeForScopedCapability(ReturnType))
+    return;
+
+  if (const auto *CBTE = dyn_cast<ExprWithCleanups>(RetVal))
+    RetVal = CBTE->getSubExpr();
+  RetVal = RetVal->IgnoreCasts();
+  if (const auto *CBTE = dyn_cast<CXXBindTemporaryExpr>(RetVal))
+    RetVal = CBTE->getSubExpr();
+  CapabilityExpr Cp;
+  if (auto Object = Analyzer->ConstructedObjects.find(RetVal);
+      Object != Analyzer->ConstructedObjects.end()) {
+    Cp = CapabilityExpr(Object->second, StringRef(), false);
+    Analyzer->ConstructedObjects.erase(Object);
+  }
+  if (!Cp.shouldIgnore()) {
+    const FactEntry *Fact = FSet.findLock(Analyzer->FactMan, Cp);
+    if (const ScopedLockableFactEntry *Scope =
+            cast_or_null<ScopedLockableFactEntry>(Fact)) {
+      CapExprSet LocksInReturnVal = Scope->getUnderlyingMutexes();
+      for (const auto &[a, b] : zip_longest(
+               Analyzer->ExpectedReturnedCapabilities, LocksInReturnVal)) {
+        if (!a.has_value()) {
+          Analyzer->Handler.handleExpectFewerUnderlyingMutexes(
+              RetVal->getExprLoc(), Analyzer->CurrentFunction->getLocation(),
+              Scope->toString(), b.value().getKind(), b.value().toString(),
+              false);
+        } else if (!b.has_value()) {
+          Analyzer->Handler.handleExpectMoreUnderlyingMutexes(
+              RetVal->getExprLoc(), Analyzer->CurrentFunction->getLocation(),
+              Scope->toString(), a.value().getKind(), a.value().toString(),
+              false);
+          break;
+        } else if (!a.value().matches(b.value())) {
+          Analyzer->Handler.handleUnmatchedUnderlyingMutexes(
+              RetVal->getExprLoc(), Analyzer->CurrentFunction->getLocation(),
+              Scope->toString(), a.value().getKind(), a.value().toString(),
+              b.value().toString(), false);
+          break;
+        }
+      }
+    }
+  }
 }
 
 /// Given two facts merging on a join point, possibly warn and decide whether 
to
@@ -2480,11 +2544,22 @@ void 
ThreadSafetyAnalyzer::runAnalysis(AnalysisDeclContext &AC) {
     CapExprSet SharedLocksToAdd;
 
     SourceLocation Loc = D->getLocation();
+    bool ReturnsScopedCapability;
+    if (CurrentFunction)
+      ReturnsScopedCapability = checkRecordTypeForScopedCapability(
+          CurrentFunction->getReturnType().getCanonicalType());
+    else if (auto CurrentMethod = dyn_cast<ObjCMethodDecl>(D))
+      ReturnsScopedCapability = checkRecordTypeForScopedCapability(
+          CurrentMethod->getReturnType().getCanonicalType());
+    else
+      llvm_unreachable("Unknown function kind");
     for (const auto *Attr : D->attrs()) {
       Loc = Attr->getLocation();
       if (const auto *A = dyn_cast<RequiresCapabilityAttr>(Attr)) {
         getMutexIDs(A->isShared() ? SharedLocksToAdd : ExclusiveLocksToAdd, A,
                     nullptr, D);
+        if (ReturnsScopedCapability)
+          getMutexIDs(ExpectedReturnedCapabilities, A, nullptr, D);
       } else if (const auto *A = dyn_cast<ReleaseCapabilityAttr>(Attr)) {
         // UNLOCK_FUNCTION() is used to hide the underlying lock 
implementation.
         // We must ignore such methods.
@@ -2493,12 +2568,19 @@ void 
ThreadSafetyAnalyzer::runAnalysis(AnalysisDeclContext &AC) {
         getMutexIDs(A->isShared() ? SharedLocksToAdd : ExclusiveLocksToAdd, A,
                     nullptr, D);
         getMutexIDs(LocksReleased, A, nullptr, D);
+        if (ReturnsScopedCapability)
+          getMutexIDs(ExpectedReturnedCapabilities, A, nullptr, D);
       } else if (const auto *A = dyn_cast<AcquireCapabilityAttr>(Attr)) {
         if (A->args_size() == 0)
           return;
         getMutexIDs(A->isShared() ? SharedLocksAcquired
                                   : ExclusiveLocksAcquired,
                     A, nullptr, D);
+        if (ReturnsScopedCapability)
+          getMutexIDs(ExpectedReturnedCapabilities, A, nullptr, D);
+      } else if (const auto *A = dyn_cast<LocksExcludedAttr>(Attr)) {
+        if (ReturnsScopedCapability)
+          getMutexIDs(ExpectedReturnedCapabilities, A, nullptr, D);
       } else if (isa<ExclusiveTrylockFunctionAttr>(Attr)) {
         // Don't try to check trylock functions for now.
         return;
diff --git a/clang/lib/Sema/AnalysisBasedWarnings.cpp 
b/clang/lib/Sema/AnalysisBasedWarnings.cpp
index 3d6da4f70f99e..da9151893c262 100644
--- a/clang/lib/Sema/AnalysisBasedWarnings.cpp
+++ b/clang/lib/Sema/AnalysisBasedWarnings.cpp
@@ -1799,11 +1799,11 @@ class ThreadSafetyReporter : public 
clang::threadSafety::ThreadSafetyHandler {
                : getNotes();
   }
 
-  OptionalNotes makeManagedMismatchNoteForParam(SourceLocation DeclLoc) {
+  OptionalNotes makeManagedMismatchNote(SourceLocation DeclLoc, bool forParam) 
{
     return DeclLoc.isValid()
                ? getNotes(PartialDiagnosticAt(
-                     DeclLoc,
-                     S.PDiag(diag::note_managed_mismatch_here_for_param)))
+                     DeclLoc, S.PDiag(diag::note_managed_mismatch_here)
+                                  << forParam))
                : getNotes();
   }
 
@@ -1829,34 +1829,35 @@ class ThreadSafetyReporter : public 
clang::threadSafety::ThreadSafetyHandler {
 
   void handleUnmatchedUnderlyingMutexes(SourceLocation Loc, SourceLocation 
DLoc,
                                         Name scopeName, StringRef Kind,
-                                        Name expected, Name actual) override {
+                                        Name expected, Name actual,
+                                        bool forParam) override {
     PartialDiagnosticAt Warning(Loc,
                                 
S.PDiag(diag::warn_unmatched_underlying_mutexes)
                                     << Kind << scopeName << expected << 
actual);
     Warnings.emplace_back(std::move(Warning),
-                          makeManagedMismatchNoteForParam(DLoc));
+                          makeManagedMismatchNote(DLoc, forParam));
   }
 
   void handleExpectMoreUnderlyingMutexes(SourceLocation Loc,
                                          SourceLocation DLoc, Name scopeName,
-                                         StringRef Kind,
-                                         Name expected) override {
+                                         StringRef Kind, Name expected,
+                                         bool forParam) override {
     PartialDiagnosticAt Warning(
         Loc, S.PDiag(diag::warn_expect_more_underlying_mutexes)
                  << Kind << scopeName << expected);
     Warnings.emplace_back(std::move(Warning),
-                          makeManagedMismatchNoteForParam(DLoc));
+                          makeManagedMismatchNote(DLoc, forParam));
   }
 
   void handleExpectFewerUnderlyingMutexes(SourceLocation Loc,
                                           SourceLocation DLoc, Name scopeName,
-                                          StringRef Kind,
-                                          Name actual) override {
+                                          StringRef Kind, Name actual,
+                                          bool forParam) override {
     PartialDiagnosticAt Warning(
         Loc, S.PDiag(diag::warn_expect_fewer_underlying_mutexes)
                  << Kind << scopeName << actual);
     Warnings.emplace_back(std::move(Warning),
-                          makeManagedMismatchNoteForParam(DLoc));
+                          makeManagedMismatchNote(DLoc, forParam));
   }
 
   void handleInvalidLockExp(SourceLocation Loc) override {
diff --git a/clang/test/SemaCXX/warn-thread-safety-analysis.cpp 
b/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
index ac3ca5e0c12a8..6bb786736d18a 100644
--- a/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
+++ b/clang/test/SemaCXX/warn-thread-safety-analysis.cpp
@@ -57,6 +57,27 @@ class SCOPED_LOCKABLE DoubleMutexLock {
   ~DoubleMutexLock() UNLOCK_FUNCTION();
 };
 
+class DeferTraits {};
+struct SharedTraits {};
+struct ExclusiveTraits {};
+
+class SCOPED_LOCKABLE RelockableMutexLock {
+public:
+  RelockableMutexLock(Mutex *mu, DeferTraits) LOCKS_EXCLUDED(mu);
+  RelockableMutexLock(Mutex *mu, SharedTraits) SHARED_LOCK_FUNCTION(mu);
+  RelockableMutexLock(Mutex *mu, ExclusiveTraits) EXCLUSIVE_LOCK_FUNCTION(mu);
+  ~RelockableMutexLock() UNLOCK_FUNCTION();
+
+  void Lock() EXCLUSIVE_LOCK_FUNCTION();
+  void Unlock() UNLOCK_FUNCTION();
+
+  void ReaderLock() SHARED_LOCK_FUNCTION();
+  void ReaderUnlock() UNLOCK_FUNCTION();
+
+  void PromoteShared() UNLOCK_FUNCTION() EXCLUSIVE_LOCK_FUNCTION();
+  void DemoteExclusive() UNLOCK_FUNCTION() SHARED_LOCK_FUNCTION();
+};
+
 // The universal lock, written "*", allows checking to be selectively turned
 // off for a particular piece of code.
 void beginNoWarnOnReads()  SHARED_LOCK_FUNCTION("*");
@@ -2753,8 +2774,6 @@ void Foo::test6() {
 
 namespace RelockableScopedLock {
 
-class DeferTraits {};
-
 class SCOPED_LOCKABLE RelockableExclusiveMutexLock {
 public:
   RelockableExclusiveMutexLock(Mutex *mu) EXCLUSIVE_LOCK_FUNCTION(mu);
@@ -2765,26 +2784,6 @@ class SCOPED_LOCKABLE RelockableExclusiveMutexLock {
   void Unlock() UNLOCK_FUNCTION();
 };
 
-struct SharedTraits {};
-struct ExclusiveTraits {};
-
-class SCOPED_LOCKABLE RelockableMutexLock {
-public:
-  RelockableMutexLock(Mutex *mu, DeferTraits) LOCKS_EXCLUDED(mu);
-  RelockableMutexLock(Mutex *mu, SharedTraits) SHARED_LOCK_FUNCTION(mu);
-  RelockableMutexLock(Mutex *mu, ExclusiveTraits) EXCLUSIVE_LOCK_FUNCTION(mu);
-  ~RelockableMutexLock() UNLOCK_FUNCTION();
-
-  void Lock() EXCLUSIVE_LOCK_FUNCTION();
-  void Unlock() UNLOCK_FUNCTION();
-
-  void ReaderLock() SHARED_LOCK_FUNCTION();
-  void ReaderUnlock() UNLOCK_FUNCTION();
-
-  void PromoteShared() UNLOCK_FUNCTION() EXCLUSIVE_LOCK_FUNCTION();
-  void DemoteExclusive() UNLOCK_FUNCTION() SHARED_LOCK_FUNCTION();
-};
-
 Mutex mu;
 int x GUARDED_BY(mu);
 bool b;
@@ -3566,6 +3565,38 @@ void releaseMemberCall() {
   ReleasableMutexLock lock(&obj.mu);
   releaseMember(obj, lock);
 }
+#ifdef __cpp_guaranteed_copy_elision
+// expected-note@+2{{mutex acquired here}}
+// expected-note@+1{{see attribute on function here}}
+RelockableScope returnUnmatchTest() EXCLUSIVE_LOCK_FUNCTION(mu){
+  // expected-note@+1{{mutex acquired here}}
+  return RelockableScope(&mu2); // expected-warning{{mutex managed by 
'<temporary>' is 'mu2' instead of 'mu'}}
+} // expected-warning{{mutex 'mu2' is still held at the end of function}}
+  // expected-warning@-1{{expecting mutex 'mu' to be held at the end of 
function}}
+
+// expected-note@+2{{mutex acquired here}}
+// expected-note@+1{{see attribute on function here}}
+RelockableScope returnMoreTest() EXCLUSIVE_LOCK_FUNCTION(mu, mu2){
+  return RelockableScope(&mu); // expected-warning{{mutex 'mu2' not managed by 
'<temporary>'}}
+} // expected-warning{{expecting mutex 'mu2' to be held at the end of 
function}}
+
+// expected-note@+1{{see attribute on function here}}
+DoubleMutexLock returnFewerTest() EXCLUSIVE_LOCK_FUNCTION(mu){
+  // expected-note@+1{{mutex acquired here}}
+  return DoubleMutexLock(&mu,&mu2); // expected-warning{{did not expect mutex 
'mu2' to be managed by '<temporary>'}}
+} // expected-warning{{mutex 'mu2' is still held at the end of function}}
+
+// expected-note@+1{{see attribute on function here}}
+RelockableMutexLock lockTest() EXCLUSIVE_LOCK_FUNCTION(mu) {
+  mu.Lock();
+  return RelockableMutexLock(&mu2, DeferTraits{}); // expected-warning{{mutex 
managed by '<temporary>' is 'mu2' instead of 'mu'}}
+}
+
+// expected-note@+1{{mutex acquired here}}
+RelockableMutexLock lockTest2() EXCLUSIVE_LOCK_FUNCTION(mu) {
+  return RelockableMutexLock(&mu, DeferTraits{});
+} // expected-warning{{expecting mutex 'mu' to be held at the end of function}}
+#endif
 
 } // end namespace PassingScope
 

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to