This revision was automatically updated to reflect the committed changes.
Closed by commit rGb13d9878b8dc: [analyzer][solver] Track symbol equivalence 
(authored by vsavchenko).

Changed prior to commit:
  https://reviews.llvm.org/D82445?vs=275981&id=279081#toc

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D82445/new/

https://reviews.llvm.org/D82445

Files:
  clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
  
clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
  clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
  clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
  clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
  clang/test/Analysis/equality_tracking.c

Index: clang/test/Analysis/equality_tracking.c
===================================================================
--- /dev/null
+++ clang/test/Analysis/equality_tracking.c
@@ -0,0 +1,167 @@
+// RUN: %clang_analyze_cc1 -verify %s \
+// RUN:   -analyzer-checker=core,debug.ExprInspection \
+// RUN:   -analyzer-config eagerly-assume=false
+
+#define NULL (void *)0
+
+#define UCHAR_MAX (unsigned char)(~0U)
+#define CHAR_MAX (char)(UCHAR_MAX & (UCHAR_MAX >> 1))
+#define CHAR_MIN (char)(UCHAR_MAX & ~(UCHAR_MAX >> 1))
+
+void clang_analyzer_eval(int);
+void clang_analyzer_warnIfReached();
+
+int getInt();
+
+void zeroImpliesEquality(int a, int b) {
+  clang_analyzer_eval((a - b) == 0); // expected-warning{{UNKNOWN}}
+  if ((a - b) == 0) {
+    clang_analyzer_eval(b != a);    // expected-warning{{FALSE}}
+    clang_analyzer_eval(b == a);    // expected-warning{{TRUE}}
+    clang_analyzer_eval(!(a != b)); // expected-warning{{TRUE}}
+    clang_analyzer_eval(!(b == a)); // expected-warning{{FALSE}}
+    return;
+  }
+  clang_analyzer_eval((a - b) == 0); // expected-warning{{FALSE}}
+  // FIXME: we should track disequality information as well
+  clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}}
+  clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}}
+}
+
+void zeroImpliesReversedEqual(int a, int b) {
+  clang_analyzer_eval((b - a) == 0); // expected-warning{{UNKNOWN}}
+  if ((b - a) == 0) {
+    clang_analyzer_eval(b != a); // expected-warning{{FALSE}}
+    clang_analyzer_eval(b == a); // expected-warning{{TRUE}}
+    return;
+  }
+  clang_analyzer_eval((b - a) == 0); // expected-warning{{FALSE}}
+  // FIXME: we should track disequality information as well
+  clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}}
+  clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}}
+}
+
+void canonicalEqual(int a, int b) {
+  clang_analyzer_eval(a == b); // expected-warning{{UNKNOWN}}
+  if (a == b) {
+    clang_analyzer_eval(b == a); // expected-warning{{TRUE}}
+    return;
+  }
+  clang_analyzer_eval(a == b); // expected-warning{{FALSE}}
+  clang_analyzer_eval(b == a); // expected-warning{{FALSE}}
+}
+
+void test(int a, int b, int c, int d) {
+  if (a == b && c == d) {
+    if (a == 0 && b == d) {
+      clang_analyzer_eval(c == 0); // expected-warning{{TRUE}}
+    }
+    c = 10;
+    if (b == d) {
+      clang_analyzer_eval(c == 10); // expected-warning{{TRUE}}
+      clang_analyzer_eval(d == 10); // expected-warning{{UNKNOWN}}
+                                    // expected-warning@-1{{FALSE}}
+      clang_analyzer_eval(b == a);  // expected-warning{{TRUE}}
+      clang_analyzer_eval(a == d);  // expected-warning{{TRUE}}
+
+      b = getInt();
+      clang_analyzer_eval(a == d); // expected-warning{{TRUE}}
+      clang_analyzer_eval(a == b); // expected-warning{{UNKNOWN}}
+    }
+  }
+
+  if (a != b && b == c) {
+    if (c == 42) {
+      clang_analyzer_eval(b == 42); // expected-warning{{TRUE}}
+      // FIXME: we should track disequality information as well
+      clang_analyzer_eval(a != 42); // expected-warning{{UNKNOWN}}
+    }
+  }
+}
+
+void testIntersection(int a, int b, int c) {
+  if (a < 42 && b > 15 && c >= 25 && c <= 30) {
+    if (a != b)
+      return;
+
+    clang_analyzer_eval(a > 15);  // expected-warning{{TRUE}}
+    clang_analyzer_eval(b < 42);  // expected-warning{{TRUE}}
+    clang_analyzer_eval(a <= 30); // expected-warning{{UNKNOWN}}
+
+    if (c == b) {
+      // For all equal symbols, we should track the minimal common range.
+      //
+      // Also, it should be noted that c is dead at this point, but the
+      // constraint initially associated with c is still around.
+      clang_analyzer_eval(a >= 25 && a <= 30); // expected-warning{{TRUE}}
+      clang_analyzer_eval(b >= 25 && b <= 30); // expected-warning{{TRUE}}
+    }
+  }
+}
+
+void testPromotion(int a, char b) {
+  if (b > 10) {
+    if (a == b) {
+      // FIXME: support transferring char ranges onto equal int symbols
+      //        when char is promoted to int
+      clang_analyzer_eval(a > 10);        // expected-warning{{UNKNOWN}}
+      clang_analyzer_eval(a <= CHAR_MAX); // expected-warning{{UNKNOWN}}
+    }
+  }
+}
+
+void testPromotionOnlyTypes(int a, char b) {
+  if (a == b) {
+    // FIXME: support transferring char ranges onto equal int symbols
+    //        when char is promoted to int
+    clang_analyzer_eval(a <= CHAR_MAX); // expected-warning{{UNKNOWN}}
+  }
+}
+
+void testDowncast(int a, unsigned char b) {
+  if (a <= -10) {
+    if ((unsigned char)a == b) {
+      // Even though ranges for a and b do not intersect,
+      // ranges for (unsigned char)a and b do.
+      clang_analyzer_warnIfReached(); // expected-warning{{REACHABLE}}
+    }
+    if (a == b) {
+      // FIXME: This case on the other hand is different, it shouldn't be
+      //        reachable.  However, the corrent symbolic information available
+      //        to the solver doesn't allow it to distinguish this expression
+      //        from the previous one.
+      clang_analyzer_warnIfReached(); // expected-warning{{REACHABLE}}
+    }
+  }
+}
+
+void testPointers(int *a, int *b, int *c, int *d) {
+  if (a == b && c == d) {
+    if (a == NULL && b == d) {
+      clang_analyzer_eval(c == NULL); // expected-warning{{TRUE}}
+    }
+  }
+
+  if (a != b && b == c) {
+    if (c == NULL) {
+      // FIXME: we should track disequality information as well
+      clang_analyzer_eval(a != NULL); // expected-warning{{UNKNOWN}}
+    }
+  }
+}
+
+void avoidInfeasibleConstraintsForClasses(int a, int b) {
+  if (a >= 0 && a <= 10 && b >= 20 && b <= 50) {
+    if ((b - a) == 0) {
+      clang_analyzer_warnIfReached(); // no warning
+    }
+    if (a == b) {
+      clang_analyzer_warnIfReached(); // no warning
+    }
+    if (a != b) {
+      clang_analyzer_warnIfReached(); // expected-warning{{REACHABLE}}
+    } else {
+      clang_analyzer_warnIfReached(); // no warning
+    }
+  }
+}
Index: clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
+++ clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp
@@ -40,19 +40,20 @@
     }
 
   } else if (const SymSymExpr *SSE = dyn_cast<SymSymExpr>(Sym)) {
-    // Translate "a != b" to "(b - a) != 0".
-    // We invert the order of the operands as a heuristic for how loop
-    // conditions are usually written ("begin != end") as compared to length
-    // calculations ("end - begin"). The more correct thing to do would be to
-    // canonicalize "a - b" and "b - a", which would allow us to treat
-    // "a != b" and "b != a" the same.
-    SymbolManager &SymMgr = getSymbolManager();
     BinaryOperator::Opcode Op = SSE->getOpcode();
     assert(BinaryOperator::isComparisonOp(Op));
 
-    // For now, we only support comparing pointers.
+    // We convert equality operations for pointers only.
     if (Loc::isLocType(SSE->getLHS()->getType()) &&
         Loc::isLocType(SSE->getRHS()->getType())) {
+      // Translate "a != b" to "(b - a) != 0".
+      // We invert the order of the operands as a heuristic for how loop
+      // conditions are usually written ("begin != end") as compared to length
+      // calculations ("end - begin"). The more correct thing to do would be to
+      // canonicalize "a - b" and "b - a", which would allow us to treat
+      // "a != b" and "b != a" the same.
+
+      SymbolManager &SymMgr = getSymbolManager();
       QualType DiffTy = SymMgr.getContext().getPointerDiffType();
       SymbolRef Subtraction =
           SymMgr.getSymSymExpr(SSE->getRHS(), BO_Sub, SSE->getLHS(), DiffTy);
@@ -63,6 +64,25 @@
         Op = BinaryOperator::negateComparisonOp(Op);
       return assumeSymRel(State, Subtraction, Op, Zero);
     }
+
+    if (BinaryOperator::isEqualityOp(Op)) {
+      SymbolManager &SymMgr = getSymbolManager();
+
+      QualType ExprType = SSE->getType();
+      SymbolRef CanonicalEquality =
+          SymMgr.getSymSymExpr(SSE->getLHS(), BO_EQ, SSE->getRHS(), ExprType);
+
+      bool WasEqual = SSE->getOpcode() == BO_EQ;
+      bool IsExpectedEqual = WasEqual == Assumption;
+
+      const llvm::APSInt &Zero = getBasicVals().getValue(0, ExprType);
+
+      if (IsExpectedEqual) {
+        return assumeSymNE(State, CanonicalEquality, Zero, Zero);
+      }
+
+      return assumeSymEQ(State, CanonicalEquality, Zero, Zero);
+    }
   }
 
   // If we get here, there's nothing else we can do but treat the symbol as
@@ -199,11 +219,6 @@
   }
 }
 
-void *ProgramStateTrait<ConstraintRange>::GDMIndex() {
-  static int Index;
-  return &Index;
-}
-
 } // end of namespace ento
 
 } // end of namespace clang
Index: clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
+++ clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp
@@ -391,7 +391,197 @@
   os << " }";
 }
 
+REGISTER_SET_FACTORY_WITH_PROGRAMSTATE(SymbolSet, SymbolRef)
+
+namespace {
+class EquivalenceClass;
+} // end anonymous namespace
+
+REGISTER_MAP_WITH_PROGRAMSTATE(ClassMap, SymbolRef, EquivalenceClass)
+REGISTER_MAP_WITH_PROGRAMSTATE(ClassMembers, EquivalenceClass, SymbolSet)
+REGISTER_MAP_WITH_PROGRAMSTATE(ConstraintRange, EquivalenceClass, RangeSet)
+
 namespace {
+/// This class encapsulates a set of symbols equal to each other.
+///
+/// The main idea of the approach requiring such classes is in narrowing
+/// and sharing constraints between symbols within the class.  Also we can
+/// conclude that there is no practical need in storing constraints for
+/// every member of the class separately.
+///
+/// Main terminology:
+///
+///   * "Equivalence class" is an object of this class, which can be efficiently
+///     compared to other classes.  It represents the whole class without
+///     storing the actual in it.  The members of the class however can be
+///     retrieved from the state.
+///
+///   * "Class members" are the symbols corresponding to the class.  This means
+///     that A == B for every member symbols A and B from the class.  Members of
+///     each class are stored in the state.
+///
+///   * "Trivial class" is a class that has and ever had only one same symbol.
+///
+///   * "Merge operation" merges two classes into one.  It is the main operation
+///     to produce non-trivial classes.
+///     If, at some point, we can assume that two symbols from two distinct
+///     classes are equal, we can merge these classes.
+class EquivalenceClass : public llvm::FoldingSetNode {
+public:
+  /// Find equivalence class for the given symbol in the given state.
+  LLVM_NODISCARD static inline EquivalenceClass find(ProgramStateRef State,
+                                                     SymbolRef Sym);
+
+  /// Merge classes for the given symbols and return a new state.
+  LLVM_NODISCARD static inline ProgramStateRef
+  merge(BasicValueFactory &BV, RangeSet::Factory &F, ProgramStateRef State,
+        SymbolRef First, SymbolRef Second);
+  // Merge this class with the given class and return a new state.
+  LLVM_NODISCARD inline ProgramStateRef merge(BasicValueFactory &BV,
+                                              RangeSet::Factory &F,
+                                              ProgramStateRef State,
+                                              EquivalenceClass Other);
+
+  /// Return a set of class members for the given state.
+  LLVM_NODISCARD inline SymbolSet getClassMembers(ProgramStateRef State);
+  /// Return true if the current class is trivial in the given state.
+  LLVM_NODISCARD inline bool isTrivial(ProgramStateRef State);
+  /// Return true if the current class is trivial and its only member is dead.
+  LLVM_NODISCARD inline bool isTriviallyDead(ProgramStateRef State,
+                                             SymbolReaper &Reaper);
+
+  /// Check equivalence data for consistency.
+  LLVM_NODISCARD LLVM_ATTRIBUTE_UNUSED static bool
+  isClassDataConsistent(ProgramStateRef State);
+
+  LLVM_NODISCARD QualType getType() const {
+    return getRepresentativeSymbol()->getType();
+  }
+
+  EquivalenceClass() = delete;
+  EquivalenceClass(const EquivalenceClass &) = default;
+  EquivalenceClass &operator=(const EquivalenceClass &) = delete;
+  EquivalenceClass(EquivalenceClass &&) = default;
+  EquivalenceClass &operator=(EquivalenceClass &&) = delete;
+
+  bool operator==(const EquivalenceClass &Other) const {
+    return ID == Other.ID;
+  }
+  bool operator<(const EquivalenceClass &Other) const { return ID < Other.ID; }
+  bool operator!=(const EquivalenceClass &Other) const {
+    return !operator==(Other);
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, uintptr_t CID) {
+    ID.AddInteger(CID);
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, this->ID); }
+
+private:
+  /* implicit */ EquivalenceClass(SymbolRef Sym)
+      : ID(reinterpret_cast<uintptr_t>(Sym)) {}
+
+  /// This function is intended to be used ONLY within the class.
+  /// The fact that ID is a pointer to a symbol is an implementation detail
+  /// and should stay that way.
+  /// In the current implementation, we use it to retrieve the only member
+  /// of the trivial class.
+  SymbolRef getRepresentativeSymbol() const {
+    return reinterpret_cast<SymbolRef>(ID);
+  }
+  static inline SymbolSet::Factory &getMembersFactory(ProgramStateRef State);
+
+  inline ProgramStateRef mergeImpl(BasicValueFactory &BV, RangeSet::Factory &F,
+                                   ProgramStateRef State, SymbolSet Members,
+                                   EquivalenceClass Other,
+                                   SymbolSet OtherMembers);
+
+  /// This is a unique identifier of the class.
+  uintptr_t ID;
+};
+
+//===----------------------------------------------------------------------===//
+//                             Constraint functions
+//===----------------------------------------------------------------------===//
+
+LLVM_NODISCARD inline ProgramStateRef setConstraint(ProgramStateRef State,
+                                                    EquivalenceClass Class,
+                                                    RangeSet Constraint) {
+  return State->set<ConstraintRange>(Class, Constraint);
+}
+
+LLVM_NODISCARD inline ProgramStateRef
+setConstraint(ProgramStateRef State, SymbolRef Sym, RangeSet Constraint) {
+  return setConstraint(State, EquivalenceClass::find(State, Sym), Constraint);
+}
+
+LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State,
+                                                    EquivalenceClass Class) {
+  return State->get<ConstraintRange>(Class);
+}
+
+LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State,
+                                                    SymbolRef Sym) {
+  return getConstraint(State, EquivalenceClass::find(State, Sym));
+}
+
+//===----------------------------------------------------------------------===//
+//                       Equality/diseqiality abstraction
+//===----------------------------------------------------------------------===//
+
+/// A small helper structure representing symbolic equality.
+///
+/// Equality check can have different forms (like a == b or a - b) and this
+/// class encapsulates those away if the only thing the user wants to check -
+/// whether it's equality/diseqiality or not and have an easy access to the
+/// compared symbols.
+struct EqualityInfo {
+public:
+  SymbolRef Left, Right;
+  // true for equality and false for disequality.
+  bool IsEquality = true;
+
+  void invert() { IsEquality = !IsEquality; }
+  /// Extract equality information from the given symbol and the constants.
+  ///
+  /// This function assumes the following expression Sym + Adjustment != Int.
+  /// It is a default because the most widespread case of the equality check
+  /// is (A == B) + 0 != 0.
+  static Optional<EqualityInfo> extract(SymbolRef Sym, const llvm::APSInt &Int,
+                                        const llvm::APSInt &Adjustment) {
+    // As of now, the only equality form supported is Sym + 0 != 0.
+    if (!Int.isNullValue() || !Adjustment.isNullValue())
+      return llvm::None;
+
+    return extract(Sym);
+  }
+  /// Extract equality information from the given symbol.
+  static Optional<EqualityInfo> extract(SymbolRef Sym) {
+    return EqualityExtractor().Visit(Sym);
+  }
+
+private:
+  class EqualityExtractor
+      : public SymExprVisitor<EqualityExtractor, Optional<EqualityInfo>> {
+  public:
+    Optional<EqualityInfo> VisitSymSymExpr(const SymSymExpr *Sym) const {
+      switch (Sym->getOpcode()) {
+      case BO_Sub:
+        // This case is: A - B != 0 -> disequality check.
+        return EqualityInfo{Sym->getLHS(), Sym->getRHS(), false};
+      case BO_EQ:
+        // This case is: A == B != 0 -> equality check.
+        return EqualityInfo{Sym->getLHS(), Sym->getRHS(), true};
+      case BO_NE:
+        // This case is: A != B != 0 -> diseqiality check.
+        return EqualityInfo{Sym->getLHS(), Sym->getRHS(), false};
+      default:
+        return llvm::None;
+      }
+    }
+  };
+};
 
 //===----------------------------------------------------------------------===//
 //                            Intersection functions
@@ -556,15 +746,16 @@
 
   RangeSet infer(SymbolRef Sym) {
     if (Optional<RangeSet> ConstraintBasedRange = intersect(
-            ValueFactory, RangeFactory, State->get<ConstraintRange>(Sym),
+            ValueFactory, RangeFactory, getConstraint(State, Sym),
             // If Sym is a difference of symbols A - B, then maybe we have range
             // set stored for B - A.
             //
             // If we have range set stored for both A - B and B - A then
             // calculate the effective range set by intersecting the range set
             // for A - B and the negated range set of B - A.
-            getRangeForNegatedSub(Sym)))
+            getRangeForNegatedSub(Sym), getRangeForEqualities(Sym))) {
       return *ConstraintBasedRange;
+    }
 
     // If Sym is a comparison expression (except <=>),
     // find any other comparisons with the same operands.
@@ -745,8 +936,7 @@
         SymbolRef NegatedSym =
             SymMgr.getSymSymExpr(SSE->getRHS(), BO_Sub, SSE->getLHS(), T);
 
-        if (const RangeSet *NegatedRange =
-                State->get<ConstraintRange>(NegatedSym)) {
+        if (const RangeSet *NegatedRange = getConstraint(State, NegatedSym)) {
           return NegatedRange->Negate(ValueFactory, RangeFactory);
         }
       }
@@ -792,7 +982,7 @@
       // Let's find an expression e.g. (x < y).
       BinaryOperatorKind QueriedOP = OperatorRelationsTable::getOpFromIndex(i);
       const SymSymExpr *SymSym = SymMgr.getSymSymExpr(LHS, QueriedOP, RHS, T);
-      const RangeSet *QueriedRangeSet = State->get<ConstraintRange>(SymSym);
+      const RangeSet *QueriedRangeSet = getConstraint(State, SymSym);
 
       // If ranges were not previously found,
       // try to find a reversed expression (y > x).
@@ -800,7 +990,7 @@
         const BinaryOperatorKind ROP =
             BinaryOperator::reverseComparisonOp(QueriedOP);
         SymSym = SymMgr.getSymSymExpr(RHS, ROP, LHS, T);
-        QueriedRangeSet = State->get<ConstraintRange>(SymSym);
+        QueriedRangeSet = getConstraint(State, SymSym);
       }
 
       if (!QueriedRangeSet || QueriedRangeSet->isEmpty())
@@ -838,6 +1028,27 @@
     return llvm::None;
   }
 
+  Optional<RangeSet> getRangeForEqualities(SymbolRef Sym) {
+    Optional<EqualityInfo> Equality = EqualityInfo::extract(Sym);
+
+    if (!Equality)
+      return llvm::None;
+
+    EquivalenceClass LHS = EquivalenceClass::find(State, Equality->Left);
+    EquivalenceClass RHS = EquivalenceClass::find(State, Equality->Right);
+
+    if (LHS != RHS)
+      // Can't really say anything at this point.
+      // We can add more logic here if we track disequalities as well.
+      return llvm::None;
+
+    // At this point, operands of the equality operation are known to be equal.
+    if (Equality->IsEquality) {
+      return getTrueRange(Sym->getType());
+    }
+    return getFalseRange(Sym->getType());
+  }
+
   RangeSet getTrueRange(QualType T) {
     RangeSet TypeRange = infer(T);
     return assumeNonZero(TypeRange, T);
@@ -1032,7 +1243,11 @@
 
   bool haveEqualConstraints(ProgramStateRef S1,
                             ProgramStateRef S2) const override {
-    return S1->get<ConstraintRange>() == S2->get<ConstraintRange>();
+    // NOTE: ClassMembers are as simple as back pointers for ClassMap,
+    //       so comparing constraint ranges and class maps should be
+    //       sufficient.
+    return S1->get<ConstraintRange>() == S2->get<ConstraintRange>() &&
+           S1->get<ClassMap>() == S2->get<ClassMap>();
   }
 
   bool canReasonAbout(SVal X) const override;
@@ -1104,6 +1319,49 @@
   RangeSet getSymGERange(ProgramStateRef St, SymbolRef Sym,
                          const llvm::APSInt &Int,
                          const llvm::APSInt &Adjustment);
+
+  //===------------------------------------------------------------------===//
+  // Equality tracking implementation
+  //===------------------------------------------------------------------===//
+
+  ProgramStateRef trackEQ(ProgramStateRef State, SymbolRef Sym,
+                          const llvm::APSInt &Int,
+                          const llvm::APSInt &Adjustment) {
+    if (auto Equality = EqualityInfo::extract(Sym, Int, Adjustment)) {
+      // Extract function assumes that we gave it Sym + Adjustment != Int,
+      // so the result should be opposite.
+      Equality->invert();
+      return track(State, *Equality);
+    }
+
+    return State;
+  }
+
+  ProgramStateRef trackNE(ProgramStateRef State, SymbolRef Sym,
+                          const llvm::APSInt &Int,
+                          const llvm::APSInt &Adjustment) {
+    if (auto Equality = EqualityInfo::extract(Sym, Int, Adjustment)) {
+      return track(State, *Equality);
+    }
+
+    return State;
+  }
+
+  ProgramStateRef track(ProgramStateRef State, EqualityInfo ToTrack) {
+    if (ToTrack.IsEquality) {
+      return trackEquality(State, ToTrack.Left, ToTrack.Right);
+    }
+    return trackDisequality(State, ToTrack.Left, ToTrack.Right);
+  }
+
+  ProgramStateRef trackDisequality(ProgramStateRef State, SymbolRef LHS,
+                                   SymbolRef RHS) {
+    // TODO: track inequalities
+    return State;
+  }
+
+  ProgramStateRef trackEquality(ProgramStateRef State, SymbolRef LHS,
+                                SymbolRef RHS);
 };
 
 } // end anonymous namespace
@@ -1114,6 +1372,200 @@
   return std::make_unique<RangeConstraintManager>(Eng, StMgr.getSValBuilder());
 }
 
+ConstraintMap ento::getConstraintMap(ProgramStateRef State) {
+  ConstraintMap::Factory &F = State->get_context<ConstraintMap>();
+  ConstraintMap Result = F.getEmptyMap();
+
+  ConstraintRangeTy Constraints = State->get<ConstraintRange>();
+  for (std::pair<EquivalenceClass, RangeSet> ClassConstraint : Constraints) {
+    EquivalenceClass Class = ClassConstraint.first;
+    SymbolSet ClassMembers = Class.getClassMembers(State);
+    assert(!ClassMembers.isEmpty() &&
+           "Class must always have at least one member!");
+
+    SymbolRef Representative = *ClassMembers.begin();
+    Result = F.add(Result, Representative, ClassConstraint.second);
+  }
+
+  return Result;
+}
+
+//===----------------------------------------------------------------------===//
+//                     EqualityClass implementation details
+//===----------------------------------------------------------------------===//
+
+inline EquivalenceClass EquivalenceClass::find(ProgramStateRef State,
+                                               SymbolRef Sym) {
+  // We store far from all Symbol -> Class mappings
+  if (const EquivalenceClass *NontrivialClass = State->get<ClassMap>(Sym))
+    return *NontrivialClass;
+
+  // This is a trivial class of Sym.
+  return Sym;
+}
+
+inline ProgramStateRef EquivalenceClass::merge(BasicValueFactory &BV,
+                                               RangeSet::Factory &F,
+                                               ProgramStateRef State,
+                                               SymbolRef First,
+                                               SymbolRef Second) {
+  EquivalenceClass FirstClass = find(State, First);
+  EquivalenceClass SecondClass = find(State, Second);
+
+  return FirstClass.merge(BV, F, State, SecondClass);
+}
+
+inline ProgramStateRef EquivalenceClass::merge(BasicValueFactory &BV,
+                                               RangeSet::Factory &F,
+                                               ProgramStateRef State,
+                                               EquivalenceClass Other) {
+  // It is already the same class.
+  if (*this == Other)
+    return State;
+
+  // FIXME: As of now, we support only equivalence classes of the same type.
+  //        This limitation is connected to the lack of explicit casts in
+  //        our symbolic expression model.
+  //
+  //        That means that for `int x` and `char y` we don't distinguish
+  //        between these two very different cases:
+  //          * `x == y`
+  //          * `(char)x == y`
+  //
+  //        The moment we introduce symbolic casts, this restriction can be
+  //        lifted.
+  if (getType() != Other.getType())
+    return State;
+
+  SymbolSet Members = getClassMembers(State);
+  SymbolSet OtherMembers = Other.getClassMembers(State);
+
+  // We estimate the size of the class by the height of tree containing
+  // its members.  Merging is not a trivial operation, so it's easier to
+  // merge the smaller class into the bigger one.
+  if (Members.getHeight() >= OtherMembers.getHeight()) {
+    return mergeImpl(BV, F, State, Members, Other, OtherMembers);
+  } else {
+    return Other.mergeImpl(BV, F, State, OtherMembers, *this, Members);
+  }
+}
+
+inline ProgramStateRef
+EquivalenceClass::mergeImpl(BasicValueFactory &ValueFactory,
+                            RangeSet::Factory &RangeFactory,
+                            ProgramStateRef State, SymbolSet MyMembers,
+                            EquivalenceClass Other, SymbolSet OtherMembers) {
+  // Essentially what we try to recreate here is some kind of union-find
+  // data structure.  It does have certain limitations due to persistence
+  // and the need to remove elements from classes.
+  //
+  // In this setting, EquialityClass object is the representative of the class
+  // or the parent element.  ClassMap is a mapping of class members to their
+  // parent. Unlike the union-find structure, they all point directly to the
+  // class representative because we don't have an opportunity to actually do
+  // path compression when dealing with immutability.  This means that we
+  // compress paths every time we do merges.  It also means that we lose
+  // the main amortized complexity benefit from the original data structure.
+  ConstraintRangeTy Constraints = State->get<ConstraintRange>();
+  ConstraintRangeTy::Factory &CRF = State->get_context<ConstraintRange>();
+
+  // 1. If the merged classes have any constraints associated with them, we
+  //    need to transfer them to the class we have left.
+  //
+  // Intersection here makes perfect sense because both of these constraints
+  // must hold for the whole new class.
+  if (Optional<RangeSet> NewClassConstraint =
+          intersect(ValueFactory, RangeFactory, getConstraint(State, *this),
+                    getConstraint(State, Other))) {
+    // NOTE: Essentially, NewClassConstraint should NEVER be infeasible because
+    //       range inferrer shouldn't generate ranges incompatible with
+    //       equivalence classes. However, at the moment, due to imperfections
+    //       in the solver, it is possible and the merge function can also
+    //       return infeasible states aka null states.
+    if (NewClassConstraint->isEmpty())
+      // Infeasible state
+      return nullptr;
+
+    // No need in tracking constraints of a now-dissolved class.
+    Constraints = CRF.remove(Constraints, Other);
+    // Assign new constraints for this class.
+    Constraints = CRF.add(Constraints, *this, *NewClassConstraint);
+
+    State = State->set<ConstraintRange>(Constraints);
+  }
+
+  // 2. Get ALL equivalence-related maps
+  ClassMapTy Classes = State->get<ClassMap>();
+  ClassMapTy::Factory &CF = State->get_context<ClassMap>();
+
+  ClassMembersTy Members = State->get<ClassMembers>();
+  ClassMembersTy::Factory &MF = State->get_context<ClassMembers>();
+
+  SymbolSet::Factory &F = getMembersFactory(State);
+
+  // 2. Merge members of the Other class into the current class.
+  SymbolSet NewClassMembers = MyMembers;
+  for (SymbolRef Sym : OtherMembers) {
+    NewClassMembers = F.add(NewClassMembers, Sym);
+    // *this is now the class for all these new symbols.
+    Classes = CF.add(Classes, Sym, *this);
+  }
+
+  // 3. Adjust member mapping.
+  //
+  // No need in tracking members of a now-dissolved class.
+  Members = MF.remove(Members, Other);
+  // Now only the current class is mapped to all the symbols.
+  Members = MF.add(Members, *this, NewClassMembers);
+
+  // 4. Update the state
+  State = State->set<ClassMap>(Classes);
+  State = State->set<ClassMembers>(Members);
+
+  return State;
+}
+
+inline SymbolSet::Factory &
+EquivalenceClass::getMembersFactory(ProgramStateRef State) {
+  return State->get_context<SymbolSet>();
+}
+
+SymbolSet EquivalenceClass::getClassMembers(ProgramStateRef State) {
+  if (const SymbolSet *Members = State->get<ClassMembers>(*this))
+    return *Members;
+
+  // This class is trivial, so we need to construct a set
+  // with just that one symbol from the class.
+  SymbolSet::Factory &F = getMembersFactory(State);
+  return F.add(F.getEmptySet(), getRepresentativeSymbol());
+}
+
+bool EquivalenceClass::isTrivial(ProgramStateRef State) {
+  return State->get<ClassMembers>(*this) == nullptr;
+}
+
+bool EquivalenceClass::isTriviallyDead(ProgramStateRef State,
+                                       SymbolReaper &Reaper) {
+  return isTrivial(State) && Reaper.isDead(getRepresentativeSymbol());
+}
+
+bool EquivalenceClass::isClassDataConsistent(ProgramStateRef State) {
+  ClassMembersTy Members = State->get<ClassMembers>();
+
+  for (std::pair<EquivalenceClass, SymbolSet> ClassMembersPair : Members) {
+    for (SymbolRef Member : ClassMembersPair.second) {
+      // Every member of the class should have a mapping back to the class.
+      if (find(State, Member) == ClassMembersPair.first) {
+        continue;
+      }
+
+      return false;
+    }
+  }
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 //                    RangeConstraintManager implementation
 //===----------------------------------------------------------------------===//
@@ -1166,7 +1618,7 @@
 
 ConditionTruthVal RangeConstraintManager::checkNull(ProgramStateRef State,
                                                     SymbolRef Sym) {
-  const RangeSet *Ranges = State->get<ConstraintRange>(Sym);
+  const RangeSet *Ranges = getConstraint(State, Sym);
 
   // If we don't have any information about this symbol, it's underconstrained.
   if (!Ranges)
@@ -1190,7 +1642,7 @@
 
 const llvm::APSInt *RangeConstraintManager::getSymVal(ProgramStateRef St,
                                                       SymbolRef Sym) const {
-  const ConstraintRangeTy::data_type *T = St->get<ConstraintRange>(Sym);
+  const RangeSet *T = getConstraint(St, Sym);
   return T ? T->getConcreteValue() : nullptr;
 }
 
@@ -1203,19 +1655,96 @@
 ProgramStateRef
 RangeConstraintManager::removeDeadBindings(ProgramStateRef State,
                                            SymbolReaper &SymReaper) {
-  bool Changed = false;
-  ConstraintRangeTy CR = State->get<ConstraintRange>();
-  ConstraintRangeTy::Factory &CRFactory = State->get_context<ConstraintRange>();
+  ClassMembersTy ClassMembersMap = State->get<ClassMembers>();
+  ClassMembersTy NewClassMembersMap = ClassMembersMap;
+  ClassMembersTy::Factory &EMFactory = State->get_context<ClassMembers>();
+  SymbolSet::Factory &SetFactory = State->get_context<SymbolSet>();
+
+  ConstraintRangeTy Constraints = State->get<ConstraintRange>();
+  ConstraintRangeTy NewConstraints = Constraints;
+  ConstraintRangeTy::Factory &ConstraintFactory =
+      State->get_context<ConstraintRange>();
+
+  ClassMapTy Map = State->get<ClassMap>();
+  ClassMapTy NewMap = Map;
+  ClassMapTy::Factory &ClassFactory = State->get_context<ClassMap>();
+
+  bool ClassMapChanged = false;
+  bool MembersMapChanged = false;
+  bool ConstraintMapChanged = false;
+
+  // 1. Let's see if dead symbols are trivial and have associated constraints.
+  for (std::pair<EquivalenceClass, RangeSet> ClassConstraintPair :
+       Constraints) {
+    EquivalenceClass Class = ClassConstraintPair.first;
+    if (Class.isTriviallyDead(State, SymReaper)) {
+      // If this class is trivial, we can remove its constraints right away.
+      Constraints = ConstraintFactory.remove(Constraints, Class);
+      ConstraintMapChanged = true;
+    }
+  }
+
+  // 2. We don't need to track classes for dead symbols.
+  for (std::pair<SymbolRef, EquivalenceClass> SymbolClassPair : Map) {
+    SymbolRef Sym = SymbolClassPair.first;
 
-  for (ConstraintRangeTy::iterator I = CR.begin(), E = CR.end(); I != E; ++I) {
-    SymbolRef Sym = I.getKey();
     if (SymReaper.isDead(Sym)) {
-      Changed = true;
-      CR = CRFactory.remove(CR, Sym);
+      ClassMapChanged = true;
+      NewMap = ClassFactory.remove(NewMap, Sym);
     }
   }
 
-  return Changed ? State->set<ConstraintRange>(CR) : State;
+  // 3. Remove dead members from classes and remove dead non-trivial classes
+  //    and their constraints.
+  for (std::pair<EquivalenceClass, SymbolSet> ClassMembersPair :
+       ClassMembersMap) {
+    SymbolSet LiveMembers = ClassMembersPair.second;
+    bool MembersChanged = false;
+
+    for (SymbolRef Member : ClassMembersPair.second) {
+      if (SymReaper.isDead(Member)) {
+        MembersChanged = true;
+        LiveMembers = SetFactory.remove(LiveMembers, Member);
+      }
+    }
+
+    // Check if the class changed.
+    if (!MembersChanged)
+      continue;
+
+    MembersMapChanged = true;
+
+    if (LiveMembers.isEmpty()) {
+      // The class is dead now, we need to wipe it out of the members map...
+      NewClassMembersMap =
+          EMFactory.remove(NewClassMembersMap, ClassMembersPair.first);
+
+      // ...and remove all of its constraints.
+      Constraints =
+          ConstraintFactory.remove(Constraints, ClassMembersPair.first);
+      ConstraintMapChanged = true;
+    } else {
+      // We need to change the members associated with the class.
+      NewClassMembersMap = EMFactory.add(NewClassMembersMap,
+                                         ClassMembersPair.first, LiveMembers);
+    }
+  }
+
+  // 4. Update the state with new maps.
+  //
+  // Here we try to be humble and update a map only if it really changed.
+  if (ClassMapChanged)
+    State = State->set<ClassMap>(NewMap);
+
+  if (MembersMapChanged)
+    State = State->set<ClassMembers>(NewClassMembersMap);
+
+  if (ConstraintMapChanged)
+    State = State->set<ConstraintRange>(Constraints);
+
+  assert(EquivalenceClass::isClassDataConsistent(State));
+
+  return State;
 }
 
 RangeSet RangeConstraintManager::getRange(ProgramStateRef State,
@@ -1247,7 +1776,13 @@
   llvm::APSInt Point = AdjustmentType.convert(Int) - Adjustment;
 
   RangeSet New = getRange(St, Sym).Delete(getBasicVals(), F, Point);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+
+  if (New.isEmpty())
+    // this is infeasible assumption
+    return nullptr;
+
+  ProgramStateRef NewState = setConstraint(St, Sym, New);
+  return trackNE(NewState, Sym, Int, Adjustment);
 }
 
 ProgramStateRef
@@ -1262,7 +1797,13 @@
   // [Int-Adjustment, Int-Adjustment]
   llvm::APSInt AdjInt = AdjustmentType.convert(Int) - Adjustment;
   RangeSet New = getRange(St, Sym).Intersect(getBasicVals(), F, AdjInt, AdjInt);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+
+  if (New.isEmpty())
+    // this is infeasible assumption
+    return nullptr;
+
+  ProgramStateRef NewState = setConstraint(St, Sym, New);
+  return trackEQ(NewState, Sym, Int, Adjustment);
 }
 
 RangeSet RangeConstraintManager::getSymLTRange(ProgramStateRef St,
@@ -1298,7 +1839,7 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymLTRange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
 RangeSet RangeConstraintManager::getSymGTRange(ProgramStateRef St,
@@ -1334,7 +1875,7 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymGTRange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
 RangeSet RangeConstraintManager::getSymGERange(ProgramStateRef St,
@@ -1370,13 +1911,13 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymGERange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
-RangeSet RangeConstraintManager::getSymLERange(
-      llvm::function_ref<RangeSet()> RS,
-      const llvm::APSInt &Int,
-      const llvm::APSInt &Adjustment) {
+RangeSet
+RangeConstraintManager::getSymLERange(llvm::function_ref<RangeSet()> RS,
+                                      const llvm::APSInt &Int,
+                                      const llvm::APSInt &Adjustment) {
   // Before we do any real work, see if the value can even show up.
   APSIntType AdjustmentType(Adjustment);
   switch (AdjustmentType.testInRange(Int, true)) {
@@ -1413,7 +1954,7 @@
                                     const llvm::APSInt &Int,
                                     const llvm::APSInt &Adjustment) {
   RangeSet New = getSymLERange(St, Sym, Int, Adjustment);
-  return New.isEmpty() ? nullptr : St->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(St, Sym, New);
 }
 
 ProgramStateRef RangeConstraintManager::assumeSymWithinInclusiveRange(
@@ -1423,7 +1964,7 @@
   if (New.isEmpty())
     return nullptr;
   RangeSet Out = getSymLERange([&] { return New; }, To, Adjustment);
-  return Out.isEmpty() ? nullptr : State->set<ConstraintRange>(Sym, Out);
+  return Out.isEmpty() ? nullptr : setConstraint(State, Sym, Out);
 }
 
 ProgramStateRef RangeConstraintManager::assumeSymOutsideInclusiveRange(
@@ -1432,7 +1973,13 @@
   RangeSet RangeLT = getSymLTRange(State, Sym, From, Adjustment);
   RangeSet RangeGT = getSymGTRange(State, Sym, To, Adjustment);
   RangeSet New(RangeLT.addRange(F, RangeGT));
-  return New.isEmpty() ? nullptr : State->set<ConstraintRange>(Sym, New);
+  return New.isEmpty() ? nullptr : setConstraint(State, Sym, New);
+}
+
+ProgramStateRef RangeConstraintManager::trackEquality(ProgramStateRef State,
+                                                      SymbolRef LHS,
+                                                      SymbolRef RHS) {
+  return EquivalenceClass::merge(getBasicVals(), F, State, LHS, RHS);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1452,17 +1999,25 @@
 
   ++Space;
   Out << '[' << NL;
-  for (ConstraintRangeTy::iterator I = Constraints.begin();
-       I != Constraints.end(); ++I) {
-    Indent(Out, Space, IsDot)
-        << "{ \"symbol\": \"" << I.getKey() << "\", \"range\": \"";
-    I.getData().print(Out);
-    Out << "\" }";
-
-    if (std::next(I) != Constraints.end())
-      Out << ',';
-    Out << NL;
+  bool First = true;
+  for (std::pair<EquivalenceClass, RangeSet> P : Constraints) {
+    SymbolSet ClassMembers = P.first.getClassMembers(State);
+
+    // We can print the same constraint for every class member.
+    for (SymbolRef ClassMember : ClassMembers) {
+      if (First) {
+        First = false;
+      } else {
+        Out << ',';
+        Out << NL;
+      }
+      Indent(Out, Space, IsDot)
+          << "{ \"symbol\": \"" << ClassMember << "\", \"range\": \"";
+      P.second.print(Out);
+      Out << "\" }";
+    }
   }
+  Out << NL;
 
   --Space;
   Indent(Out, Space, IsDot) << "]," << NL;
Index: clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
+++ clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
@@ -2813,7 +2813,7 @@
 //===----------------------------------------------------------------------===//
 
 FalsePositiveRefutationBRVisitor::FalsePositiveRefutationBRVisitor()
-    : Constraints(ConstraintRangeTy::Factory().getEmptyMap()) {}
+    : Constraints(ConstraintMap::Factory().getEmptyMap()) {}
 
 void FalsePositiveRefutationBRVisitor::finalizeVisitor(
     BugReporterContext &BRC, const ExplodedNode *EndPathNode,
@@ -2855,9 +2855,8 @@
 void FalsePositiveRefutationBRVisitor::addConstraints(
     const ExplodedNode *N, bool OverwriteConstraintsOnExistingSyms) {
   // Collect new constraints
-  const ConstraintRangeTy &NewCs = N->getState()->get<ConstraintRange>();
-  ConstraintRangeTy::Factory &CF =
-      N->getState()->get_context<ConstraintRange>();
+  ConstraintMap NewCs = getConstraintMap(N->getState());
+  ConstraintMap::Factory &CF = N->getState()->get_context<ConstraintMap>();
 
   // Add constraints if we don't have them yet
   for (auto const &C : NewCs) {
Index: clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
===================================================================
--- clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
+++ clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h
@@ -136,14 +136,8 @@
   }
 };
 
-class ConstraintRange {};
-using ConstraintRangeTy = llvm::ImmutableMap<SymbolRef, RangeSet>;
-
-template <>
-struct ProgramStateTrait<ConstraintRange>
-    : public ProgramStatePartialTrait<ConstraintRangeTy> {
-  static void *GDMIndex();
-};
+using ConstraintMap = llvm::ImmutableMap<SymbolRef, RangeSet>;
+ConstraintMap getConstraintMap(ProgramStateRef State);
 
 class RangedConstraintManager : public SimpleConstraintManager {
 public:
@@ -222,4 +216,6 @@
 } // namespace ento
 } // namespace clang
 
+REGISTER_FACTORY_WITH_PROGRAMSTATE(ConstraintMap)
+
 #endif
Index: clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
===================================================================
--- clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
+++ clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
@@ -373,7 +373,7 @@
 class FalsePositiveRefutationBRVisitor final : public BugReporterVisitor {
 private:
   /// Holds the constraints in a given path
-  ConstraintRangeTy Constraints;
+  ConstraintMap Constraints;
 
 public:
   FalsePositiveRefutationBRVisitor();
@@ -390,7 +390,6 @@
                       bool OverwriteConstraintsOnExistingSyms);
 };
 
-
 /// The visitor detects NoteTags and displays the event notes they contain.
 class TagVisitor : public BugReporterVisitor {
 public:
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to