Author: Balazs Benics Date: 2024-06-18T09:42:29+02:00 New Revision: 89c26f6c7b0a6dfa257ec090fcf5b6e6e0c89aab
URL: https://github.com/llvm/llvm-project/commit/89c26f6c7b0a6dfa257ec090fcf5b6e6e0c89aab DIFF: https://github.com/llvm/llvm-project/commit/89c26f6c7b0a6dfa257ec090fcf5b6e6e0c89aab.diff LOG: [analyzer][NFC] Reorganize Z3 report refutation This change keeps existing behavior, namely that if we hit a Z3 timeout we will accept the report as "satisfiable". This prepares for the commit "Harden safeguards for Z3 query times". https://discourse.llvm.org/t/analyzer-rfc-taming-z3-query-times/79520 Reviewers: NagyDonat, haoNoQ, Xazax-hun, mikhailramalho, Szelethus Reviewed By: NagyDonat Pull Request: https://github.com/llvm/llvm-project/pull/95128 Added: clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp clang/test/Analysis/z3/crosscheck-statistics.c clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp Modified: clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h clang/lib/StaticAnalyzer/Core/BugReporter.cpp clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp clang/lib/StaticAnalyzer/Core/CMakeLists.txt clang/unittests/StaticAnalyzer/CMakeLists.txt llvm/include/llvm/Support/SMTAPI.h llvm/lib/Support/Z3Solver.cpp Removed: ################################################################################ diff --git a/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h b/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h index cc3d93aabafda..f97514955a591 100644 --- a/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h +++ b/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h @@ -597,29 +597,6 @@ class SuppressInlineDefensiveChecksVisitor final : public BugReporterVisitor { PathSensitiveBugReport &BR) override; }; -/// The bug visitor will walk all the nodes in a path and collect all the -/// constraints. When it reaches the root node, will create a refutation -/// manager and check if the constraints are satisfiable -class FalsePositiveRefutationBRVisitor final : public BugReporterVisitor { -private: - /// Holds the constraints in a given path - ConstraintMap Constraints; - -public: - FalsePositiveRefutationBRVisitor(); - - void Profile(llvm::FoldingSetNodeID &ID) const override; - - PathDiagnosticPieceRef VisitNode(const ExplodedNode *N, - BugReporterContext &BRC, - PathSensitiveBugReport &BR) override; - - void finalizeVisitor(BugReporterContext &BRC, const ExplodedNode *EndPathNode, - PathSensitiveBugReport &BR) override; - void addConstraints(const ExplodedNode *N, - bool OverwriteConstraintsOnExistingSyms); -}; - /// The visitor detects NoteTags and displays the event notes they contain. class TagVisitor : public BugReporterVisitor { public: diff --git a/clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h b/clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h new file mode 100644 index 0000000000000..9413fd739f607 --- /dev/null +++ b/clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h @@ -0,0 +1,66 @@ +//===- Z3CrosscheckVisitor.h - Crosscheck reports with Z3 -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the visitor and utilities around it for Z3 report +// refutation. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_STATICANALYZER_CORE_BUGREPORTER_Z3CROSSCHECKVISITOR_H +#define LLVM_CLANG_STATICANALYZER_CORE_BUGREPORTER_Z3CROSSCHECKVISITOR_H + +#include "clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h" + +namespace clang::ento { + +/// The bug visitor will walk all the nodes in a path and collect all the +/// constraints. When it reaches the root node, will create a refutation +/// manager and check if the constraints are satisfiable. +class Z3CrosscheckVisitor final : public BugReporterVisitor { +public: + struct Z3Result { + std::optional<bool> IsSAT = std::nullopt; + }; + explicit Z3CrosscheckVisitor(Z3CrosscheckVisitor::Z3Result &Result); + + void Profile(llvm::FoldingSetNodeID &ID) const override; + + PathDiagnosticPieceRef VisitNode(const ExplodedNode *N, + BugReporterContext &BRC, + PathSensitiveBugReport &BR) override; + + void finalizeVisitor(BugReporterContext &BRC, const ExplodedNode *EndPathNode, + PathSensitiveBugReport &BR) override; + +private: + void addConstraints(const ExplodedNode *N, + bool OverwriteConstraintsOnExistingSyms); + + /// Holds the constraints in a given path. + ConstraintMap Constraints; + Z3Result &Result; +}; + +/// The oracle will decide if a report should be accepted or rejected based on +/// the results of the Z3 solver. +class Z3CrosscheckOracle { +public: + enum Z3Decision { + AcceptReport, // The report was SAT. + RejectReport, // The report was UNSAT or UNDEF. + }; + + /// Makes a decision for accepting or rejecting the report based on the + /// result of the corresponding Z3 query. + static Z3Decision + interpretQueryResult(const Z3CrosscheckVisitor::Z3Result &Query); +}; + +} // namespace clang::ento + +#endif // LLVM_CLANG_STATICANALYZER_CORE_BUGREPORTER_Z3CROSSCHECKVISITOR_H diff --git a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h index 5116a4c06850d..bf18c353b8508 100644 --- a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h +++ b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h @@ -34,7 +34,10 @@ class SMTConstraintManager : public clang::ento::SimpleConstraintManager { public: SMTConstraintManager(clang::ento::ExprEngine *EE, clang::ento::SValBuilder &SB) - : SimpleConstraintManager(EE, SB) {} + : SimpleConstraintManager(EE, SB) { + Solver->setBoolParam("model", true); // Enable model finding + Solver->setUnsignedParam("timeout", 15000 /*milliseconds*/); + } virtual ~SMTConstraintManager() = default; //===------------------------------------------------------------------===// diff --git a/clang/lib/StaticAnalyzer/Core/BugReporter.cpp b/clang/lib/StaticAnalyzer/Core/BugReporter.cpp index 14ca507a16d55..c9a7fd0e035c2 100644 --- a/clang/lib/StaticAnalyzer/Core/BugReporter.cpp +++ b/clang/lib/StaticAnalyzer/Core/BugReporter.cpp @@ -35,6 +35,7 @@ #include "clang/StaticAnalyzer/Core/AnalyzerOptions.h" #include "clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h" #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h" +#include "clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h" #include "clang/StaticAnalyzer/Core/Checker.h" #include "clang/StaticAnalyzer/Core/CheckerManager.h" #include "clang/StaticAnalyzer/Core/CheckerRegistryData.h" @@ -86,6 +87,11 @@ STATISTIC(MaxValidBugClassSize, "The maximum number of bug reports in the same equivalence class " "where at least one report is valid (not suppressed)"); +STATISTIC(NumTimesReportPassesZ3, "Number of reports passed Z3"); +STATISTIC(NumTimesReportRefuted, "Number of reports refuted by Z3"); +STATISTIC(NumTimesReportEQClassWasExhausted, + "Number of times all reports of an equivalence class was refuted"); + BugReporterVisitor::~BugReporterVisitor() = default; void BugReporterContext::anchor() {} @@ -2864,21 +2870,31 @@ std::optional<PathDiagnosticBuilder> PathDiagnosticBuilder::findValidReport( // If crosscheck is enabled, remove all visitors, add the refutation // visitor and check again R->clearVisitors(); - R->addVisitor<FalsePositiveRefutationBRVisitor>(); + Z3CrosscheckVisitor::Z3Result CrosscheckResult; + R->addVisitor<Z3CrosscheckVisitor>(CrosscheckResult); // We don't overwrite the notes inserted by other visitors because the // refutation manager does not add any new note to the path generateVisitorsDiagnostics(R, BugPath->ErrorNode, BRC); + switch (Z3CrosscheckOracle::interpretQueryResult(CrosscheckResult)) { + case Z3CrosscheckOracle::RejectReport: + ++NumTimesReportRefuted; + R->markInvalid("Infeasible constraints", /*Data=*/nullptr); + continue; + case Z3CrosscheckOracle::AcceptReport: + ++NumTimesReportPassesZ3; + break; + } } - // Check if the bug is still valid - if (R->isValid()) - return PathDiagnosticBuilder( - std::move(BRC), std::move(BugPath->BugPath), BugPath->Report, - BugPath->ErrorNode, std::move(visitorNotes)); + assert(R->isValid()); + return PathDiagnosticBuilder(std::move(BRC), std::move(BugPath->BugPath), + BugPath->Report, BugPath->ErrorNode, + std::move(visitorNotes)); } } + ++NumTimesReportEQClassWasExhausted; return {}; } diff --git a/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp b/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp index 487a3bd16b674..68dac65949117 100644 --- a/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp +++ b/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp @@ -3446,82 +3446,6 @@ UndefOrNullArgVisitor::VisitNode(const ExplodedNode *N, BugReporterContext &BRC, return nullptr; } -//===----------------------------------------------------------------------===// -// Implementation of FalsePositiveRefutationBRVisitor. -//===----------------------------------------------------------------------===// - -FalsePositiveRefutationBRVisitor::FalsePositiveRefutationBRVisitor() - : Constraints(ConstraintMap::Factory().getEmptyMap()) {} - -void FalsePositiveRefutationBRVisitor::finalizeVisitor( - BugReporterContext &BRC, const ExplodedNode *EndPathNode, - PathSensitiveBugReport &BR) { - // Collect new constraints - addConstraints(EndPathNode, /*OverwriteConstraintsOnExistingSyms=*/true); - - // Create a refutation manager - llvm::SMTSolverRef RefutationSolver = llvm::CreateZ3Solver(); - ASTContext &Ctx = BRC.getASTContext(); - - // Add constraints to the solver - for (const auto &I : Constraints) { - const SymbolRef Sym = I.first; - auto RangeIt = I.second.begin(); - - llvm::SMTExprRef SMTConstraints = SMTConv::getRangeExpr( - RefutationSolver, Ctx, Sym, RangeIt->From(), RangeIt->To(), - /*InRange=*/true); - while ((++RangeIt) != I.second.end()) { - SMTConstraints = RefutationSolver->mkOr( - SMTConstraints, SMTConv::getRangeExpr(RefutationSolver, Ctx, Sym, - RangeIt->From(), RangeIt->To(), - /*InRange=*/true)); - } - - RefutationSolver->addConstraint(SMTConstraints); - } - - // And check for satisfiability - std::optional<bool> IsSAT = RefutationSolver->check(); - if (!IsSAT) - return; - - if (!*IsSAT) - BR.markInvalid("Infeasible constraints", EndPathNode->getLocationContext()); -} - -void FalsePositiveRefutationBRVisitor::addConstraints( - const ExplodedNode *N, bool OverwriteConstraintsOnExistingSyms) { - // Collect new constraints - ConstraintMap NewCs = getConstraintMap(N->getState()); - ConstraintMap::Factory &CF = N->getState()->get_context<ConstraintMap>(); - - // Add constraints if we don't have them yet - for (auto const &C : NewCs) { - const SymbolRef &Sym = C.first; - if (!Constraints.contains(Sym)) { - // This symbol is new, just add the constraint. - Constraints = CF.add(Constraints, Sym, C.second); - } else if (OverwriteConstraintsOnExistingSyms) { - // Overwrite the associated constraint of the Symbol. - Constraints = CF.remove(Constraints, Sym); - Constraints = CF.add(Constraints, Sym, C.second); - } - } -} - -PathDiagnosticPieceRef FalsePositiveRefutationBRVisitor::VisitNode( - const ExplodedNode *N, BugReporterContext &, PathSensitiveBugReport &) { - addConstraints(N, /*OverwriteConstraintsOnExistingSyms=*/false); - return nullptr; -} - -void FalsePositiveRefutationBRVisitor::Profile( - llvm::FoldingSetNodeID &ID) const { - static int Tag = 0; - ID.AddPointer(&Tag); -} - //===----------------------------------------------------------------------===// // Implementation of TagVisitor. //===----------------------------------------------------------------------===// diff --git a/clang/lib/StaticAnalyzer/Core/CMakeLists.txt b/clang/lib/StaticAnalyzer/Core/CMakeLists.txt index 8672876c0608d..fb9394a519eb7 100644 --- a/clang/lib/StaticAnalyzer/Core/CMakeLists.txt +++ b/clang/lib/StaticAnalyzer/Core/CMakeLists.txt @@ -51,6 +51,7 @@ add_clang_library(clangStaticAnalyzerCore SymbolManager.cpp TextDiagnostics.cpp WorkList.cpp + Z3CrosscheckVisitor.cpp LINK_LIBS clangAST diff --git a/clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp b/clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp new file mode 100644 index 0000000000000..a7db44ef8ea30 --- /dev/null +++ b/clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp @@ -0,0 +1,118 @@ +//===- Z3CrosscheckVisitor.cpp - Crosscheck reports with Z3 -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the visitor and utilities around it for Z3 report +// refutation. +// +//===----------------------------------------------------------------------===// + +#include "clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h" +#include "clang/StaticAnalyzer/Core/BugReporter/BugReporter.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTConv.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/SMTAPI.h" + +#define DEBUG_TYPE "Z3CrosscheckOracle" + +STATISTIC(NumZ3QueriesDone, "Number of Z3 queries done"); +STATISTIC(NumTimesZ3TimedOut, "Number of times Z3 query timed out"); + +STATISTIC(NumTimesZ3QueryAcceptsReport, + "Number of Z3 queries accepting a report"); +STATISTIC(NumTimesZ3QueryRejectReport, + "Number of Z3 queries rejecting a report"); + +using namespace clang; +using namespace ento; + +Z3CrosscheckVisitor::Z3CrosscheckVisitor(Z3CrosscheckVisitor::Z3Result &Result) + : Constraints(ConstraintMap::Factory().getEmptyMap()), Result(Result) {} + +void Z3CrosscheckVisitor::finalizeVisitor(BugReporterContext &BRC, + const ExplodedNode *EndPathNode, + PathSensitiveBugReport &BR) { + // Collect new constraints + addConstraints(EndPathNode, /*OverwriteConstraintsOnExistingSyms=*/true); + + // Create a refutation manager + llvm::SMTSolverRef RefutationSolver = llvm::CreateZ3Solver(); + RefutationSolver->setBoolParam("model", true); // Enable model finding + RefutationSolver->setUnsignedParam("timeout", 15000); // ms + + ASTContext &Ctx = BRC.getASTContext(); + + // Add constraints to the solver + for (const auto &[Sym, Range] : Constraints) { + auto RangeIt = Range.begin(); + + llvm::SMTExprRef SMTConstraints = SMTConv::getRangeExpr( + RefutationSolver, Ctx, Sym, RangeIt->From(), RangeIt->To(), + /*InRange=*/true); + while ((++RangeIt) != Range.end()) { + SMTConstraints = RefutationSolver->mkOr( + SMTConstraints, SMTConv::getRangeExpr(RefutationSolver, Ctx, Sym, + RangeIt->From(), RangeIt->To(), + /*InRange=*/true)); + } + RefutationSolver->addConstraint(SMTConstraints); + } + + // And check for satisfiability + std::optional<bool> IsSAT = RefutationSolver->check(); + Result = Z3Result{IsSAT}; +} + +void Z3CrosscheckVisitor::addConstraints( + const ExplodedNode *N, bool OverwriteConstraintsOnExistingSyms) { + // Collect new constraints + ConstraintMap NewCs = getConstraintMap(N->getState()); + ConstraintMap::Factory &CF = N->getState()->get_context<ConstraintMap>(); + + // Add constraints if we don't have them yet + for (auto const &[Sym, Range] : NewCs) { + if (!Constraints.contains(Sym)) { + // This symbol is new, just add the constraint. + Constraints = CF.add(Constraints, Sym, Range); + } else if (OverwriteConstraintsOnExistingSyms) { + // Overwrite the associated constraint of the Symbol. + Constraints = CF.remove(Constraints, Sym); + Constraints = CF.add(Constraints, Sym, Range); + } + } +} + +PathDiagnosticPieceRef +Z3CrosscheckVisitor::VisitNode(const ExplodedNode *N, BugReporterContext &, + PathSensitiveBugReport &) { + addConstraints(N, /*OverwriteConstraintsOnExistingSyms=*/false); + return nullptr; +} + +void Z3CrosscheckVisitor::Profile(llvm::FoldingSetNodeID &ID) const { + static int Tag = 0; + ID.AddPointer(&Tag); +} + +Z3CrosscheckOracle::Z3Decision Z3CrosscheckOracle::interpretQueryResult( + const Z3CrosscheckVisitor::Z3Result &Query) { + ++NumZ3QueriesDone; + + if (!Query.IsSAT.has_value()) { + // For backward compatibility, let's accept the first timeout. + ++NumTimesZ3TimedOut; + return AcceptReport; + } + + if (Query.IsSAT.value()) { + ++NumTimesZ3QueryAcceptsReport; + return AcceptReport; // sat + } + + ++NumTimesZ3QueryRejectReport; + return RejectReport; // unsat +} diff --git a/clang/test/Analysis/z3/crosscheck-statistics.c b/clang/test/Analysis/z3/crosscheck-statistics.c new file mode 100644 index 0000000000000..7192824c5be31 --- /dev/null +++ b/clang/test/Analysis/z3/crosscheck-statistics.c @@ -0,0 +1,33 @@ +// RUN: %clang_analyze_cc1 -analyzer-checker=core -verify %s \ +// RUN: -analyzer-config crosscheck-with-z3=true \ +// RUN: -analyzer-stats 2>&1 | FileCheck %s + +// REQUIRES: z3 + +// expected-error@1 {{Z3 refutation rate:1/2}} + +int accepting(int n) { + if (n == 4) { + n = n / (n-4); // expected-warning {{Division by zero}} + } + return n; +} + +int rejecting(int n, int x) { + // Let's make the path infeasible. + if (2 < x && x < 5 && x*x == x*x*x) { + // Have the same condition as in 'accepting'. + if (n == 4) { + n = x / (n-4); // no-warning: refuted + } + } + return n; +} + +// CHECK: 1 BugReporter - Number of times all reports of an equivalence class was refuted +// CHECK-NEXT: 1 BugReporter - Number of reports passed Z3 +// CHECK-NEXT: 1 BugReporter - Number of reports refuted by Z3 + +// CHECK: 1 Z3CrosscheckVisitor - Number of Z3 queries accepting a report +// CHECK-NEXT: 1 Z3CrosscheckVisitor - Number of Z3 queries rejecting a report +// CHECK-NEXT: 2 Z3CrosscheckVisitor - Number of Z3 queries done diff --git a/clang/unittests/StaticAnalyzer/CMakeLists.txt b/clang/unittests/StaticAnalyzer/CMakeLists.txt index ff34d5747cc81..dcc557b44fb31 100644 --- a/clang/unittests/StaticAnalyzer/CMakeLists.txt +++ b/clang/unittests/StaticAnalyzer/CMakeLists.txt @@ -21,6 +21,7 @@ add_clang_unittest(StaticAnalysisTests SymbolReaperTest.cpp SValTest.cpp TestReturnValueUnderConstruction.cpp + Z3CrosscheckOracleTest.cpp ) clang_target_link_libraries(StaticAnalysisTests diff --git a/clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp b/clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp new file mode 100644 index 0000000000000..efad4dd3f03b9 --- /dev/null +++ b/clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp @@ -0,0 +1,59 @@ +//===- unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h" +#include "gtest/gtest.h" + +using namespace clang; +using namespace ento; + +using Z3Result = Z3CrosscheckVisitor::Z3Result; +using Z3Decision = Z3CrosscheckOracle::Z3Decision; + +static constexpr Z3Decision AcceptReport = Z3Decision::AcceptReport; +static constexpr Z3Decision RejectReport = Z3Decision::RejectReport; + +static constexpr std::optional<bool> SAT = true; +static constexpr std::optional<bool> UNSAT = false; +static constexpr std::optional<bool> UNDEF = std::nullopt; + +namespace { + +struct Z3CrosscheckOracleTest : public testing::Test { + Z3Decision interpretQueryResult(const Z3Result &Result) const { + return Z3CrosscheckOracle::interpretQueryResult(Result); + } +}; + +TEST_F(Z3CrosscheckOracleTest, AcceptsFirstSAT) { + ASSERT_EQ(AcceptReport, interpretQueryResult({SAT})); +} + +TEST_F(Z3CrosscheckOracleTest, AcceptsSAT) { + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); + ASSERT_EQ(AcceptReport, interpretQueryResult({SAT})); +} + +TEST_F(Z3CrosscheckOracleTest, AcceptsFirstTimeout) { + ASSERT_EQ(AcceptReport, interpretQueryResult({UNDEF})); +} + +TEST_F(Z3CrosscheckOracleTest, AcceptsTimeout) { + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); + ASSERT_EQ(AcceptReport, interpretQueryResult({UNDEF})); +} + +TEST_F(Z3CrosscheckOracleTest, RejectsUNSATs) { + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); + ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT})); +} + +} // namespace diff --git a/llvm/include/llvm/Support/SMTAPI.h b/llvm/include/llvm/Support/SMTAPI.h index 9389c96956dd1..a2a89674414f4 100644 --- a/llvm/include/llvm/Support/SMTAPI.h +++ b/llvm/include/llvm/Support/SMTAPI.h @@ -125,6 +125,19 @@ class SMTExpr { virtual bool equal_to(SMTExpr const &other) const = 0; }; +class SMTSolverStatistics { +public: + SMTSolverStatistics() = default; + virtual ~SMTSolverStatistics() = default; + + virtual double getDouble(llvm::StringRef) const = 0; + virtual unsigned getUnsigned(llvm::StringRef) const = 0; + + virtual void print(raw_ostream &OS) const = 0; + + LLVM_DUMP_METHOD void dump() const; +}; + /// Shared pointer for SMTExprs, used by SMTSolver API. using SMTExprRef = const SMTExpr *; @@ -434,6 +447,12 @@ class SMTSolver { virtual bool isFPSupported() = 0; virtual void print(raw_ostream &OS) const = 0; + + /// Sets the requested option. + virtual void setBoolParam(StringRef Key, bool Value) = 0; + virtual void setUnsignedParam(StringRef Key, unsigned Value) = 0; + + virtual std::unique_ptr<SMTSolverStatistics> getStatistics() const = 0; }; /// Shared pointer for SMTSolvers. diff --git a/llvm/lib/Support/Z3Solver.cpp b/llvm/lib/Support/Z3Solver.cpp index eb671fe2596db..5a34ff160f6cf 100644 --- a/llvm/lib/Support/Z3Solver.cpp +++ b/llvm/lib/Support/Z3Solver.cpp @@ -6,7 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/ScopeExit.h" #include "llvm/Config/config.h" +#include "llvm/Support/NativeFormatting.h" #include "llvm/Support/SMTAPI.h" using namespace llvm; @@ -26,18 +28,14 @@ namespace { class Z3Config { friend class Z3Context; - Z3_config Config; + Z3_config Config = Z3_mk_config(); public: - Z3Config() : Config(Z3_mk_config()) { - // Enable model finding - Z3_set_param_value(Config, "model", "true"); - // Disable proof generation - Z3_set_param_value(Config, "proof", "false"); - // Set timeout to 15000ms = 15s - Z3_set_param_value(Config, "timeout", "15000"); - } - + Z3Config() = default; + Z3Config(const Z3Config &) = delete; + Z3Config(Z3Config &&) = default; + Z3Config &operator=(Z3Config &) = delete; + Z3Config &operator=(Z3Config &&) = default; ~Z3Config() { Z3_del_config(Config); } }; // end class Z3Config @@ -50,16 +48,22 @@ void Z3ErrorHandler(Z3_context Context, Z3_error_code Error) { /// Wrapper for Z3 context class Z3Context { public: + Z3Config Config; Z3_context Context; Z3Context() { - Context = Z3_mk_context_rc(Z3Config().Config); + Context = Z3_mk_context_rc(Config.Config); // The error function is set here because the context is the first object // created by the backend Z3_set_error_handler(Context, Z3ErrorHandler); } - virtual ~Z3Context() { + Z3Context(const Z3Context &) = delete; + Z3Context(Z3Context &&) = default; + Z3Context &operator=(Z3Context &) = delete; + Z3Context &operator=(Z3Context &&) = default; + + ~Z3Context() { Z3_del_context(Context); Context = nullptr; } @@ -262,7 +266,17 @@ class Z3Solver : public SMTSolver { Z3Context Context; - Z3_solver Solver; + Z3_solver Solver = [this] { + Z3_solver S = Z3_mk_simple_solver(Context.Context); + Z3_solver_inc_ref(Context.Context, S); + return S; + }(); + + Z3_params Params = [this] { + Z3_params P = Z3_mk_params(Context.Context); + Z3_params_inc_ref(Context.Context, P); + return P; + }(); // Cache Sorts std::set<Z3Sort> CachedSorts; @@ -271,18 +285,15 @@ class Z3Solver : public SMTSolver { std::set<Z3Expr> CachedExprs; public: - Z3Solver() : Solver(Z3_mk_simple_solver(Context.Context)) { - Z3_solver_inc_ref(Context.Context, Solver); - } - + Z3Solver() = default; Z3Solver(const Z3Solver &Other) = delete; Z3Solver(Z3Solver &&Other) = delete; Z3Solver &operator=(Z3Solver &Other) = delete; Z3Solver &operator=(Z3Solver &&Other) = delete; - ~Z3Solver() { - if (Solver) - Z3_solver_dec_ref(Context.Context, Solver); + ~Z3Solver() override { + Z3_params_dec_ref(Context.Context, Params); + Z3_solver_dec_ref(Context.Context, Solver); } void addConstraint(const SMTExprRef &Exp) const override { @@ -871,6 +882,7 @@ class Z3Solver : public SMTSolver { } std::optional<bool> check() const override { + Z3_solver_set_params(Context.Context, Solver, Params); Z3_lbool res = Z3_solver_check(Context.Context, Solver); if (res == Z3_L_TRUE) return true; @@ -896,8 +908,71 @@ class Z3Solver : public SMTSolver { void print(raw_ostream &OS) const override { OS << Z3_solver_to_string(Context.Context, Solver); } + + void setUnsignedParam(StringRef Key, unsigned Value) override { + Z3_symbol Sym = Z3_mk_string_symbol(Context.Context, Key.str().c_str()); + Z3_params_set_uint(Context.Context, Params, Sym, Value); + } + + void setBoolParam(StringRef Key, bool Value) override { + Z3_symbol Sym = Z3_mk_string_symbol(Context.Context, Key.str().c_str()); + Z3_params_set_bool(Context.Context, Params, Sym, Value); + } + + std::unique_ptr<SMTSolverStatistics> getStatistics() const override; }; // end class Z3Solver +class Z3Statistics final : public SMTSolverStatistics { +public: + double getDouble(StringRef Key) const override { + auto It = DoubleValues.find(Key.str()); + assert(It != DoubleValues.end()); + return It->second; + }; + unsigned getUnsigned(StringRef Key) const override { + auto It = UnsignedValues.find(Key.str()); + assert(It != UnsignedValues.end()); + return It->second; + }; + + void print(raw_ostream &OS) const override { + for (auto const &[K, V] : UnsignedValues) { + OS << K << ": " << V << '\n'; + } + for (auto const &[K, V] : DoubleValues) { + write_double(OS << K << ": ", V, FloatStyle::Fixed); + OS << '\n'; + } + } + +private: + friend class Z3Solver; + std::unordered_map<std::string, unsigned> UnsignedValues; + std::unordered_map<std::string, double> DoubleValues; +}; + +std::unique_ptr<SMTSolverStatistics> Z3Solver::getStatistics() const { + auto const &C = Context.Context; + Z3_stats S = Z3_solver_get_statistics(C, Solver); + Z3_stats_inc_ref(C, S); + auto StatsGuard = llvm::make_scope_exit([&C, &S] { Z3_stats_dec_ref(C, S); }); + Z3Statistics Result; + + unsigned NumKeys = Z3_stats_size(C, S); + for (unsigned Idx = 0; Idx < NumKeys; ++Idx) { + const char *Key = Z3_stats_get_key(C, S, Idx); + if (Z3_stats_is_uint(C, S, Idx)) { + auto Value = Z3_stats_get_uint_value(C, S, Idx); + Result.UnsignedValues.try_emplace(Key, Value); + } else { + assert(Z3_stats_is_double(C, S, Idx)); + auto Value = Z3_stats_get_double_value(C, S, Idx); + Result.DoubleValues.try_emplace(Key, Value); + } + } + return std::make_unique<Z3Statistics>(std::move(Result)); +} + } // end anonymous namespace #endif @@ -916,3 +991,4 @@ llvm::SMTSolverRef llvm::CreateZ3Solver() { LLVM_DUMP_METHOD void SMTSort::dump() const { print(llvm::errs()); } LLVM_DUMP_METHOD void SMTExpr::dump() const { print(llvm::errs()); } LLVM_DUMP_METHOD void SMTSolver::dump() const { print(llvm::errs()); } +LLVM_DUMP_METHOD void SMTSolverStatistics::dump() const { print(llvm::errs()); } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits