================
@@ -3068,11 +3111,174 @@ void MallocChecker::checkDeadSymbols(SymbolReaper 
&SymReaper,
   C.addTransition(state->set<RegionState>(RS), N);
 }
 
+// Use isWithinStdNamespace from CheckerHelpers.h instead of custom
+// implementation
+
+// Allowlist of owning smart pointers we want to recognize.
+// Start with unique_ptr and shared_ptr. (intentionally exclude weak_ptr)
+static bool isSmartOwningPtrType(QualType QT) {
+  QT = QT->getCanonicalTypeUnqualified();
+
+  // First try TemplateSpecializationType (for std smart pointers)
+  const auto *TST = QT->getAs<TemplateSpecializationType>();
+  if (TST) {
+    const TemplateDecl *TD = TST->getTemplateName().getAsTemplateDecl();
+    if (!TD)
+      return false;
+
+    const auto *ND = dyn_cast_or_null<NamedDecl>(TD->getTemplatedDecl());
+    if (!ND)
+      return false;
+
+    // Check if it's in std namespace
+    if (!isWithinStdNamespace(ND))
+      return false;
+
+    StringRef Name = ND->getName();
+    return Name == "unique_ptr" || Name == "shared_ptr";
+  }
+
+  // Also try RecordType (for custom smart pointer implementations)
+  const auto *RD = QT->getAsCXXRecordDecl();
+  if (RD) {
+    StringRef Name = RD->getName();
+    if (Name == "unique_ptr" || Name == "shared_ptr") {
+      // Accept any custom unique_ptr or shared_ptr implementation
+      return true;
+    }
+  }
+
+  return false;
+}
+
+static bool hasSmartPtrField(const CXXRecordDecl *CRD) {
+  return llvm::any_of(CRD->fields(), [](const FieldDecl *FD) {
+    return isSmartOwningPtrType(FD->getType());
+  });
+}
+
+static bool isRvalueByValueRecord(const Expr *AE) {
+  if (AE->isGLValue())
+    return false;
+
+  QualType T = AE->getType();
+  if (!T->isRecordType() || T->isReferenceType())
+    return false;
+
+  // Accept common temp/construct forms but don't overfit.
+  return isa<CXXTemporaryObjectExpr, MaterializeTemporaryExpr, 
CXXConstructExpr,
+             InitListExpr, ImplicitCastExpr, CXXBindTemporaryExpr>(AE);
+}
+
+static bool isRvalueByValueRecordWithSmartPtr(const Expr *AE) {
+  if (!isRvalueByValueRecord(AE))
+    return false;
+
+  const auto *CRD = AE->getType()->getAsCXXRecordDecl();
+  return CRD && hasSmartPtrField(CRD);
+}
+
+static ProgramStateRef escapeAllAllocatedSymbols(ProgramStateRef State) {
+  RegionStateTy RS = State->get<RegionState>();
+  ProgramStateRef NewState = State;
+  for (auto [Sym, RefSt] : RS) {
+    if (RefSt.isAllocated() || RefSt.isAllocatedOfSizeZero()) {
+      NewState = NewState->set<RegionState>(Sym, RefState::getEscaped(&RefSt));
+    }
+  }
+  return NewState;
+}
+
+static void collectDirectSmartOwningPtrFieldRegions(
+    const MemRegion *Base, QualType RecQT, CheckerContext &C,
+    SmallVectorImpl<const MemRegion *> &Out) {
+  if (!Base)
+    return;
+  const auto *CRD = RecQT->getAsCXXRecordDecl();
+  if (!CRD)
+    return;
+
+  for (const FieldDecl *FD : CRD->fields()) {
+    if (!isSmartOwningPtrType(FD->getType()))
+      continue;
+    SVal L = C.getState()->getLValue(FD, loc::MemRegionVal(Base));
+    if (const MemRegion *FR = L.getAsRegion())
+      Out.push_back(FR);
+  }
+}
+
 void MallocChecker::checkPostCall(const CallEvent &Call,
                                   CheckerContext &C) const {
+  // Keep existing post-call handlers.
   if (const auto *PostFN = PostFnMap.lookup(Call)) {
     (*PostFN)(this, C.getState(), Call, C);
-    return;
+  }
+
+  SmallVector<const MemRegion *, 8> SmartPtrFieldRoots;
+
+  for (unsigned I = 0, E = Call.getNumArgs(); I != E; ++I) {
+    const Expr *AE = Call.getArgExpr(I);
+    if (!AE)
+      continue;
+    AE = AE->IgnoreParenImpCasts();
+
+    if (!isRvalueByValueRecordWithSmartPtr(AE))
+      continue;
+
+    // Find a region for the argument.
+    SVal VCall = Call.getArgSVal(I);
+    SVal VExpr = C.getSVal(AE);
+    const MemRegion *RCall = VCall.getAsRegion();
+    const MemRegion *RExpr = VExpr.getAsRegion();
+
+    const MemRegion *Base = RCall ? RCall : RExpr;
+    if (!Base) {
+      // Fallback: if we have a by-value record with smart pointer fields but 
no
+      // region, mark all allocated symbols as escaped
+      ProgramStateRef State = C.getState();
+      ProgramStateRef NewState = escapeAllAllocatedSymbols(State);
+      if (NewState != State)
+        C.addTransition(NewState);
+      continue;
+    }
----------------
ivanmurashko wrote:

I refactored that part of the code and addressed the comment. Especially, I 
removed the comparison and just set `needsStateUpdate = true` when we know we 
need a state change

https://github.com/llvm/llvm-project/pull/152751
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to