vsavchenko created this revision.
vsavchenko added reviewers: NoQ, dcoughlin, ASDenysPetrov, xazax.hun, Szelethus.
Herald added subscribers: cfe-commits, martong, Charusso, dkrupp, donat.nagy, 
mikhail.ramalho, a.sidorin, rnkovacs, szepet, baloghadamsoftware.
Herald added a project: clang.
vsavchenko added a parent revision: D82381: [analyzer] Introduce small 
improvements to the solver infra.

For the most cases, we try to reason about symbol either based on the
information we know about that symbol in particular or about its
composite parts.  This is faster and eliminates costly brute force
searches through existing constraints.

However, we do want to support some cases that are widespread enough
and involve reasoning about different existing constraints at once.
These include:

- resoning about 'a - b' based on what we know about 'b - a'
- reasoning about 'a <= b' based on what we know about 'a > b' or 'a < b'

This commit expands on that part by tracking symbols known to be equal
while still avoiding brute force searches.  It changes the way we track
constraints for individual symbols.  If we know for a fact that 'a == b'
then there is no need in tracking constraints for both 'a' and 'b' especially
if these constraints are different.  This additional relationship makes
dead/live logic for constraints harder as we want to maintain as much
information on the equivalence class as possible, but we still won't
carry the information that we don't need anymore.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D82445

Files:
  clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
  
clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
  clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
  clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
  clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
  clang/test/Analysis/equality_tracking.c

Index: clang/test/Analysis/equality_tracking.c
===================================================================
--- /dev/null
+++ clang/test/Analysis/equality_tracking.c
@@ -0,0 +1,132 @@
+// RUN: %clang_analyze_cc1 -verify %s \
+// RUN:   -analyzer-checker=core,debug.ExprInspection \
+// RUN:   -analyzer-config eagerly-assume=false
+
+#define NULL (void *)0
+
+#define UCHAR_MAX (unsigned char)(~0U)
+#define CHAR_MAX (char)(UCHAR_MAX & (UCHAR_MAX >> 1))
+#define CHAR_MIN (char)(UCHAR_MAX & ~(UCHAR_MAX >> 1))
+
+void clang_analyzer_eval(int);
+
+int getInt();
+
+void zeroImpliesEquality(int a, int b) {
+  clang_analyzer_eval((a - b) == 0); // expected-warning{{UNKNOWN}}
+  if ((a - b) == 0) {
+    clang_analyzer_eval(b != a);    // expected-warning{{FALSE}}
+    clang_analyzer_eval(b == a);    // expected-warning{{TRUE}}
+    clang_analyzer_eval(!(a != b)); // expected-warning{{TRUE}}
+    clang_analyzer_eval(!(b == a)); // expected-warning{{FALSE}}
+    return;
+  }
+  clang_analyzer_eval((a - b) == 0); // expected-warning{{FALSE}}
+  // FIXME: we should track disequality information as well
+  clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}}
+  clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}}
+}
+
+void zeroImpliesReversedEqual(int a, int b) {
+  clang_analyzer_eval((b - a) == 0); // expected-warning{{UNKNOWN}}
+  if ((b - a) == 0) {
+    clang_analyzer_eval(b != a); // expected-warning{{FALSE}}
+    clang_analyzer_eval(b == a); // expected-warning{{TRUE}}
+    return;
+  }
+  clang_analyzer_eval((b - a) == 0); // expected-warning{{FALSE}}
+  // FIXME: we should track disequality information as well
+  clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}}
+  clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}}
+}
+
+void canonicalEqual(int a, int b) {
+  clang_analyzer_eval(a == b); // expected-warning{{UNKNOWN}}
+  if (a == b) {
+    clang_analyzer_eval(b == a); // expected-warning{{TRUE}}
+    return;
+  }
+  clang_analyzer_eval(a == b); // expected-warning{{FALSE}}
+  clang_analyzer_eval(b == a); // expected-warning{{FALSE}}
+}
+
+void test(int a, int b, int c, int d) {
+  if (a == b && c == d) {
+    if (a == 0 && b == d) {
+      clang_analyzer_eval(c == 0); // expected-warning{{TRUE}}
+    }
+    c = 10;
+    if (b == d) {
+      clang_analyzer_eval(c == 10); // expected-warning{{TRUE}}
+      clang_analyzer_eval(d == 10); // expected-warning{{UNKNOWN}}
+                                    // expected-warning@-1{{FALSE}}
+      clang_analyzer_eval(b == a);  // expected-warning{{TRUE}}
+      clang_analyzer_eval(a == d);  // expected-warning{{TRUE}}
+
+      b = getInt();
+      clang_analyzer_eval(a == d); // expected-warning{{TRUE}}
+      clang_analyzer_eval(a == b); // expected-warning{{UNKNOWN}}
+    }
+  }
+
+  if (a != b && b == c) {
+    if (c == 42) {
+      clang_analyzer_eval(b == 42); // expected-warning{{TRUE}}
+      // FIXME: we should track disequality information as well
+      clang_analyzer_eval(a != 42); // expected-warning{{UNKNOWN}}
+    }
+  }
+}
+
+void testIntersection(int a, int b, int c) {
+  if (a < 42 && b > 15 && c >= 25 && c <= 30) {
+    if (a != b)
+      return;
+
+    clang_analyzer_eval(a > 15);  // expected-warning{{TRUE}}
+    clang_analyzer_eval(b < 42);  // expected-warning{{TRUE}}
+    clang_analyzer_eval(a <= 30); // expected-warning{{UNKNOWN}}
+
+    if (c == b) {
+      // For all equal symbols, we should track the minimal common range.
+      //
+      // Also, it should be noted that c is dead at this point, but the
+      // constraint initially associated with c is still around.
+      clang_analyzer_eval(a >= 25 && a <= 30); // expected-warning{{TRUE}}
+      clang_analyzer_eval(b >= 25 && b <= 30); // expected-warning{{TRUE}}
+    }
+  }
+}
+
+void testPromotion(int a, char b) {
+  if (b > 10) {
+    if (a == b) {
+      clang_analyzer_eval(a > 10);        // expected-warning{{TRUE}}
+      clang_analyzer_eval(a <= CHAR_MAX); // expected-warning{{TRUE}}
+    }
+  }
+}
+
+void testPromotionOnlyTypes(int a, char b) {
+  if (a == b) {
+    // FIXME: even when b doesn't have any constraints we still
+    //        should understand that b has a smaller type and assign
+    //        constraints correspondingly
+    clang_analyzer_eval(a <= CHAR_MAX); // expected-warning{{UNKNOWN}}
+  }
+}
+
+void testPointers(int *a, int *b, int *c, int *d) {
+  if (a == b && c == d) {
+    if (a == NULL && b == d) {
+      clang_analyzer_eval(c == NULL); // expected-warning{{TRUE}}
+    }
+  }
+
+  if (a != b && b == c) {
+    if (c == NULL) {
+      // FIXME: we should track disequality information as well
+      clang_analyzer_eval(a != NULL); // expected-warning{{UNKNOWN}}
+    }
+  }
+}
Index: clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
+++ clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
@@ -40,19 +40,20 @@
     }
 
   } else if (const SymSymExpr *SSE = dyn_cast<SymSymExpr>(Sym)) {
-    // Translate "a != b" to "(b - a) != 0".
-    // We invert the order of the operands as a heuristic for how loop
-    // conditions are usually written ("begin != end") as compared to length
-    // calculations ("end - begin"). The more correct thing to do would be to
-    // canonicalize "a - b" and "b - a", which would allow us to treat
-    // "a != b" and "b != a" the same.
-    SymbolManager &SymMgr = getSymbolManager();
     BinaryOperator::Opcode Op = SSE->getOpcode();
     assert(BinaryOperator::isComparisonOp(Op));
 
-    // For now, we only support comparing pointers.
+    // We convert equality operations for pointers only.
     if (Loc::isLocType(SSE->getLHS()->getType()) &&
         Loc::isLocType(SSE->getRHS()->getType())) {
+      // Translate "a != b" to "(b - a) != 0".
+      // We invert the order of the operands as a heuristic for how loop
+      // conditions are usually written ("begin != end") as compared to length
+      // calculations ("end - begin"). The more correct thing to do would be to
+      // canonicalize "a - b" and "b - a", which would allow us to treat
+      // "a != b" and "b != a" the same.
+
+      SymbolManager &SymMgr = getSymbolManager();
       QualType DiffTy = SymMgr.getContext().getPointerDiffType();
       SymbolRef Subtraction =
           SymMgr.getSymSymExpr(SSE->getRHS(), BO_Sub, SSE->getLHS(), DiffTy);
@@ -63,6 +64,25 @@
         Op = BinaryOperator::negateComparisonOp(Op);
       return assumeSymRel(State, Subtraction, Op, Zero);
     }
+
+    if (BinaryOperator::isEqualityOp(Op)) {
+      SymbolManager &SymMgr = getSymbolManager();
+
+      QualType ExprType = SSE->getType();
+      SymbolRef CanonicalEquality =
+          SymMgr.getSymSymExpr(SSE->getLHS(), BO_EQ, SSE->getRHS(), ExprType);
+
+      bool WasEqual = SSE->getOpcode() == BO_EQ;
+      bool IsExpectedEqual = WasEqual == Assumption;
+
+      const llvm::APSInt &Zero = getBasicVals().getValue(0, ExprType);
+
+      if (IsExpectedEqual) {
+        return assumeSymNE(State, CanonicalEquality, Zero, Zero);
+      }
+
+      return assumeSymEQ(State, CanonicalEquality, Zero, Zero);
+    }
   }
 
   // If we get here, there's nothing else we can do but treat the symbol as
@@ -199,11 +219,6 @@
   }
 }
 
-void *ProgramStateTrait<ConstraintRange>::GDMIndex() {
-  static int Index;
-  return &Index;
-}
-
 } // end of namespace ento
 
 } // end of namespace clang
Index: clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
+++ clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
@@ -391,7 +391,191 @@
   os << " }";
 }
 
+REGISTER_SET_FACTORY_WITH_PROGRAMSTATE(SymbolSet, SymbolRef)
+
+namespace {
+class EquivalenceClass;
+} // end anonymous namespace
+
+REGISTER_MAP_WITH_PROGRAMSTATE(ClassMap, SymbolRef, EquivalenceClass)
+REGISTER_MAP_WITH_PROGRAMSTATE(ClassMembers, EquivalenceClass, SymbolSet)
+REGISTER_MAP_WITH_PROGRAMSTATE(ConstraintRange, EquivalenceClass, RangeSet)
+
 namespace {
+/// This class encapsulates a set of symbols equal to each other.
+///
+/// The main idea of the approach requiring such classes is in narrowing
+/// and sharing constraints between symbols within the class.  Also we can
+/// conclude that there is no practical need in storing constraints for
+/// every member of the class separately.
+///
+/// Main terminology:
+///
+///   * "Equivalence class" is an object of this class, which can be efficiently
+///     compared to other classes.  It represents the whole class without
+///     storing the actual in it.  The members of the class however can be
+///     retrieved from the state.
+///
+///   * "Class members" are the symbols corresponding to the class.  This means
+///     that A == B for every member symbols A and B from the class.  Members of
+///     each class are stored in the state.
+///
+///   * "Trivial class" is a class that has and ever had only one same symbol.
+///
+///   * "Merge (or Union) operation" merges two classes into one.  It is the
+///     main operation to produce non-trivial classes.
+///     If, at some point, we can assume that two symbols from two distinct
+///     classes are equal, we can merge these classes.
+class EquivalenceClass : public llvm::FoldingSetNode {
+public:
+  /// Find equivalence class for the given symbol in the given state.
+  static inline EquivalenceClass find(ProgramStateRef State, SymbolRef Sym);
+
+  /// Merge classes for the given symbols and return a new state.
+  static inline ProgramStateRef merge(BasicValueFactory &BV,
+                                      RangeSet::Factory &F,
+                                      ProgramStateRef State, SymbolRef First,
+                                      SymbolRef Second);
+  // Merge this class with the given class and return a new state.
+  inline ProgramStateRef merge(BasicValueFactory &BV, RangeSet::Factory &F,
+                               ProgramStateRef State, EquivalenceClass Other);
+
+  /// Return a set of class members for the given state.
+  inline SymbolSet getClassMembers(ProgramStateRef State);
+  /// Return true if the current class is trivial in the given state.
+  inline bool isTrivial(ProgramStateRef State);
+  /// Return true if the current class is trivial and its only member is dead.
+  inline bool isTriviallyDead(ProgramStateRef State, SymbolReaper &Reaper);
+
+  EquivalenceClass() = delete;
+  EquivalenceClass(const EquivalenceClass &) = default;
+  EquivalenceClass &operator=(const EquivalenceClass &) = default;
+  EquivalenceClass(EquivalenceClass &&) = default;
+  EquivalenceClass &operator=(EquivalenceClass &&) = default;
+
+  bool operator==(const EquivalenceClass &Other) const {
+    return ID == Other.ID;
+  }
+  bool operator<(const EquivalenceClass &Other) const { return ID < Other.ID; }
+  bool operator!=(const EquivalenceClass &Other) const {
+    return !operator==(Other);
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, uintptr_t CID) {
+    ID.AddInteger(CID);
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, this->ID); }
+
+private:
+  /* implicit */ EquivalenceClass(SymbolRef Sym)
+      : ID(reinterpret_cast<uintptr_t>(Sym)) {}
+
+  /// This function is intended to be used ONLY within the class.
+  /// The fact that ID is a pointer to a symbol is an implementation detail
+  /// and should stay that way.
+  /// In the current implementation, we use it to retrieve the only member
+  /// of the trivial class.
+  SymbolRef getRepresentativeSymbol() const {
+    return reinterpret_cast<SymbolRef>(ID);
+  }
+  static inline SymbolSet::Factory &getMembersFactory(ProgramStateRef State);
+
+  inline ProgramStateRef mergeImpl(BasicValueFactory &BV, RangeSet::Factory &F,
+                                   ProgramStateRef State, SymbolSet Members,
+                                   EquivalenceClass Other,
+                                   SymbolSet OtherMembers);
+
+  /// This is a unique identifier of the class.
+  uintptr_t ID;
+};
+
+inline bool isZero(const llvm::APSInt &Int) {
+  APSIntType Type(Int);
+  return Int == Type.getZeroValue();
+}
+
+//===----------------------------------------------------------------------===//
+//                             Constraint functions
+//===----------------------------------------------------------------------===//
+
+LLVM_NODISCARD inline ProgramStateRef setConstraint(ProgramStateRef State,
+                                                    EquivalenceClass Class,
+                                                    RangeSet Constraint) {
+  return State->set<ConstraintRange>(Class, Constraint);
+}
+
+LLVM_NODISCARD inline ProgramStateRef
+setConstraint(ProgramStateRef State, SymbolRef Sym, RangeSet Constraint) {
+  return setConstraint(State, EquivalenceClass::find(State, Sym), Constraint);
+}
+
+LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State,
+                                                    EquivalenceClass Class) {
+  return State->get<ConstraintRange>(Class);
+}
+
+LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State,
+                                                    SymbolRef Sym) {
+  return getConstraint(State, EquivalenceClass::find(State, Sym));
+}
+
+//===----------------------------------------------------------------------===//
+//                               Equality tracker
+//===----------------------------------------------------------------------===//
+
+/// A small helper structure representing symbolic equality.
+///
+/// Equality check can have different forms (like a == b or a - b) and this
+/// class encapsulates those away if the only thing the user wants to check -
+/// whether it's equality/diseqiality or not and have an easy access to the
+/// compared symbols.
+struct EqualityInfo {
+public:
+  SymbolRef Left, Right;
+  // true for equality and false for disequality.
+  bool IsEquality = true;
+
+  void invert() { IsEquality = !IsEquality; }
+  /// Extract equality information from the given symbol and the constants.
+  ///
+  /// This function assumes the following expression Sym + Adjustment != Int.
+  /// It is a default because the most widespread case of the equality check
+  /// is (A == B) + 0 != 0.
+  static Optional<EqualityInfo> extract(SymbolRef Sym, const llvm::APSInt &Int,
+                                        const llvm::APSInt &Adjustment) {
+    // As of now, the only equality form supported is Sym + 0 != 0.
+    if (!isZero(Int) || !isZero(Adjustment))
+      return llvm::None;
+
+    return extract(Sym);
+  }
+  /// Extract equality information from the given symbol.
+  static Optional<EqualityInfo> extract(SymbolRef Sym) {
+    return EqualityExtractor().Visit(Sym);
+  }
+
+private:
+  class EqualityExtractor
+      : public SymExprVisitor<EqualityExtractor, Optional<EqualityInfo>> {
+  public:
+    Optional<EqualityInfo> VisitSymSymExpr(const SymSymExpr *Sym) const {
+      switch (Sym->getOpcode()) {
+      case BO_Sub:
+        // This case is: A - B != 0 -> disequality check.
+        return EqualityInfo{Sym->getLHS(), Sym->getRHS(), false};
+      case BO_EQ:
+        // This case is: A == B != 0 -> equality check.
+        return EqualityInfo{Sym->getLHS(), Sym->getRHS(), true};
+      case BO_NE:
+        // This case is: A != B != 0 -> diseqiality check.
+        return EqualityInfo{Sym->getLHS(), Sym->getRHS(), false};
+      default:
+        return llvm::None;
+      }
+    }
+  };
+};
 
 //===----------------------------------------------------------------------===//
 //                            Intersection functions
@@ -556,15 +740,16 @@
 
   RangeSet infer(SymbolRef Sym) {
     if (Optional<RangeSet> ConstraintBasedRange = intersect(
-            ValueFactory, RangeFactory, State->get<ConstraintRange>(Sym),
+            ValueFactory, RangeFactory, getConstraint(State, Sym),
             // If Sym is a difference of symbols A - B, then maybe we have range
             // set stored for B - A.
             //
             // If we have range set stored for both A - B and B - A then
             // calculate the effective range set by intersecting the range set
             // for A - B and the negated range set of B - A.
-            getRangeForInvertedSub(Sym)))
+            getRangeForInvertedSub(Sym), getRangeForEqualities(Sym))) {
       return *ConstraintBasedRange;
+    }
 
     // If Sym is a comparison expression (except <=>),
     // find any other comparisons with the same operands.
@@ -745,8 +930,7 @@
         SymbolRef NegatedSym =
             SymMgr.getSymSymExpr(SSE->getRHS(), BO_Sub, SSE->getLHS(), T);
 
-        if (const RangeSet *NegatedRange =
-                State->get<ConstraintRange>(NegatedSym)) {
+        if (const RangeSet *NegatedRange = getConstraint(State, NegatedSym)) {
           return NegatedRange->Negate(ValueFactory, RangeFactory);
         }
       }
@@ -792,7 +976,7 @@
       // Let's find an expression e.g. (x < y).
       BinaryOperatorKind QueriedOP = OperatorRelationsTable::getOpFromIndex(i);
       const SymSymExpr *SymSym = SymMgr.getSymSymExpr(LHS, QueriedOP, RHS, T);
-      const RangeSet *QueriedRangeSet = State->get<ConstraintRange>(SymSym);
+      const RangeSet *QueriedRangeSet = getConstraint(State, SymSym);
 
       // If ranges were not previously found,
       // try to find a reversed expression (y > x).
@@ -800,7 +984,7 @@
         const BinaryOperatorKind ROP =
             BinaryOperator::reverseComparisonOp(QueriedOP);
         SymSym = SymMgr.getSymSymExpr(RHS, ROP, LHS, T);
-        QueriedRangeSet = State->get<ConstraintRange>(SymSym);
+        QueriedRangeSet = getConstraint(State, SymSym);
       }
 
       if (!QueriedRangeSet || QueriedRangeSet->isEmpty())
@@ -838,6 +1022,27 @@
     return llvm::None;
   }
 
+  Optional<RangeSet> getRangeForEqualities(SymbolRef Sym) {
+    Optional<EqualityInfo> Equality = EqualityInfo::extract(Sym);
+
+    if (!Equality)
+      return llvm::None;
+
+    EquivalenceClass LHS = EquivalenceClass::find(State, Equality->Left);
+    EquivalenceClass RHS = EquivalenceClass::find(State, Equality->Right);
+
+    if (LHS != RHS)
+      // Can't really say anything at this point.
+      // We can add more logic here if we track disequalities as well.
+      return llvm::None;
+
+    // At this point, operands of the equality operation are known to be equal.
+    if (Equality->IsEquality) {
+      return getTrueRange(Sym->getType());
+    }
+    return getFalseRange(Sym->getType());
+  }
+
   RangeSet getTrueRange(QualType T) {
     RangeSet TypeRange = infer(T);
     return assumeNonZero(TypeRange, T);
@@ -1032,7 +1237,11 @@
 
   bool haveEqualConstraints(ProgramStateRef S1,
                             ProgramStateRef S2) const override {
-    return S1->get<ConstraintRange>() == S2->get<ConstraintRange>();
+    // NOTE: ClassMembers are as simple as back pointers for ClassMap,
+    //       so comparing constraint ranges and class maps should be
+    //       sufficient.
+    return S1->get<ConstraintRange>() == S2->get<ConstraintRange>() &&
+           S1->get<ClassMap>() == S2->get<ClassMap>();
   }
 
   bool canReasonAbout(SVal X) const override;
@@ -1104,6 +1313,49 @@
   RangeSet getSymGERange(ProgramStateRef St, SymbolRef Sym,
                          const llvm::APSInt &Int,
                          const llvm::APSInt &Adjustment);
+
+  //===------------------------------------------------------------------===//
+  // Equality tracking implementation
+  //===------------------------------------------------------------------===//
+
+  ProgramStateRef trackEQ(ProgramStateRef State, SymbolRef Sym,
+                          const llvm::APSInt &Int,
+                          const llvm::APSInt &Adjustment) {
+    if (auto Equality = EqualityInfo::extract(Sym, Int, Adjustment)) {
+      // Extract function assumes that we gave it Sym + Adjustment != Int,
+      // so the result should be opposite.
+      Equality->invert();
+      return track(State, *Equality);
+    }
+
+    return State;
+  }
+
+  ProgramStateRef trackNE(ProgramStateRef State, SymbolRef Sym,
+                          const llvm::APSInt &Int,
+                          const llvm::APSInt &Adjustment) {
+    if (auto Equality = EqualityInfo::extract(Sym, Int, Adjustment)) {
+      return track(State, *Equality);
+    }
+
+    return State;
+  }
+
+  ProgramStateRef track(ProgramStateRef State, EqualityInfo ToTrack) {
+    if (ToTrack.IsEquality) {
+      return trackEquality(State, ToTrack.Left, ToTrack.Right);
+    }
+    return trackDisequality(State, ToTrack.Left, ToTrack.Right);
+  }
+
+  ProgramStateRef trackDisequality(ProgramStateRef State, SymbolRef LHS,
+                                   SymbolRef RHS) {
+    // TODO: track inequalities
+    return State;
+  }
+
+  ProgramStateRef trackEquality(ProgramStateRef State, SymbolRef LHS,
+                                SymbolRef RHS);
 };
 
 } // end anonymous namespace
@@ -1114,6 +1366,153 @@
   return std::make_unique<RangeConstraintManager>(Eng, StMgr.getSValBuilder());
 }
 
+ConstraintMap ento::getConstraintMap(ProgramStateRef State) {
+  ConstraintMap::Factory &F = State->get_context<ConstraintMap>();
+  ConstraintMap Result = F.getEmptyMap();
+
+  ConstraintRangeTy Constraints = State->get<ConstraintRange>();
+  for (std::pair<EquivalenceClass, RangeSet> ClassConstraint : Constraints) {
+    EquivalenceClass Class = ClassConstraint.first;
+    SymbolSet ClassMembers = Class.getClassMembers(State);
+    assert(!ClassMembers.isEmpty() &&
+           "Class must always have at least one member!");
+
+    SymbolRef Representative = *ClassMembers.begin();
+    Result = F.add(Result, Representative, ClassConstraint.second);
+  }
+
+  return Result;
+}
+
+//===----------------------------------------------------------------------===//
+//                     EqualityClass implementation details
+//===----------------------------------------------------------------------===//
+
+inline EquivalenceClass EquivalenceClass::find(ProgramStateRef State,
+                                               SymbolRef Sym) {
+  if (const EquivalenceClass *NontrivialClass = State->get<ClassMap>(Sym))
+    return *NontrivialClass;
+
+  // This is a trivial class of Sym.
+  return Sym;
+}
+
+inline ProgramStateRef EquivalenceClass::merge(BasicValueFactory &BV,
+                                               RangeSet::Factory &F,
+                                               ProgramStateRef State,
+                                               SymbolRef First,
+                                               SymbolRef Second) {
+  EquivalenceClass FirstClass = find(State, First);
+  EquivalenceClass SecondClass = find(State, Second);
+
+  return FirstClass.merge(BV, F, State, SecondClass);
+}
+
+inline ProgramStateRef EquivalenceClass::merge(BasicValueFactory &BV,
+                                               RangeSet::Factory &F,
+                                               ProgramStateRef State,
+                                               EquivalenceClass Other) {
+  // It is already the same class.
+  if (*this == Other)
+    return State;
+
+  SymbolSet Members = getClassMembers(State);
+  SymbolSet OtherMembers = Other.getClassMembers(State);
+
+  // We estimate the size of the class by the height of tree containing
+  // its members.  Merging is not a trivial operation, so it's easier to
+  // merge the smaller class into the bigger one.
+  if (Members.getHeight() >= OtherMembers.getHeight()) {
+    return mergeImpl(BV, F, State, Members, Other, OtherMembers);
+  } else {
+    return Other.mergeImpl(BV, F, State, OtherMembers, *this, Members);
+  }
+}
+
+inline ProgramStateRef
+EquivalenceClass::mergeImpl(BasicValueFactory &ValueFactory,
+                            RangeSet::Factory &RangeFactory,
+                            ProgramStateRef State, SymbolSet MyMembers,
+                            EquivalenceClass Other, SymbolSet OtherMembers) {
+  // 1. Get ALL constraint- and equivalence-related maps
+  ClassMapTy Classes = State->get<ClassMap>();
+  ClassMapTy::Factory &CF = State->get_context<ClassMap>();
+
+  ClassMembersTy Members = State->get<ClassMembers>();
+  ClassMembersTy::Factory &MF = State->get_context<ClassMembers>();
+
+  ConstraintRangeTy Constraints = State->get<ConstraintRange>();
+  ConstraintRangeTy::Factory &CRF = State->get_context<ConstraintRange>();
+
+  SymbolSet::Factory &F = getMembersFactory(State);
+
+  // 2. Merge members of the Other class into the current class.
+  SymbolSet NewClassMembers = MyMembers;
+  for (SymbolRef Sym : OtherMembers) {
+    NewClassMembers = F.add(NewClassMembers, Sym);
+    // *this is now the class for all these new symbols.
+    Classes = CF.add(Classes, Sym, *this);
+  }
+
+  // 3. Adjust member mapping.
+  //
+  // No need in tracking members of a now-dissolved class.
+  Members = MF.remove(Members, Other);
+  // Now only the current class is mapped to all the symbols.
+  Members = MF.add(Members, *this, NewClassMembers);
+
+  // 4. Update the state
+  State = State->set<ClassMap>(Classes);
+  State = State->set<ClassMembers>(Members);
+
+  // 5. If the merged classes have any constraints associated with them, we
+  //    need to transfer them to the class we have left.
+  //
+  // Intersection here makes perfect sense because both of these constraints
+  // must hold for the whole new class.
+  if (Optional<RangeSet> NewClassConstraint =
+          intersect(ValueFactory, RangeFactory, getConstraint(State, *this),
+                    getConstraint(State, Other))) {
+    // NOTE: Essentially, NewClassConstraint should NEVER be infeasible because
+    //       we shouldn't make assumptions that can lead to that.
+    //       However, at the moment, due to imperfections in the solver, it is
+    //       possible.
+    //
+    // No need in tracking constraints of a now-dissolved class.
+    Constraints = CRF.remove(Constraints, Other);
+    // Assign new constraints for this class.
+    Constraints = CRF.add(Constraints, *this, *NewClassConstraint);
+
+    State = State->set<ConstraintRange>(Constraints);
+  }
+
+  return State;
+}
+
+inline SymbolSet::Factory &
+EquivalenceClass::getMembersFactory(ProgramStateRef State) {
+  return State->get_context<SymbolSet>();
+}
+
+SymbolSet EquivalenceClass::getClassMembers(ProgramStateRef State) {
+  if (const SymbolSet *Members = State->get<ClassMembers>(*this))
+    return *Members;
+
+  // This class is trivial, so we need to construct a set
+  // with just that one symbol from the class.
+  SymbolSet::Factory &F = getMembersFactory(State);
+  return F.add(F.getEmptySet(), getRepresentativeSymbol());
+}
+
+bool EquivalenceClass::isTrivial(ProgramStateRef State) {
+  return State->get<ClassMembers>(*this) == nullptr;
+}
+
+bool EquivalenceClass::isTriviallyDead(ProgramStateRef State,
+                                       SymbolReaper &Reaper) {
+  return isTrivial(State) && Reaper.isDead(getRepresentativeSymbol());
+}
+
 //===----------------------------------------------------------------------===//
 //                    RangeConstraintManager implementation
 //===----------------------------------------------------------------------===//
@@ -1166,7 +1565,7 @@
 
 ConditionTruthVal RangeConstraintManager::checkNull(ProgramStateRef State,
                                                     SymbolRef Sym) {
-  const RangeSet *Ranges = State->get<ConstraintRange>(Sym);
+  const RangeSet *Ranges = getConstraint(State, Sym);
 
   // If we don't have any information about this symbol, it's underconstrained.
   if (!Ranges)
@@ -1190,7 +1589,7 @@
 
 const llvm::APSInt *RangeConstraintManager::getSymVal(ProgramStateRef St,
                                                       SymbolRef Sym) const {
-  const ConstraintRangeTy::data_type *T = St->get<ConstraintRange>(Sym);
+  const RangeSet *T = getConstraint(St, Sym);
   return T ? T->getConcreteValue() : nullptr;
 }
 
@@ -1203,19 +1602,94 @@
 ProgramStateRef
 RangeConstraintManager::removeDeadBindings(ProgramStateRef State,
                                            SymbolReaper &SymReaper) {
-  bool Changed = false;
-  ConstraintRangeTy CR = State->get<ConstraintRange>();
-  ConstraintRangeTy::Factory &CRFactory = State->get_context<ConstraintRange>();
+  ClassMembersTy ClassMembersMap = State->get<ClassMembers>();
+  ClassMembersTy NewClassMembersMap = ClassMembersMap;
+  ClassMembersTy::Factory &EMFactory = State->get_context<ClassMembers>();
+  SymbolSet::Factory &SetFactory = State->get_context<SymbolSet>();
+
+  ConstraintRangeTy Constraints = State->get<ConstraintRange>();
+  ConstraintRangeTy NewConstraints = Constraints;
+  ConstraintRangeTy::Factory &ConstraintFactory =
+      State->get_context<ConstraintRange>();
+
+  ClassMapTy Map = State->get<ClassMap>();
+  ClassMapTy NewMap = Map;
+  ClassMapTy::Factory &ClassFactory = State->get_context<ClassMap>();
+
+  bool ClassMapChanged = false;
+  bool MembersMapChanged = false;
+  bool ConstraintMapChanged = false;
+
+  // 1. Let's see if dead symbols are trivial and have associated constraints.
+  for (std::pair<EquivalenceClass, RangeSet> ClassConstraintPair :
+       Constraints) {
+    EquivalenceClass Class = ClassConstraintPair.first;
+    if (Class.isTriviallyDead(State, SymReaper)) {
+      // If this class is trivial, we can remove its constraints right away.
+      Constraints = ConstraintFactory.remove(Constraints, Class);
+      ConstraintMapChanged = true;
+    }
+  }
+
+  // 2. We don't need to track classes for dead symbols.
+  for (std::pair<SymbolRef, EquivalenceClass> SymbolClassPair : Map) {
+    SymbolRef Sym = SymbolClassPair.first;
 
-  for (ConstraintRangeTy::iterator I = CR.begin(), E = CR.end(); I != E; ++I) {
-    SymbolRef Sym = I.getKey();
     if (SymReaper.isDead(Sym)) {
-      Changed = true;
-      CR = CRFactory.remove(CR, Sym);
+      ClassMapChanged = true;
+      NewMap = ClassFactory.remove(NewMap, Sym);
+    }
+  }
+
+  // 3. Remove dead members from classes and remove dead non-trivial classes
+  //    and their constraints.
+  for (std::pair<EquivalenceClass, SymbolSet> ClassMembersPair :
+       ClassMembersMap) {
+    SymbolSet LiveMembers = ClassMembersPair.second;
+    bool MembersChanged = false;
+
+    for (SymbolRef Member : ClassMembersPair.second) {
+      if (SymReaper.isDead(Member)) {
+        MembersChanged = true;
+        LiveMembers = SetFactory.remove(LiveMembers, Member);
+      }
+    }
+
+    // Check if the class changed.
+    if (!MembersChanged)
+      continue;
+
+    MembersMapChanged = true;
+
+    if (LiveMembers.isEmpty()) {
+      // The class is dead now, we need to wipe it out of the members map...
+      NewClassMembersMap =
+          EMFactory.remove(NewClassMembersMap, ClassMembersPair.first);
+
+      // ...and remove all of its constraints.
+      Constraints =
+          ConstraintFactory.remove(Constraints, ClassMembersPair.first);
+      ConstraintMapChanged = true;
+    } else {
+      // We need to change the members associated with the class.
+      NewClassMembersMap = EMFactory.add(NewClassMembersMap,
+                                         ClassMembersPair.first, LiveMembers);
     }
   }
 
-  return Changed ? State->set<ConstraintRange>(CR) : State;
+  // 4. Update the state with new maps.
+  //
+  // Here we try to be humble and update a map only if it really changed.
+  if (ClassMapChanged)
+    State = State->set<ClassMap>(NewMap);
+
+  if (MembersMapChanged)
+    State = State->set<ClassMembers>(NewClassMembersMap);
+
+  if (ConstraintMapChanged)
+    State = State->set<ConstraintRange>(Constraints);
+
+  return State;
 }
 
 RangeSet RangeConstraintManager::getRange(ProgramStateRef State,
@@ -1247,7 +1721,13 @@
   llvm::APSInt Point = AdjustmentType.convert(Int) - Adjustment;
 
   RangeSet New = getRange(St, Sym).Delete(getBasicVals(), F, Point);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+
+  if (New.isEmpty())
+    // this is infeasible assumption
+    return nullptr;
+
+  ProgramStateRef NewState = setConstraint(St, Sym, New);
+  return trackNE(NewState, Sym, Int, Adjustment);
 }
 
 ProgramStateRef
@@ -1262,7 +1742,13 @@
   // [Int-Adjustment, Int-Adjustment]
   llvm::APSInt AdjInt = AdjustmentType.convert(Int) - Adjustment;
   RangeSet New = getRange(St, Sym).Intersect(getBasicVals(), F, AdjInt, AdjInt);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+
+  if (New.isEmpty())
+    // this is infeasible assumption
+    return nullptr;
+
+  ProgramStateRef NewState = setConstraint(St, Sym, New);
+  return trackEQ(NewState, Sym, Int, Adjustment);
 }
 
 RangeSet RangeConstraintManager::getSymLTRange(ProgramStateRef St,
@@ -1298,7 +1784,7 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymLTRange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
 RangeSet RangeConstraintManager::getSymGTRange(ProgramStateRef St,
@@ -1334,7 +1820,7 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymGTRange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
 RangeSet RangeConstraintManager::getSymGERange(ProgramStateRef St,
@@ -1370,13 +1856,13 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymGERange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
-RangeSet RangeConstraintManager::getSymLERange(
-      llvm::function_ref<RangeSet()> RS,
-      const llvm::APSInt &Int,
-      const llvm::APSInt &Adjustment) {
+RangeSet
+RangeConstraintManager::getSymLERange(llvm::function_ref<RangeSet()> RS,
+                                      const llvm::APSInt &Int,
+                                      const llvm::APSInt &Adjustment) {
   // Before we do any real work, see if the value can even show up.
   APSIntType AdjustmentType(Adjustment);
   switch (AdjustmentType.testInRange(Int, true)) {
@@ -1413,7 +1899,7 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymLERange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
 ProgramStateRef RangeConstraintManager::assumeSymWithinInclusiveRange(
@@ -1423,7 +1909,7 @@
   if (New.isEmpty())
     return nullptr;
   RangeSet Out = getSymLERange([&] { return New; }, To, Adjustment);
-  return Out.isEmpty() ? nullptr : State->set<ConstraintRange>(Sym, Out);
+  return Out.isEmpty() ? nullptr : setConstraint(State, Sym, Out);
 }
 
 ProgramStateRef RangeConstraintManager::assumeSymOutsideInclusiveRange(
@@ -1432,7 +1918,13 @@
   RangeSet RangeLT = getSymLTRange(State, Sym, From, Adjustment);
   RangeSet RangeGT = getSymGTRange(State, Sym, To, Adjustment);
   RangeSet New(RangeLT.addRange(F, RangeGT));
-  return New.isEmpty() ? nullptr : State->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(State, Sym, New);
+}
+
+ProgramStateRef RangeConstraintManager::trackEquality(ProgramStateRef State,
+                                                      SymbolRef LHS,
+                                                      SymbolRef RHS) {
+  return EquivalenceClass::merge(getBasicVals(), F, State, LHS, RHS);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1452,17 +1944,25 @@
 
   ++Space;
   Out << '[' << NL;
-  for (ConstraintRangeTy::iterator I = Constraints.begin();
-       I != Constraints.end(); ++I) {
-    Indent(Out, Space, IsDot)
-        << "{ \"symbol\": \"" << I.getKey() << "\", \"range\": \"";
-    I.getData().print(Out);
-    Out << "\" }";
-
-    if (std::next(I) != Constraints.end())
-      Out << ',';
-    Out << NL;
+  bool First = true;
+  for (std::pair<EquivalenceClass, RangeSet> P : Constraints) {
+    SymbolSet ClassMembers = P.first.getClassMembers(State);
+
+    // We can print the same constraint for every class member.
+    for (SymbolRef ClassMember : ClassMembers) {
+      if (First) {
+        First = false;
+      } else {
+        Out << ',';
+        Out << NL;
+      }
+      Indent(Out, Space, IsDot)
+          << "{ \"symbol\": \"" << ClassMember << "\", \"range\": \"";
+      P.second.print(Out);
+      Out << "\" }";
+    }
   }
+  Out << NL;
 
   --Space;
   Indent(Out, Space, IsDot) << "]," << NL;
Index: clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
+++ clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
@@ -2813,7 +2813,7 @@
 //===----------------------------------------------------------------------===//
 
 FalsePositiveRefutationBRVisitor::FalsePositiveRefutationBRVisitor()
-    : Constraints(ConstraintRangeTy::Factory().getEmptyMap()) {}
+    : Constraints(ConstraintMap::Factory().getEmptyMap()) {}
 
 void FalsePositiveRefutationBRVisitor::finalizeVisitor(
     BugReporterContext &BRC, const ExplodedNode *EndPathNode,
@@ -2855,9 +2855,8 @@
 PathDiagnosticPieceRef FalsePositiveRefutationBRVisitor::VisitNode(
     const ExplodedNode *N, BugReporterContext &, PathSensitiveBugReport &) {
   // Collect new constraints
-  const ConstraintRangeTy &NewCs = N->getState()->get<ConstraintRange>();
-  ConstraintRangeTy::Factory &CF =
-      N->getState()->get_context<ConstraintRange>();
+  ConstraintMap NewCs = getConstraintMap(N->getState());
+  ConstraintMap::Factory &CF = N->getState()->get_context<ConstraintMap>();
 
   // Add constraints if we don't have them yet
   for (auto const &C : NewCs) {
Index: clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
===================================================================
--- clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
+++ clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
@@ -136,14 +136,8 @@
   }
 };
 
-class ConstraintRange {};
-using ConstraintRangeTy = llvm::ImmutableMap<SymbolRef, RangeSet>;
-
-template <>
-struct ProgramStateTrait<ConstraintRange>
-    : public ProgramStatePartialTrait<ConstraintRangeTy> {
-  static void *GDMIndex();
-};
+using ConstraintMap = llvm::ImmutableMap<SymbolRef, RangeSet>;
+ConstraintMap getConstraintMap(ProgramStateRef State);
 
 class RangedConstraintManager : public SimpleConstraintManager {
 public:
@@ -222,4 +216,6 @@
 } // namespace ento
 } // namespace clang
 
+REGISTER_FACTORY_WITH_PROGRAMSTATE(ConstraintMap);
+
 #endif
Index: clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
===================================================================
--- clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
+++ clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
@@ -373,7 +373,7 @@
 class FalsePositiveRefutationBRVisitor final : public BugReporterVisitor {
 private:
   /// Holds the constraints in a given path
-  ConstraintRangeTy Constraints;
+  ConstraintMap Constraints;
 
 public:
   FalsePositiveRefutationBRVisitor();
@@ -388,7 +388,6 @@
                        PathSensitiveBugReport &BR) override;
 };
 
-
 /// The visitor detects NoteTags and displays the event notes they contain.
 class TagVisitor : public BugReporterVisitor {
 public:
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to