https://github.com/ivanaivanovska updated https://github.com/llvm/llvm-project/pull/125492
>From 54c7b3c1fb149b82c26927d0fd831d8786f70ac3 Mon Sep 17 00:00:00 2001 From: Ivana Ivanovska <iivanov...@google.com> Date: Mon, 2 Dec 2024 14:17:06 +0000 Subject: [PATCH 1/2] Optimize -Wunsafe-buffer-usage. --- clang/lib/Analysis/UnsafeBufferUsage.cpp | 1427 ++++++++++++++-------- 1 file changed, 906 insertions(+), 521 deletions(-) diff --git a/clang/lib/Analysis/UnsafeBufferUsage.cpp b/clang/lib/Analysis/UnsafeBufferUsage.cpp index c064aa30e8aedc6..4520d28d9e94522 100644 --- a/clang/lib/Analysis/UnsafeBufferUsage.cpp +++ b/clang/lib/Analysis/UnsafeBufferUsage.cpp @@ -8,30 +8,32 @@ #include "clang/Analysis/Analyses/UnsafeBufferUsage.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTTypeTraits.h" #include "clang/AST/Decl.h" +#include "clang/AST/DeclCXX.h" #include "clang/AST/DynamicRecursiveASTVisitor.h" #include "clang/AST/Expr.h" #include "clang/AST/FormatString.h" +#include "clang/AST/ParentMapContext.h" #include "clang/AST/Stmt.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/Type.h" -#include "clang/ASTMatchers/ASTMatchFinder.h" -#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" #include "clang/Lex/Preprocessor.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include <memory> #include <optional> #include <queue> +#include <set> #include <sstream> using namespace llvm; using namespace clang; -using namespace ast_matchers; #ifndef NDEBUG namespace { @@ -68,7 +70,7 @@ static std::string getDREAncestorString(const DeclRefExpr *DRE, if (StParents.size() > 1) return "unavailable due to multiple parents"; - if (StParents.size() == 0) + if (StParents.empty()) break; St = StParents.begin()->get<Stmt>(); if (St) @@ -76,10 +78,39 @@ static std::string getDREAncestorString(const DeclRefExpr *DRE, } while (St); return SS.str(); } + } // namespace #endif /* NDEBUG */ -namespace clang::ast_matchers { +namespace { +// Using a custom matcher instead of ASTMatchers to achieve better performance. +class FastMatcher { +public: + virtual bool matches(const DynTypedNode &DynNode, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler) = 0; + virtual ~FastMatcher() = default; +}; + +class MatchResult { + +public: + template <typename T> const T *getNodeAs(StringRef ID) const { + auto It = Nodes.find(std::string(ID)); + if (It == Nodes.end()) { + return nullptr; + } + return It->second.get<T>(); + } + + void addNode(StringRef ID, const DynTypedNode &Node) { + Nodes[std::string(ID)] = Node; + } + +private: + llvm::StringMap<DynTypedNode> Nodes; +}; +} // namespace + // A `RecursiveASTVisitor` that traverses all descendants of a given node "n" // except for those belonging to a different callable of "n". class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { @@ -87,13 +118,11 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { // Creates an AST visitor that matches `Matcher` on all // descendants of a given node "n" except for the ones // belonging to a different callable of "n". - MatchDescendantVisitor(const internal::DynTypedMatcher *Matcher, - internal::ASTMatchFinder *Finder, - internal::BoundNodesTreeBuilder *Builder, - internal::ASTMatchFinder::BindKind Bind, + MatchDescendantVisitor(FastMatcher &Matcher, bool FindAll, const bool ignoreUnevaluatedContext) - : Matcher(Matcher), Finder(Finder), Builder(Builder), Bind(Bind), - Matches(false), ignoreUnevaluatedContext(ignoreUnevaluatedContext) { + : Matcher(&Matcher), FindAll(FindAll), Matches(false), + ignoreUnevaluatedContext(ignoreUnevaluatedContext), + ActiveASTContext(nullptr), Handler(nullptr) { ShouldVisitTemplateInstantiations = true; ShouldVisitImplicitCode = false; // TODO: let's ignore implicit code for now } @@ -104,7 +133,6 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { Matches = false; if (const Stmt *StmtNode = DynNode.get<Stmt>()) { TraverseStmt(const_cast<Stmt *>(StmtNode)); - *Builder = ResultBindings; return Matches; } return false; @@ -186,106 +214,212 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { return DynamicRecursiveASTVisitor::TraverseStmt(Node); } + void setASTContext(ASTContext &Context) { ActiveASTContext = &Context; } + + void setHandler(const UnsafeBufferUsageHandler &NewHandler) { + Handler = &NewHandler; + } + private: // Sets 'Matched' to true if 'Matcher' matches 'Node' // // Returns 'true' if traversal should continue after this function // returns, i.e. if no match is found or 'Bind' is 'BK_All'. template <typename T> bool match(const T &Node) { - internal::BoundNodesTreeBuilder RecursiveBuilder(*Builder); - - if (Matcher->matches(DynTypedNode::create(Node), Finder, - &RecursiveBuilder)) { - ResultBindings.addMatch(RecursiveBuilder); + if (Matcher->matches(DynTypedNode::create(Node), *ActiveASTContext, + *Handler)) { Matches = true; - if (Bind != internal::ASTMatchFinder::BK_All) + if (!FindAll) return false; // Abort as soon as a match is found. } return true; } - const internal::DynTypedMatcher *const Matcher; - internal::ASTMatchFinder *const Finder; - internal::BoundNodesTreeBuilder *const Builder; - internal::BoundNodesTreeBuilder ResultBindings; - const internal::ASTMatchFinder::BindKind Bind; + FastMatcher *const Matcher; + // When true, finds all matches. When false, finds the first match and stops. + const bool FindAll; bool Matches; bool ignoreUnevaluatedContext; + ASTContext *ActiveASTContext; + const UnsafeBufferUsageHandler *Handler; }; // Because we're dealing with raw pointers, let's define what we mean by that. -static auto hasPointerType() { - return hasType(hasCanonicalType(pointerType())); +static bool hasPointerType(const Expr &E) { + return isa<PointerType>(E.getType().getCanonicalType()); } -static auto hasArrayType() { return hasType(hasCanonicalType(arrayType())); } - -AST_MATCHER_P(Stmt, forEachDescendantEvaluatedStmt, internal::Matcher<Stmt>, - innerMatcher) { - const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher); - - MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All, - true); - return Visitor.findMatch(DynTypedNode::create(Node)); +static bool hasArrayType(const Expr &E) { + return isa<ArrayType>(E.getType().getCanonicalType()); } -AST_MATCHER_P(Stmt, forEachDescendantStmt, internal::Matcher<Stmt>, - innerMatcher) { - const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher); +static void +forEachDescendantEvaluatedStmt(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler, + FastMatcher &Matcher) { + MatchDescendantVisitor Visitor(Matcher, /* FindAll */ true, + /*ignoreUnevaluatedContext*/ true); + Visitor.setASTContext(Ctx); + Visitor.setHandler(Handler); + Visitor.findMatch(DynTypedNode::create(*S)); +} - MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All, - false); - return Visitor.findMatch(DynTypedNode::create(Node)); +static void forEachDescendantStmt(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler, + FastMatcher &Matcher) { + MatchDescendantVisitor Visitor(Matcher, /* FindAll */ true, + /*ignoreUnevaluatedContext*/ false); + Visitor.setASTContext(Ctx); + Visitor.setHandler(Handler); + Visitor.findMatch(DynTypedNode::create(*S)); } // Matches a `Stmt` node iff the node is in a safe-buffer opt-out region -AST_MATCHER_P(Stmt, notInSafeBufferOptOut, const UnsafeBufferUsageHandler *, - Handler) { +static bool notInSafeBufferOptOut(const Stmt &Node, + const UnsafeBufferUsageHandler *Handler) { return !Handler->isSafeBufferOptOut(Node.getBeginLoc()); } -AST_MATCHER_P(Stmt, ignoreUnsafeBufferInContainer, - const UnsafeBufferUsageHandler *, Handler) { +static bool +ignoreUnsafeBufferInContainer(const Stmt &Node, + const UnsafeBufferUsageHandler *Handler) { return Handler->ignoreUnsafeBufferInContainer(Node.getBeginLoc()); } -AST_MATCHER_P(Stmt, ignoreUnsafeLibcCall, const UnsafeBufferUsageHandler *, - Handler) { - if (Finder->getASTContext().getLangOpts().CPlusPlus) +static bool ignoreUnsafeLibcCall(const ASTContext &Ctx, const Stmt &Node, + const UnsafeBufferUsageHandler *Handler) { + if (Ctx.getLangOpts().CPlusPlus) return Handler->ignoreUnsafeBufferInLibcCall(Node.getBeginLoc()); return true; /* Only warn about libc calls for C++ */ } -AST_MATCHER_P(CastExpr, castSubExpr, internal::Matcher<Expr>, innerMatcher) { - return innerMatcher.matches(*Node.getSubExpr(), Finder, Builder); +// Finds any expression 'e' such that `OnResult` +// matches 'e' and 'e' is in an Unspecified Lvalue Context. +static void findStmtsInUnspecifiedLvalueContext( + const Stmt *S, const llvm::function_ref<void(const Expr *)> OnResult) { + if (const auto *CE = dyn_cast<ImplicitCastExpr>(S)) { + if (CE->getCastKind() != CastKind::CK_LValueToRValue) + return; + OnResult(CE->getSubExpr()); + } + if (const auto *BO = dyn_cast<BinaryOperator>(S)) { + if (BO->getOpcode() != BO_Assign) + return; + OnResult(BO->getLHS()); + } } -// Matches a `UnaryOperator` whose operator is pre-increment: -AST_MATCHER(UnaryOperator, isPreInc) { - return Node.getOpcode() == UnaryOperator::Opcode::UO_PreInc; +/// Note: Copied and modified from ASTMatchers. +/// Matches all arguments and their respective types for a \c CallExpr or +/// \c CXXConstructExpr. It is very similar to \c forEachArgumentWithParam but +/// it works on calls through function pointers as well. +/// +/// The difference is, that function pointers do not provide access to a +/// \c ParmVarDecl, but only the \c QualType for each argument. +/// +/// Given +/// \code +/// void f(int i); +/// int y; +/// f(y); +/// void (*f_ptr)(int) = f; +/// f_ptr(y); +/// \endcode +/// callExpr( +/// forEachArgumentWithParamType( +/// declRefExpr(to(varDecl(hasName("y")))), +/// qualType(isInteger()).bind("type) +/// )) +/// matches f(y) and f_ptr(y) +/// with declRefExpr(...) +/// matching int y +/// and qualType(...) +/// matching int +static void forEachArgumentWithParamType( + const CallExpr &Node, + const llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)> + OnParamAndArg) { + // The first argument of an overloaded member operator is the implicit object + // argument of the method which should not be matched against a parameter, so + // we skip over it here. + unsigned ArgIndex = 0; + if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(&Node)) { + const auto *FD = CE->getDirectCallee(); + if (FD) { + if (const auto *MD = dyn_cast<CXXMethodDecl>(FD); + MD && !MD->isExplicitObjectMemberFunction()) { + // This is an overloaded operator call. + // We need to skip the first argument, which is the implicit object + // argument of the method which should not be matched against a + // parameter. + ++ArgIndex; + } + } + } + + const FunctionProtoType *FProto = nullptr; + + if (const auto *Call = dyn_cast<CallExpr>(&Node)) { + if (const auto *Value = + dyn_cast_or_null<ValueDecl>(Call->getCalleeDecl())) { + QualType QT = Value->getType().getCanonicalType(); + + // This does not necessarily lead to a `FunctionProtoType`, + // e.g. K&R functions do not have a function prototype. + if (QT->isFunctionPointerType()) + FProto = QT->getPointeeType()->getAs<FunctionProtoType>(); + + if (QT->isMemberFunctionPointerType()) { + const auto *MP = QT->getAs<MemberPointerType>(); + assert(MP && "Must be member-pointer if its a memberfunctionpointer"); + FProto = MP->getPointeeType()->getAs<FunctionProtoType>(); + assert(FProto && + "The call must have happened through a member function " + "pointer"); + } + } + } + + unsigned ParamIndex = 0; + unsigned NumArgs = Node.getNumArgs(); + if (FProto && FProto->isVariadic()) + NumArgs = std::min(NumArgs, FProto->getNumParams()); + + const auto GetParamType = + [&FProto, &Node](unsigned int ParamIndex) -> std::optional<QualType> { + // This test is cheaper compared to the big matcher in the next if. + // Therefore, please keep this order. + if (FProto && FProto->getNumParams() > ParamIndex) { + return FProto->getParamType(ParamIndex); + } + if (const auto *E = dyn_cast<Expr>(&Node)) { + if (const auto *CE = dyn_cast<CXXConstructExpr>(E)) { + if (const auto *Ctor = CE->getConstructor(); + Ctor && Ctor->getNumParams() > ParamIndex) { + return CE->getArg(ParamIndex)->getType(); + } + } + if (const auto *CE = dyn_cast<CallExpr>(E)) { + const auto *FD = CE->getDirectCallee(); + if (FD && FD->getNumParams() > ParamIndex) { + return CE->getArg(ParamIndex)->getType(); + } + } + } + return std::nullopt; + }; + + for (; ArgIndex < NumArgs; ++ArgIndex, ++ParamIndex) { + auto ParamType = GetParamType(ParamIndex); + if (ParamType) + OnParamAndArg(*ParamType, Node.getArg(ArgIndex)->IgnoreParenCasts()); + } } -// Returns a matcher that matches any expression 'e' such that `innerMatcher` -// matches 'e' and 'e' is in an Unspecified Lvalue Context. -static auto isInUnspecifiedLvalueContext(internal::Matcher<Expr> innerMatcher) { - // clang-format off - return - expr(anyOf( - implicitCastExpr( - hasCastKind(CastKind::CK_LValueToRValue), - castSubExpr(innerMatcher)), - binaryOperator( - hasAnyOperatorName("="), - hasLHS(innerMatcher) - ) - )); - // clang-format on -} - -// Returns a matcher that matches any expression `e` such that `InnerMatcher` -// matches `e` and `e` is in an Unspecified Pointer Context (UPC). -static internal::Matcher<Stmt> -isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) { +// Finds any expression `e` such that `InnerMatcher` matches `e` and +// `e` is in an Unspecified Pointer Context (UPC). +static void findStmtsInUnspecifiedPointerContext( + const Stmt *S, llvm::function_ref<void(const Stmt *)> InnerMatcher) { // A UPC can be // 1. an argument of a function call (except the callee has [[unsafe_...]] // attribute), or @@ -294,45 +428,57 @@ isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) { // 4. the operand of a pointer subtraction operation // (i.e., computing the distance between two pointers); or ... - // clang-format off - auto CallArgMatcher = callExpr( + if (auto *CE = dyn_cast<CallExpr>(S)) { + if (const auto *FnDecl = CE->getDirectCallee(); + FnDecl && FnDecl->hasAttr<UnsafeBufferUsageAttr>()) + return; forEachArgumentWithParamType( - InnerMatcher, - isAnyPointer() /* array also decays to pointer type*/), - unless(callee( - functionDecl(hasAttr(attr::UnsafeBufferUsage))))); - - auto CastOperandMatcher = - castExpr(anyOf(hasCastKind(CastKind::CK_PointerToIntegral), - hasCastKind(CastKind::CK_PointerToBoolean)), - castSubExpr(allOf(hasPointerType(), InnerMatcher))); - - auto CompOperandMatcher = - binaryOperator(hasAnyOperatorName("!=", "==", "<", "<=", ">", ">="), - eachOf(hasLHS(allOf(hasPointerType(), InnerMatcher)), - hasRHS(allOf(hasPointerType(), InnerMatcher)))); - - // A matcher that matches pointer subtractions: - auto PtrSubtractionMatcher = - binaryOperator(hasOperatorName("-"), - // Note that here we need both LHS and RHS to be - // pointer. Then the inner matcher can match any of - // them: - allOf(hasLHS(hasPointerType()), - hasRHS(hasPointerType())), - eachOf(hasLHS(InnerMatcher), - hasRHS(InnerMatcher))); - // clang-format on - - return stmt(anyOf(CallArgMatcher, CastOperandMatcher, CompOperandMatcher, - PtrSubtractionMatcher)); - // FIXME: any more cases? (UPC excludes the RHS of an assignment. For now we - // don't have to check that.) -} - -// Returns a matcher that matches any expression 'e' such that `innerMatcher` -// matches 'e' and 'e' is in an unspecified untyped context (i.e the expression -// 'e' isn't evaluated to an RValue). For example, consider the following code: + *CE, [&InnerMatcher](QualType Type, const Expr *Arg) { + if (Type->isAnyPointerType()) + InnerMatcher(Arg); + }); + } + + if (auto *CE = dyn_cast<CastExpr>(S)) { + if (CE->getCastKind() != CastKind::CK_PointerToIntegral && + CE->getCastKind() != CastKind::CK_PointerToBoolean) + return; + if (!hasPointerType(*CE->getSubExpr())) + return; + InnerMatcher(CE->getSubExpr()); + } + + // Pointer comparison operator. + if (const auto *BO = dyn_cast<BinaryOperator>(S); + BO && (BO->getOpcode() == BO_EQ || BO->getOpcode() == BO_NE || + BO->getOpcode() == BO_LT || BO->getOpcode() == BO_LE || + BO->getOpcode() == BO_GT || BO->getOpcode() == BO_GE)) { + auto *LHS = BO->getLHS(); + auto *RHS = BO->getRHS(); + if (!hasPointerType(*LHS) || !hasPointerType(*RHS)) + return; + InnerMatcher(LHS); + InnerMatcher(RHS); + } + + // Pointer subtractions. + if (const auto *BO = dyn_cast<BinaryOperator>(S); + BO && BO->getOpcode() == BO_Sub && hasPointerType(*BO->getLHS()) && + hasPointerType(*BO->getRHS())) { + // Note that here we need both LHS and RHS to be + // pointer. Then the inner matcher can match any of + // them: + InnerMatcher(BO->getLHS()); + InnerMatcher(BO->getRHS()); + } + // FIXME: any more cases? (UPC excludes the RHS of an assignment. For now + // we don't have to check that.) +} + +// Finds statements in unspecified untyped context i.e. any expression 'e' such +// that `InnerMatcher` matches 'e' and 'e' is in an unspecified untyped context +// (i.e the expression 'e' isn't evaluated to an RValue). For example, consider +// the following code: // int *p = new int[4]; // int *q = new int[4]; // if ((p = q)) {} @@ -340,17 +486,23 @@ isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) { // The expression `p = q` in the conditional of the `if` statement // `if ((p = q))` is evaluated as an RValue, whereas the expression `p = q;` // in the assignment statement is in an untyped context. -static internal::Matcher<Stmt> -isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher) { +static void findStmtsInUnspecifiedUntypedContext( + const Stmt *S, llvm::function_ref<void(const Stmt *)> InnerMatcher) { // An unspecified context can be // 1. A compound statement, // 2. The body of an if statement // 3. Body of a loop - auto CompStmt = compoundStmt(forEach(InnerMatcher)); - auto IfStmtThen = ifStmt(hasThen(InnerMatcher)); - auto IfStmtElse = ifStmt(hasElse(InnerMatcher)); + if (auto *CS = dyn_cast<CompoundStmt>(S)) { + for (auto *Child : CS->body()) + InnerMatcher(Child); + } + if (auto *IfS = dyn_cast<IfStmt>(S)) { + if (IfS->getThen()) + InnerMatcher(IfS->getThen()); + if (IfS->getElse()) + InnerMatcher(IfS->getElse()); + } // FIXME: Handle loop bodies. - return stmt(anyOf(CompStmt, IfStmtThen, IfStmtElse)); } // Given a two-param std::span construct call, matches iff the call has the @@ -362,14 +514,15 @@ isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher) { // `n` // 5. `std::span<T>{any, 0}` // 6. `std::span<T>{std::addressof(...), 1}` -AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { +static bool isSafeSpanTwoParamConstruct(const CXXConstructExpr &Node, + const ASTContext &Ctx) { assert(Node.getNumArgs() == 2 && "expecting a two-parameter std::span constructor"); const Expr *Arg0 = Node.getArg(0)->IgnoreImplicit(); const Expr *Arg1 = Node.getArg(1)->IgnoreImplicit(); - auto HaveEqualConstantValues = [&Finder](const Expr *E0, const Expr *E1) { - if (auto E0CV = E0->getIntegerConstantExpr(Finder->getASTContext())) - if (auto E1CV = E1->getIntegerConstantExpr(Finder->getASTContext())) { + auto HaveEqualConstantValues = [&Ctx](const Expr *E0, const Expr *E1) { + if (auto E0CV = E0->getIntegerConstantExpr(Ctx)) + if (auto E1CV = E1->getIntegerConstantExpr(Ctx)) { return APSInt::compareValues(*E0CV, *E1CV) == 0; } return false; @@ -381,8 +534,7 @@ AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { } return false; }; - std::optional<APSInt> Arg1CV = - Arg1->getIntegerConstantExpr(Finder->getASTContext()); + std::optional<APSInt> Arg1CV = Arg1->getIntegerConstantExpr(Ctx); if (Arg1CV && Arg1CV->isZero()) // Check form 5: @@ -421,8 +573,7 @@ AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { QualType Arg0Ty = Arg0->IgnoreImplicit()->getType(); - if (auto *ConstArrTy = - Finder->getASTContext().getAsConstantArrayType(Arg0Ty)) { + if (auto *ConstArrTy = Ctx.getAsConstantArrayType(Arg0Ty)) { const APSInt ConstArrSize = APSInt(ConstArrTy->getSize()); // Check form 4: @@ -431,7 +582,8 @@ AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { return false; } -AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { +static bool isSafeArraySubscript(const ArraySubscriptExpr &Node, + const ASTContext &Ctx) { // FIXME: Proper solution: // - refactor Sema::CheckArrayAccess // - split safe/OOB/unknown decision logic from diagnostics emitting code @@ -446,7 +598,7 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { ->getType() ->getUnqualifiedDesugaredType())) { limit = CATy->getLimitedSize(); - } else if (const auto *SLiteral = dyn_cast<StringLiteral>( + } else if (const auto *SLiteral = dyn_cast<clang::StringLiteral>( Node.getBase()->IgnoreParenImpCasts())) { limit = SLiteral->getLength() + 1; } else { @@ -456,7 +608,7 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { Expr::EvalResult EVResult; const Expr *IndexExpr = Node.getIdx(); if (!IndexExpr->isValueDependent() && - IndexExpr->EvaluateAsInt(EVResult, Finder->getASTContext())) { + IndexExpr->EvaluateAsInt(EVResult, Ctx)) { llvm::APSInt ArrIdx = EVResult.Val.getInt(); // FIXME: ArrIdx.isNegative() we could immediately emit an error as that's a // bug @@ -466,10 +618,6 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { return false; } -AST_MATCHER_P(CallExpr, hasNumArgs, unsigned, Num) { - return Node.getNumArgs() == Num; -} - namespace libc_func_matchers { // Under `libc_func_matchers`, define a set of matchers that match unsafe // functions in libc and unsafe calls to them. @@ -518,7 +666,7 @@ struct LibcFunNamePrefixSuffixParser { // A pointer type expression is known to be null-terminated, if it has the // form: E.c_str(), for any expression E of `std::string` type. static bool isNullTermPointer(const Expr *Ptr) { - if (isa<StringLiteral>(Ptr->IgnoreParenImpCasts())) + if (isa<clang::StringLiteral>(Ptr->IgnoreParenImpCasts())) return true; if (isa<PredefinedExpr>(Ptr->IgnoreParenImpCasts())) return true; @@ -576,7 +724,7 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg, const Expr *Fmt = Call->getArg(FmtArgIdx); - if (auto *SL = dyn_cast<StringLiteral>(Fmt->IgnoreParenImpCasts())) { + if (auto *SL = dyn_cast<clang::StringLiteral>(Fmt->IgnoreParenImpCasts())) { StringRef FmtStr; if (SL->getCharByteWidth() == 1) @@ -616,7 +764,7 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg, // Note: For predefined prefix and suffix, see `LibcFunNamePrefixSuffixParser`. // The notation `CoreName[str/wcs]` means a new name obtained from replace // string "wcs" with "str" in `CoreName`. -AST_MATCHER(FunctionDecl, isPredefinedUnsafeLibcFunc) { +static bool isPredefinedUnsafeLibcFunc(const FunctionDecl &Node) { static std::unique_ptr<std::set<StringRef>> PredefinedNames = nullptr; if (!PredefinedNames) PredefinedNames = @@ -723,7 +871,7 @@ AST_MATCHER(FunctionDecl, isPredefinedUnsafeLibcFunc) { // Match a call to one of the `v*printf` functions taking `va_list`. We cannot // check safety for these functions so they should be changed to their // non-va_list versions. -AST_MATCHER(FunctionDecl, isUnsafeVaListPrintfFunc) { +static bool isUnsafeVaListPrintfFunc(const FunctionDecl &Node) { auto *II = Node.getIdentifier(); if (!II) @@ -739,7 +887,7 @@ AST_MATCHER(FunctionDecl, isUnsafeVaListPrintfFunc) { // Matches a call to one of the `sprintf` functions as they are always unsafe // and should be changed to `snprintf`. -AST_MATCHER(FunctionDecl, isUnsafeSprintfFunc) { +static bool isUnsafeSprintfFunc(const FunctionDecl &Node) { auto *II = Node.getIdentifier(); if (!II) @@ -763,7 +911,7 @@ AST_MATCHER(FunctionDecl, isUnsafeSprintfFunc) { // Match function declarations of `printf`, `fprintf`, `snprintf` and their wide // character versions. Calls to these functions can be safe if their arguments // are carefully made safe. -AST_MATCHER(FunctionDecl, isNormalPrintfFunc) { +static bool isNormalPrintfFunc(const FunctionDecl &Node) { auto *II = Node.getIdentifier(); if (!II) @@ -787,9 +935,9 @@ AST_MATCHER(FunctionDecl, isNormalPrintfFunc) { // Then if the format string is a string literal, this matcher matches when at // least one string argument is unsafe. If the format is not a string literal, // this matcher matches when at least one pointer type argument is unsafe. -AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, - clang::ast_matchers::internal::Matcher<Expr>, - UnsafeStringArgMatcher) { +static bool hasUnsafePrintfStringArg(const CallExpr &Node, ASTContext &Ctx, + MatchResult &Result, + const char *const Op) { // Determine what printf it is by examining formal parameters: const FunctionDecl *FD = Node.getDirectCallee(); @@ -800,7 +948,6 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, if (NumParms < 1) return false; // possibly some user-defined printf function - ASTContext &Ctx = Finder->getASTContext(); QualType FirstParmTy = FD->getParamDecl(0)->getType(); if (!FirstParmTy->isPointerType()) @@ -814,8 +961,10 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, // It is a fprintf: const Expr *UnsafeArg; - if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 1, Ctx, false)) - return UnsafeStringArgMatcher.matches(*UnsafeArg, Finder, Builder); + if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 1, Ctx, false)) { + Result.addNode(Op, DynTypedNode::create(*UnsafeArg)); + return true; + } return false; } @@ -826,8 +975,10 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, if (auto *II = FD->getIdentifier()) isKprintf = II->getName() == "kprintf"; - if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 0, Ctx, isKprintf)) - return UnsafeStringArgMatcher.matches(*UnsafeArg, Finder, Builder); + if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 0, Ctx, isKprintf)) { + Result.addNode(Op, DynTypedNode::create(*UnsafeArg)); + return true; + } return false; } @@ -839,17 +990,22 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, // second is an integer, it is a snprintf: const Expr *UnsafeArg; - if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 2, Ctx, false)) - return UnsafeStringArgMatcher.matches(*UnsafeArg, Finder, Builder); + if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 2, Ctx, false)) { + Result.addNode(Op, DynTypedNode::create(*UnsafeArg)); + return true; + } return false; } } // We don't really recognize this "normal" printf, the only thing we // can do is to require all pointers to be null-terminated: - for (auto Arg : Node.arguments()) - if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg)) - if (UnsafeStringArgMatcher.matches(*Arg, Finder, Builder)) + for (const auto *Arg : Node.arguments()) + if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg)) { + if (isa<Expr>(Arg)) { + Result.addNode(Op, DynTypedNode::create(*Arg)); return true; + } + } return false; } @@ -869,7 +1025,8 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, // ptr := Constant-Array-DRE; // size:= any expression that has compile-time constant value equivalent to // sizeof (Constant-Array-DRE) -AST_MATCHER(CallExpr, hasUnsafeSnprintfBuffer) { +static bool hasUnsafeSnprintfBuffer(const CallExpr &Node, + const ASTContext &Ctx) { const FunctionDecl *FD = Node.getDirectCallee(); assert(FD && "It should have been checked that FD is non-null."); @@ -923,8 +1080,6 @@ AST_MATCHER(CallExpr, hasUnsafeSnprintfBuffer) { // Pattern 2: if (auto *DRE = dyn_cast<DeclRefExpr>(Buf->IgnoreParenImpCasts())) { - ASTContext &Ctx = Finder->getASTContext(); - if (auto *CAT = Ctx.getAsConstantArrayType(DRE->getType())) { Expr::EvalResult ER; // The array element type must be compatible with `char` otherwise an @@ -940,7 +1095,6 @@ AST_MATCHER(CallExpr, hasUnsafeSnprintfBuffer) { return true; // ptr and size are not in safe pattern } } // namespace libc_func_matchers -} // namespace clang::ast_matchers namespace { // Because the analysis revolves around variables and their types, we'll need to @@ -967,11 +1121,6 @@ class Gadget { #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" }; - /// Common type of ASTMatchers used for discovering gadgets. - /// Useful for implementing the static matcher() methods - /// that are expected from all non-abstract subclasses. - using Matcher = decltype(stmt()); - Gadget(Kind K) : K(K) {} Kind getKind() const { return K; } @@ -1048,7 +1197,10 @@ class FixableGadget : public Gadget { } }; -static auto toSupportedVariable() { return to(varDecl()); } +static auto toSupportedVariable(const DeclRefExpr &Node) { + const Decl *D = Node.getDecl(); + return D != nullptr && isa<VarDecl>(D); +} using FixableGadgetList = std::vector<std::unique_ptr<FixableGadget>>; using WarningGadgetList = std::vector<std::unique_ptr<WarningGadget>>; @@ -1060,19 +1212,24 @@ class IncrementGadget : public WarningGadget { const UnaryOperator *Op; public: - IncrementGadget(const MatchFinder::MatchResult &Result) + IncrementGadget(const MatchResult &Result) : WarningGadget(Kind::Increment), - Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {} + Op(Result.getNodeAs<UnaryOperator>(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::Increment; } - static Matcher matcher() { - return stmt( - unaryOperator(hasOperatorName("++"), - hasUnaryOperand(ignoringParenImpCasts(hasPointerType()))) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *UO = dyn_cast<UnaryOperator>(S); + if (!UO || !UO->isIncrementOp()) + return false; + const auto *Operand = UO->getSubExpr()->IgnoreParenImpCasts(); + if (!hasPointerType(*Operand)) + return false; + Result.addNode(OpTag, DynTypedNode::create(*UO)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1100,19 +1257,24 @@ class DecrementGadget : public WarningGadget { const UnaryOperator *Op; public: - DecrementGadget(const MatchFinder::MatchResult &Result) + DecrementGadget(const MatchResult &Result) : WarningGadget(Kind::Decrement), - Op(Result.Nodes.getNodeAs<UnaryOperator>(OpTag)) {} + Op(Result.getNodeAs<UnaryOperator>(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::Decrement; } - static Matcher matcher() { - return stmt( - unaryOperator(hasOperatorName("--"), - hasUnaryOperand(ignoringParenImpCasts(hasPointerType()))) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *UO = dyn_cast<UnaryOperator>(S); + if (!UO || !UO->isDecrementOp()) + return false; + const auto *Operand = UO->getSubExpr()->IgnoreParenImpCasts(); + if (!hasPointerType(*Operand)) + return false; + Result.addNode(OpTag, DynTypedNode::create(*UO)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1139,26 +1301,30 @@ class ArraySubscriptGadget : public WarningGadget { const ArraySubscriptExpr *ASE; public: - ArraySubscriptGadget(const MatchFinder::MatchResult &Result) + ArraySubscriptGadget(const MatchResult &Result) : WarningGadget(Kind::ArraySubscript), - ASE(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ArraySubscrTag)) {} + ASE(Result.getNodeAs<ArraySubscriptExpr>(ArraySubscrTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::ArraySubscript; } - static Matcher matcher() { - // clang-format off - return stmt(arraySubscriptExpr( - hasBase(ignoringParenImpCasts( - anyOf(hasPointerType(), hasArrayType()))), - unless(anyOf( - isSafeArraySubscript(), - hasIndex( - anyOf(integerLiteral(equals(0)), arrayInitIndexExpr()) - ) - ))).bind(ArraySubscrTag)); - // clang-format on + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *ASE = dyn_cast<ArraySubscriptExpr>(S); + if (!ASE) + return false; + const auto *const Base = ASE->getBase()->IgnoreParenImpCasts(); + if (!hasPointerType(*Base) && !hasArrayType(*Base)) + return false; + bool isSafeIndex = + (isa<IntegerLiteral>(ASE->getIdx()) && + cast<IntegerLiteral>(ASE->getIdx())->getValue().isZero()) || + isa<ArrayInitIndexExpr>(ASE->getIdx()); + if (isSafeArraySubscript(*ASE, Ctx) || isSafeIndex) + return false; + Result.addNode(ArraySubscrTag, DynTypedNode::create(*ASE)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1189,29 +1355,40 @@ class PointerArithmeticGadget : public WarningGadget { const Expr *Ptr; // the pointer expression in `PA` public: - PointerArithmeticGadget(const MatchFinder::MatchResult &Result) + PointerArithmeticGadget(const MatchResult &Result) : WarningGadget(Kind::PointerArithmetic), - PA(Result.Nodes.getNodeAs<BinaryOperator>(PointerArithmeticTag)), - Ptr(Result.Nodes.getNodeAs<Expr>(PointerArithmeticPointerTag)) {} + PA((Result.getNodeAs<BinaryOperator>(PointerArithmeticTag))), + Ptr(Result.getNodeAs<Expr>(PointerArithmeticPointerTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PointerArithmetic; } - static Matcher matcher() { - auto HasIntegerType = anyOf(hasType(isInteger()), hasType(enumType())); - auto PtrAtRight = - allOf(hasOperatorName("+"), - hasRHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)), - hasLHS(HasIntegerType)); - auto PtrAtLeft = - allOf(anyOf(hasOperatorName("+"), hasOperatorName("-"), - hasOperatorName("+="), hasOperatorName("-=")), - hasLHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)), - hasRHS(HasIntegerType)); - - return stmt(binaryOperator(anyOf(PtrAtLeft, PtrAtRight)) - .bind(PointerArithmeticTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *BO = dyn_cast<BinaryOperator>(S); + if (!BO) + return false; + const auto *LHS = BO->getLHS(); + const auto *RHS = BO->getRHS(); + // ptr at left + if (BO->getOpcode() == BO_Add || BO->getOpcode() == BO_Sub || + BO->getOpcode() == BO_AddAssign || BO->getOpcode() == BO_SubAssign) { + if (hasPointerType(*LHS) && (RHS->getType()->isIntegerType() || + RHS->getType()->isEnumeralType())) { + Result.addNode(PointerArithmeticPointerTag, DynTypedNode::create(*LHS)); + Result.addNode(PointerArithmeticTag, DynTypedNode::create(*BO)); + return true; + } + } + // ptr at right + if (BO->getOpcode() == BO_Add && hasPointerType(*RHS) && + (LHS->getType()->isIntegerType() || LHS->getType()->isEnumeralType())) { + Result.addNode(PointerArithmeticPointerTag, DynTypedNode::create(*RHS)); + Result.addNode(PointerArithmeticTag, DynTypedNode::create(*BO)); + return true; + } + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1238,27 +1415,36 @@ class SpanTwoParamConstructorGadget : public WarningGadget { const CXXConstructExpr *Ctor; // the span constructor expression public: - SpanTwoParamConstructorGadget(const MatchFinder::MatchResult &Result) + SpanTwoParamConstructorGadget(const MatchResult &Result) : WarningGadget(Kind::SpanTwoParamConstructor), - Ctor(Result.Nodes.getNodeAs<CXXConstructExpr>( - SpanTwoParamConstructorTag)) {} + Ctor(Result.getNodeAs<CXXConstructExpr>(SpanTwoParamConstructorTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::SpanTwoParamConstructor; } - static Matcher matcher() { - auto HasTwoParamSpanCtorDecl = hasDeclaration( - cxxConstructorDecl(hasDeclContext(isInStdNamespace()), hasName("span"), - parameterCountIs(2))); - - return stmt(cxxConstructExpr(HasTwoParamSpanCtorDecl, - unless(isSafeSpanTwoParamConstruct())) - .bind(SpanTwoParamConstructorTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *CE = dyn_cast<CXXConstructExpr>(S); + if (!CE) + return false; + const auto *CDecl = CE->getConstructor(); + const auto *DCtx = CDecl->getDeclContext(); + auto HasTwoParamSpanCtorDecl = + Decl::castFromDeclContext(DCtx)->isInStdNamespace() && + CDecl->getDeclName().getAsString() == "span" && CE->getNumArgs() == 2; + if (!HasTwoParamSpanCtorDecl || isSafeSpanTwoParamConstruct(*CE, Ctx)) + return false; + Result.addNode(SpanTwoParamConstructorTag, DynTypedNode::create(*CE)); + return true; } - static Matcher matcher(const UnsafeBufferUsageHandler *Handler) { - return stmt(unless(ignoreUnsafeBufferInContainer(Handler)), matcher()); + static bool matches(const Stmt *S, const ASTContext &Ctx, + const UnsafeBufferUsageHandler *Handler, + MatchResult &Result) { + if (ignoreUnsafeBufferInContainer(*S, Handler)) + return false; + return matches(S, Ctx, Result); } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1291,23 +1477,34 @@ class PointerInitGadget : public FixableGadget { const DeclRefExpr *PtrInitRHS; // the RHS pointer expression in `PI` public: - PointerInitGadget(const MatchFinder::MatchResult &Result) + PointerInitGadget(const MatchResult &Result) : FixableGadget(Kind::PointerInit), - PtrInitLHS(Result.Nodes.getNodeAs<VarDecl>(PointerInitLHSTag)), - PtrInitRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerInitRHSTag)) {} + PtrInitLHS(Result.getNodeAs<VarDecl>(PointerInitLHSTag)), + PtrInitRHS(Result.getNodeAs<DeclRefExpr>(PointerInitRHSTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PointerInit; } - static Matcher matcher() { - auto PtrInitStmt = declStmt(hasSingleDecl( - varDecl(hasInitializer(ignoringImpCasts( - declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerInitRHSTag)))) - .bind(PointerInitLHSTag))); - - return stmt(PtrInitStmt); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + const DeclStmt *DS = dyn_cast<DeclStmt>(S); + if (!DS || !DS->isSingleDecl()) + return false; + const VarDecl *VD = dyn_cast<VarDecl>(DS->getSingleDecl()); + if (!VD) + return false; + const Expr *Init = VD->getAnyInitializer(); + if (!Init) + return false; + const auto *DRE = dyn_cast<DeclRefExpr>(Init->IgnoreImpCasts()); + if (!DRE || !hasPointerType(*DRE) || !toSupportedVariable(*DRE)) { + return false; + } + MatchResult R; + R.addNode(PointerInitLHSTag, DynTypedNode::create(*VD)); + R.addNode(PointerInitRHSTag, DynTypedNode::create(*DRE)); + Results.emplace_back(R); + return true; } virtual std::optional<FixItList> @@ -1339,25 +1536,40 @@ class PtrToPtrAssignmentGadget : public FixableGadget { const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA` public: - PtrToPtrAssignmentGadget(const MatchFinder::MatchResult &Result) + PtrToPtrAssignmentGadget(const MatchResult &Result) : FixableGadget(Kind::PtrToPtrAssignment), - PtrLHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)), - PtrRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {} + PtrLHS(Result.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)), + PtrRHS(Result.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PtrToPtrAssignment; } - static Matcher matcher() { - auto PtrAssignExpr = binaryOperator( - allOf(hasOperatorName("="), - hasRHS(ignoringParenImpCasts( - declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerAssignRHSTag))), - hasLHS(declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerAssignLHSTag)))); - - return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr)); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedUntypedContext(S, [&Found, &Results](const Stmt *S) { + const auto *BO = dyn_cast<BinaryOperator>(S); + if (!BO || BO->getOpcode() != BO_Assign) + return; + const auto *RHS = BO->getRHS()->IgnoreParenImpCasts(); + if (const auto *RHSRef = dyn_cast<DeclRefExpr>(RHS); + !RHSRef || !hasPointerType(*RHSRef) || + !toSupportedVariable(*RHSRef)) { + return; + } + const auto *LHS = BO->getLHS(); + if (const auto *LHSRef = dyn_cast<DeclRefExpr>(LHS); + !LHSRef || !hasPointerType(*LHSRef) || + !toSupportedVariable(*LHSRef)) { + return; + } + MatchResult R; + R.addNode(PointerAssignLHSTag, DynTypedNode::create(*LHS)); + R.addNode(PointerAssignRHSTag, DynTypedNode::create(*RHS)); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1388,26 +1600,41 @@ class CArrayToPtrAssignmentGadget : public FixableGadget { const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA` public: - CArrayToPtrAssignmentGadget(const MatchFinder::MatchResult &Result) + CArrayToPtrAssignmentGadget(const MatchResult &Result) : FixableGadget(Kind::CArrayToPtrAssignment), - PtrLHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)), - PtrRHS(Result.Nodes.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {} + PtrLHS(Result.getNodeAs<DeclRefExpr>(PointerAssignLHSTag)), + PtrRHS(Result.getNodeAs<DeclRefExpr>(PointerAssignRHSTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::CArrayToPtrAssignment; } - static Matcher matcher() { - auto PtrAssignExpr = binaryOperator( - allOf(hasOperatorName("="), - hasRHS(ignoringParenImpCasts( - declRefExpr(hasType(hasCanonicalType(constantArrayType())), - toSupportedVariable()) - .bind(PointerAssignRHSTag))), - hasLHS(declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerAssignLHSTag)))); - - return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr)); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedUntypedContext(S, [&Found, &Results](const Stmt *S) { + const auto *BO = dyn_cast<BinaryOperator>(S); + if (!BO || BO->getOpcode() != BO_Assign) + return; + const auto *RHS = BO->getRHS()->IgnoreParenImpCasts(); + if (const auto *RHSRef = dyn_cast<DeclRefExpr>(RHS); + !RHSRef || + !isa<ConstantArrayType>(RHSRef->getType().getCanonicalType()) || + !toSupportedVariable(*RHSRef)) { + return; + } + const auto *LHS = BO->getLHS(); + if (const auto *LHSRef = dyn_cast<DeclRefExpr>(LHS); + !LHSRef || !hasPointerType(*LHSRef) || + !toSupportedVariable(*LHSRef)) { + return; + } + MatchResult R; + R.addNode(PointerAssignLHSTag, DynTypedNode::create(*LHS)); + R.addNode(PointerAssignRHSTag, DynTypedNode::create(*RHS)); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1431,23 +1658,32 @@ class UnsafeBufferUsageAttrGadget : public WarningGadget { const Expr *Op; public: - UnsafeBufferUsageAttrGadget(const MatchFinder::MatchResult &Result) + UnsafeBufferUsageAttrGadget(const MatchResult &Result) : WarningGadget(Kind::UnsafeBufferUsageAttr), - Op(Result.Nodes.getNodeAs<Expr>(OpTag)) {} + Op(Result.getNodeAs<Expr>(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::UnsafeBufferUsageAttr; } - static Matcher matcher() { - auto HasUnsafeFieldDecl = - member(fieldDecl(hasAttr(attr::UnsafeBufferUsage))); - - auto HasUnsafeFnDecl = - callee(functionDecl(hasAttr(attr::UnsafeBufferUsage))); - - return stmt(anyOf(callExpr(HasUnsafeFnDecl).bind(OpTag), - memberExpr(HasUnsafeFieldDecl).bind(OpTag))); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + if (auto *CE = dyn_cast<CallExpr>(S)) { + if (CE->getDirectCallee() && + CE->getDirectCallee()->hasAttr<UnsafeBufferUsageAttr>()) { + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; + } + } + if (auto *ME = dyn_cast<MemberExpr>(S)) { + if (!isa<FieldDecl>(ME->getMemberDecl())) + return false; + if (ME->getMemberDecl()->hasAttr<UnsafeBufferUsageAttr>()) { + Result.addNode(OpTag, DynTypedNode::create(*ME)); + return true; + } + } + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1460,30 +1696,32 @@ class UnsafeBufferUsageAttrGadget : public WarningGadget { DeclUseList getClaimedVarUseSites() const override { return {}; } }; -/// A call of a constructor that performs unchecked buffer operations -/// over one of its pointer parameters, or constructs a class object that will -/// perform buffer operations that depend on the correctness of the parameters. +// A call of a constructor that performs unchecked buffer operations +// over one of its pointer parameters, or constructs a class object that will +// perform buffer operations that depend on the correctness of the parameters. class UnsafeBufferUsageCtorAttrGadget : public WarningGadget { constexpr static const char *const OpTag = "cxx_construct_expr"; const CXXConstructExpr *Op; public: - UnsafeBufferUsageCtorAttrGadget(const MatchFinder::MatchResult &Result) + UnsafeBufferUsageCtorAttrGadget(const MatchResult &Result) : WarningGadget(Kind::UnsafeBufferUsageCtorAttr), - Op(Result.Nodes.getNodeAs<CXXConstructExpr>(OpTag)) {} + Op(Result.getNodeAs<CXXConstructExpr>(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::UnsafeBufferUsageCtorAttr; } - static Matcher matcher() { - auto HasUnsafeCtorDecl = - hasDeclaration(cxxConstructorDecl(hasAttr(attr::UnsafeBufferUsage))); - // std::span(ptr, size) ctor is handled by SpanTwoParamConstructorGadget. - auto HasTwoParamSpanCtorDecl = SpanTwoParamConstructorGadget::matcher(); - return stmt( - cxxConstructExpr(HasUnsafeCtorDecl, unless(HasTwoParamSpanCtorDecl)) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *CE = dyn_cast<CXXConstructExpr>(S); + if (!CE || !CE->getConstructor()->hasAttr<UnsafeBufferUsageAttr>()) + return false; + MatchResult tmp; + if (SpanTwoParamConstructorGadget::matches(CE, Ctx, tmp)) + return false; + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1505,23 +1743,51 @@ class DataInvocationGadget : public WarningGadget { const ExplicitCastExpr *Op; public: - DataInvocationGadget(const MatchFinder::MatchResult &Result) + DataInvocationGadget(const MatchResult &Result) : WarningGadget(Kind::DataInvocation), - Op(Result.Nodes.getNodeAs<ExplicitCastExpr>(OpTag)) {} + Op(Result.getNodeAs<ExplicitCastExpr>(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::DataInvocation; } - static Matcher matcher() { + static bool checkCallExpr(const CXXMemberCallExpr *call) { + if (!call) + return false; + auto *callee = call->getDirectCallee(); + if (!callee || !isa<CXXMethodDecl>(callee)) + return false; + auto *method = cast<CXXMethodDecl>(callee); + if (method->getNameAsString() == "data" && + (method->getParent()->getQualifiedNameAsString() == "std::span" || + method->getParent()->getQualifiedNameAsString() == "std::array" || + method->getParent()->getQualifiedNameAsString() == "std::vector")) + return true; + return false; + } - Matcher callExpr = cxxMemberCallExpr(callee( - cxxMethodDecl(hasName("data"), - ofClass(anyOf(hasName("std::span"), hasName("std::array"), - hasName("std::vector")))))); - return stmt( - explicitCastExpr(anyOf(has(callExpr), has(parenExpr(has(callExpr))))) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + auto *CE = dyn_cast<ExplicitCastExpr>(S); + if (!CE) + return false; + for (auto *child : CE->children()) { + if (auto *MCE = dyn_cast<CXXMemberCallExpr>(child); + MCE && checkCallExpr(MCE)) { + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; + } + if (auto *paren = dyn_cast<ParenExpr>(child)) { + for (auto *grandchild : paren->children()) { + if (auto *MCE = dyn_cast<CXXMemberCallExpr>(grandchild); + MCE && checkCallExpr(MCE)) { + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; + } + } + } + } + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1561,56 +1827,70 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget { } WarnedFunKind = OTHERS; public: - UnsafeLibcFunctionCallGadget(const MatchFinder::MatchResult &Result) + UnsafeLibcFunctionCallGadget(const MatchResult &Result) : WarningGadget(Kind::UnsafeLibcFunctionCall), - Call(Result.Nodes.getNodeAs<CallExpr>(Tag)) { - if (Result.Nodes.getNodeAs<Decl>(UnsafeSprintfTag)) + Call(Result.getNodeAs<CallExpr>(Tag)) { + if (Result.getNodeAs<Decl>(UnsafeSprintfTag)) WarnedFunKind = SPRINTF; - else if (auto *E = Result.Nodes.getNodeAs<Expr>(UnsafeStringTag)) { + else if (auto *E = Result.getNodeAs<Expr>(UnsafeStringTag)) { WarnedFunKind = STRING; UnsafeArg = E; - } else if (Result.Nodes.getNodeAs<CallExpr>(UnsafeSizedByTag)) { + } else if (Result.getNodeAs<CallExpr>(UnsafeSizedByTag)) { WarnedFunKind = SIZED_BY; UnsafeArg = Call->getArg(0); - } else if (Result.Nodes.getNodeAs<Decl>(UnsafeVaListTag)) + } else if (Result.getNodeAs<Decl>(UnsafeVaListTag)) WarnedFunKind = VA_LIST; } - static Matcher matcher(const UnsafeBufferUsageHandler *Handler) { - return stmt(unless(ignoreUnsafeLibcCall(Handler)), - anyOf( - callExpr( - callee(functionDecl(anyOf( - // Match a predefined unsafe libc - // function: - functionDecl(libc_func_matchers::isPredefinedUnsafeLibcFunc()), - // Match a call to one of the `v*printf` functions - // taking va-list, which cannot be checked at - // compile-time: - functionDecl(libc_func_matchers::isUnsafeVaListPrintfFunc()) - .bind(UnsafeVaListTag), - // Match a call to a `sprintf` function, which is never - // safe: - functionDecl(libc_func_matchers::isUnsafeSprintfFunc()) - .bind(UnsafeSprintfTag)))), - // (unless the call has a sole string literal argument): - unless( - allOf(hasArgument(0, expr(stringLiteral())), hasNumArgs(1)))), - - // The following two cases require checking against actual - // arguments of the call: - - // Match a call to an `snprintf` function. And first two - // arguments of the call (that describe a buffer) are not in - // safe patterns: - callExpr(callee(functionDecl(libc_func_matchers::isNormalPrintfFunc())), - libc_func_matchers::hasUnsafeSnprintfBuffer()) - .bind(UnsafeSizedByTag), - // Match a call to a `printf` function, which can be safe if - // all arguments are null-terminated: - callExpr(callee(functionDecl(libc_func_matchers::isNormalPrintfFunc())), - libc_func_matchers::hasUnsafePrintfStringArg( - expr().bind(UnsafeStringTag))))); + static bool matches(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler *Handler, + MatchResult &Result) { + if (ignoreUnsafeLibcCall(Ctx, *S, Handler)) + return false; + auto *CE = dyn_cast<CallExpr>(S); + if (!CE || !CE->getDirectCallee()) + return false; + const auto *FD = dyn_cast<FunctionDecl>(CE->getDirectCallee()); + if (!FD) + return false; + auto isSingleStringLiteralArg = false; + if (CE->getNumArgs() == 1) { + const auto *const Arg = CE->getArg(0); + if (isa<Expr>(Arg) && !Arg->children().empty()) { + isSingleStringLiteralArg = + isa<clang::StringLiteral>(*Arg->children().begin()); + } + } + if (!isSingleStringLiteralArg) { // (unless the call has a sole string + // literal argument): + if (libc_func_matchers::isPredefinedUnsafeLibcFunc(*FD)) { + Result.addNode(Tag, DynTypedNode::create(*CE)); + return true; + } + if (libc_func_matchers::isUnsafeVaListPrintfFunc(*FD)) { + Result.addNode(Tag, DynTypedNode::create(*CE)); + Result.addNode(UnsafeVaListTag, DynTypedNode::create(*FD)); + return true; + } + if (libc_func_matchers::isUnsafeSprintfFunc(*FD)) { + Result.addNode(Tag, DynTypedNode::create(*CE)); + Result.addNode(UnsafeSprintfTag, DynTypedNode::create(*FD)); + return true; + } + } + if (libc_func_matchers::isNormalPrintfFunc(*FD)) { + if (libc_func_matchers::hasUnsafeSnprintfBuffer(*CE, Ctx)) { + Result.addNode(Tag, DynTypedNode::create(*CE)); + Result.addNode(UnsafeSizedByTag, DynTypedNode::create(*CE)); + return true; + } + if (libc_func_matchers::hasUnsafePrintfStringArg(*CE, Ctx, Result, + UnsafeStringTag)) { + Result.addNode(Tag, DynTypedNode::create(*CE)); + return true; + } + } + return false; } const Stmt *getBaseStmt() const { return Call; } @@ -1627,7 +1907,7 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget { }; // Represents expressions of the form `DRE[*]` in the Unspecified Lvalue -// Context (see `isInUnspecifiedLvalueContext`). +// Context (see `findStmtsInUnspecifiedLvalueContext`). // Note here `[]` is the built-in subscript operator. class ULCArraySubscriptGadget : public FixableGadget { private: @@ -1636,9 +1916,9 @@ class ULCArraySubscriptGadget : public FixableGadget { const ArraySubscriptExpr *Node; public: - ULCArraySubscriptGadget(const MatchFinder::MatchResult &Result) + ULCArraySubscriptGadget(const MatchResult &Result) : FixableGadget(Kind::ULCArraySubscript), - Node(Result.Nodes.getNodeAs<ArraySubscriptExpr>(ULCArraySubscriptTag)) { + Node(Result.getNodeAs<ArraySubscriptExpr>(ULCArraySubscriptTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -1646,14 +1926,23 @@ class ULCArraySubscriptGadget : public FixableGadget { return G->getKind() == Kind::ULCArraySubscript; } - static Matcher matcher() { - auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType()); - auto BaseIsArrayOrPtrDRE = hasBase( - ignoringParenImpCasts(declRefExpr(ArrayOrPtr, toSupportedVariable()))); - auto Target = - arraySubscriptExpr(BaseIsArrayOrPtrDRE).bind(ULCArraySubscriptTag); - - return expr(isInUnspecifiedLvalueContext(Target)); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedLvalueContext(S, [&Found, &Results](const Expr *E) { + const auto *ASE = dyn_cast<ArraySubscriptExpr>(E); + if (!ASE) + return; + const auto *DRE = + dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts()); + if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)) || + !toSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(ULCArraySubscriptTag, DynTypedNode::create(*ASE)); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1670,17 +1959,17 @@ class ULCArraySubscriptGadget : public FixableGadget { }; // Fixable gadget to handle stand alone pointers of the form `UPC(DRE)` in the -// unspecified pointer context (isInUnspecifiedPointerContext). The gadget emits -// fixit of the form `UPC(DRE.data())`. +// unspecified pointer context (findStmtsInUnspecifiedPointerContext). The +// gadget emits fixit of the form `UPC(DRE.data())`. class UPCStandalonePointerGadget : public FixableGadget { private: static constexpr const char *const DeclRefExprTag = "StandalonePointer"; const DeclRefExpr *Node; public: - UPCStandalonePointerGadget(const MatchFinder::MatchResult &Result) + UPCStandalonePointerGadget(const MatchResult &Result) : FixableGadget(Kind::UPCStandalonePointer), - Node(Result.Nodes.getNodeAs<DeclRefExpr>(DeclRefExprTag)) { + Node(Result.getNodeAs<DeclRefExpr>(DeclRefExprTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -1688,12 +1977,22 @@ class UPCStandalonePointerGadget : public FixableGadget { return G->getKind() == Kind::UPCStandalonePointer; } - static Matcher matcher() { - auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType()); - auto target = expr(ignoringParenImpCasts( - declRefExpr(allOf(ArrayOrPtr, toSupportedVariable())) - .bind(DeclRefExprTag))); - return stmt(isInUnspecifiedPointerContext(target)); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedPointerContext(S, [&Found, &Results](const Stmt *S) { + auto *E = dyn_cast<Expr>(S); + if (!E) + return; + const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenImpCasts()); + if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)) || + !toSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(DeclRefExprTag, DynTypedNode::create(*DRE)); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1711,25 +2010,38 @@ class PointerDereferenceGadget : public FixableGadget { const UnaryOperator *Op = nullptr; public: - PointerDereferenceGadget(const MatchFinder::MatchResult &Result) + PointerDereferenceGadget(const MatchResult &Result) : FixableGadget(Kind::PointerDereference), - BaseDeclRefExpr( - Result.Nodes.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)), - Op(Result.Nodes.getNodeAs<UnaryOperator>(OperatorTag)) {} + BaseDeclRefExpr(Result.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)), + Op(Result.getNodeAs<UnaryOperator>(OperatorTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PointerDereference; } - static Matcher matcher() { - auto Target = - unaryOperator( - hasOperatorName("*"), - has(expr(ignoringParenImpCasts( - declRefExpr(toSupportedVariable()).bind(BaseDeclRefExprTag))))) - .bind(OperatorTag); - - return expr(isInUnspecifiedLvalueContext(Target)); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedLvalueContext(S, [&Found, &Results](const Stmt *S) { + const auto *UO = dyn_cast<UnaryOperator>(S); + if (!UO || UO->getOpcode() != UO_Deref) + return; + for (const auto *Child : UO->children()) { + const auto *CE = dyn_cast<Expr>(Child); + if (!CE) + continue; + CE = CE->IgnoreParenImpCasts(); + const auto *DRE = dyn_cast<DeclRefExpr>(CE); + if (!DRE || !toSupportedVariable(*DRE)) + continue; + MatchResult R; + R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE)); + R.addNode(OperatorTag, DynTypedNode::create(*UO)); + Results.emplace_back(R); + Found = true; + return; + } + }); + return Found; } DeclUseList getClaimedVarUseSites() const override { @@ -1742,7 +2054,7 @@ class PointerDereferenceGadget : public FixableGadget { }; // Represents expressions of the form `&DRE[any]` in the Unspecified Pointer -// Context (see `isInUnspecifiedPointerContext`). +// Context (see `findStmtsInUnspecifiedPointerContext`). // Note here `[]` is the built-in subscript operator. class UPCAddressofArraySubscriptGadget : public FixableGadget { private: @@ -1751,10 +2063,9 @@ class UPCAddressofArraySubscriptGadget : public FixableGadget { const UnaryOperator *Node; // the `&DRE[any]` node public: - UPCAddressofArraySubscriptGadget(const MatchFinder::MatchResult &Result) + UPCAddressofArraySubscriptGadget(const MatchResult &Result) : FixableGadget(Kind::ULCArraySubscript), - Node(Result.Nodes.getNodeAs<UnaryOperator>( - UPCAddressofArraySubscriptTag)) { + Node(Result.getNodeAs<UnaryOperator>(UPCAddressofArraySubscriptTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -1762,13 +2073,28 @@ class UPCAddressofArraySubscriptGadget : public FixableGadget { return G->getKind() == Kind::UPCAddressofArraySubscript; } - static Matcher matcher() { - return expr(isInUnspecifiedPointerContext(expr(ignoringImpCasts( - unaryOperator( - hasOperatorName("&"), - hasUnaryOperand(arraySubscriptExpr(hasBase( - ignoringParenImpCasts(declRefExpr(toSupportedVariable())))))) - .bind(UPCAddressofArraySubscriptTag))))); + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedPointerContext(S, [&Found, &Results](const Stmt *S) { + auto *E = dyn_cast<Expr>(S); + if (!E) + return; + const auto *UO = dyn_cast<UnaryOperator>(E->IgnoreImpCasts()); + if (!UO || UO->getOpcode() != UO_AddrOf) + return; + const auto *ASE = dyn_cast<ArraySubscriptExpr>(UO->getSubExpr()); + if (!ASE) + return; + const auto *DRE = + dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts()); + if (!DRE || !toSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(UPCAddressofArraySubscriptTag, DynTypedNode::create(*UO)); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1859,9 +2185,9 @@ class UPCPreIncrementGadget : public FixableGadget { const UnaryOperator *Node; // the `++Ptr` node public: - UPCPreIncrementGadget(const MatchFinder::MatchResult &Result) + UPCPreIncrementGadget(const MatchResult &Result) : FixableGadget(Kind::UPCPreIncrement), - Node(Result.Nodes.getNodeAs<UnaryOperator>(UPCPreIncrementTag)) { + Node(Result.getNodeAs<UnaryOperator>(UPCPreIncrementTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -1869,15 +2195,28 @@ class UPCPreIncrementGadget : public FixableGadget { return G->getKind() == Kind::UPCPreIncrement; } - static Matcher matcher() { + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { // Note here we match `++Ptr` for any expression `Ptr` of pointer type. // Although currently we can only provide fix-its when `Ptr` is a DRE, we // can have the matcher be general, so long as `getClaimedVarUseSites` does // things right. - return stmt(isInUnspecifiedPointerContext(expr(ignoringImpCasts( - unaryOperator(isPreInc(), - hasUnaryOperand(declRefExpr(toSupportedVariable()))) - .bind(UPCPreIncrementTag))))); + bool Found = false; + findStmtsInUnspecifiedPointerContext(S, [&Found, &Results](const Stmt *S) { + auto *E = dyn_cast<Expr>(S); + if (!E) + return; + const auto *UO = dyn_cast<UnaryOperator>(E->IgnoreImpCasts()); + if (!UO || UO->getOpcode() != UO_PreInc) + return; + const auto *DRE = dyn_cast<DeclRefExpr>(UO->getSubExpr()); + if (!DRE || !toSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(UPCPreIncrementTag, DynTypedNode::create(*UO)); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1901,10 +2240,10 @@ class UUCAddAssignGadget : public FixableGadget { const Expr *Offset = nullptr; public: - UUCAddAssignGadget(const MatchFinder::MatchResult &Result) + UUCAddAssignGadget(const MatchResult &Result) : FixableGadget(Kind::UUCAddAssign), - Node(Result.Nodes.getNodeAs<BinaryOperator>(UUCAddAssignTag)), - Offset(Result.Nodes.getNodeAs<Expr>(OffsetTag)) { + Node(Result.getNodeAs<BinaryOperator>(UUCAddAssignTag)), + Offset(Result.getNodeAs<Expr>(OffsetTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -1912,17 +2251,26 @@ class UUCAddAssignGadget : public FixableGadget { return G->getKind() == Kind::UUCAddAssign; } - static Matcher matcher() { - // clang-format off - return stmt(isInUnspecifiedUntypedContext(expr(ignoringImpCasts( - binaryOperator(hasOperatorName("+="), - hasLHS( - declRefExpr( - hasPointerType(), - toSupportedVariable())), - hasRHS(expr().bind(OffsetTag))) - .bind(UUCAddAssignTag))))); - // clang-format on + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + findStmtsInUnspecifiedUntypedContext(S, [&Found, &Results](const Stmt *S) { + const auto *E = dyn_cast<Expr>(S); + if (!E) + return; + const auto *BO = dyn_cast<BinaryOperator>(E->IgnoreImpCasts()); + if (!BO || BO->getOpcode() != BO_AddAssign) + return; + const auto *DRE = dyn_cast<DeclRefExpr>(BO->getLHS()); + if (!DRE || !hasPointerType(*DRE) || !toSupportedVariable(*DRE) || + !isa<Expr>(BO->getRHS())) + return; + MatchResult R; + R.addNode(UUCAddAssignTag, DynTypedNode::create(*BO)); + R.addNode(OffsetTag, DynTypedNode::create(*BO->getRHS())); + Results.emplace_back(R); + Found = true; + }); + return Found; } virtual std::optional<FixItList> @@ -1948,31 +2296,59 @@ class DerefSimplePtrArithFixableGadget : public FixableGadget { const IntegerLiteral *Offset = nullptr; public: - DerefSimplePtrArithFixableGadget(const MatchFinder::MatchResult &Result) + DerefSimplePtrArithFixableGadget(const MatchResult &Result) : FixableGadget(Kind::DerefSimplePtrArithFixable), - BaseDeclRefExpr( - Result.Nodes.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)), - DerefOp(Result.Nodes.getNodeAs<UnaryOperator>(DerefOpTag)), - AddOp(Result.Nodes.getNodeAs<BinaryOperator>(AddOpTag)), - Offset(Result.Nodes.getNodeAs<IntegerLiteral>(OffsetTag)) {} - - static Matcher matcher() { - // clang-format off - auto ThePtr = expr(hasPointerType(), - ignoringImpCasts(declRefExpr(toSupportedVariable()). - bind(BaseDeclRefExprTag))); - auto PlusOverPtrAndInteger = expr(anyOf( - binaryOperator(hasOperatorName("+"), hasLHS(ThePtr), - hasRHS(integerLiteral().bind(OffsetTag))) - .bind(AddOpTag), - binaryOperator(hasOperatorName("+"), hasRHS(ThePtr), - hasLHS(integerLiteral().bind(OffsetTag))) - .bind(AddOpTag))); - return isInUnspecifiedLvalueContext(unaryOperator( - hasOperatorName("*"), - hasUnaryOperand(ignoringParens(PlusOverPtrAndInteger))) - .bind(DerefOpTag)); - // clang-format on + BaseDeclRefExpr(Result.getNodeAs<DeclRefExpr>(BaseDeclRefExprTag)), + DerefOp(Result.getNodeAs<UnaryOperator>(DerefOpTag)), + AddOp(Result.getNodeAs<BinaryOperator>(AddOpTag)), + Offset(Result.getNodeAs<IntegerLiteral>(OffsetTag)) {} + + static bool matches(const Stmt *S, llvm::SmallVector<MatchResult> &Results) { + bool Found = false; + auto IsPtr = [](const Expr *E, MatchResult &R) { + if (!E || !hasPointerType(*E)) + return false; + const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreImpCasts()); + if (!DRE || !toSupportedVariable(*DRE)) + return false; + R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE)); + return true; + }; + const auto PlusOverPtrAndInteger = [&IsPtr](const Expr *E, MatchResult &R) { + const auto *BO = dyn_cast<BinaryOperator>(E); + if (!BO || BO->getOpcode() != BO_Add) + return false; + + const auto *LHS = BO->getLHS(); + const auto *RHS = BO->getRHS(); + if (isa<IntegerLiteral>(RHS) && IsPtr(LHS, R)) { + R.addNode(OffsetTag, DynTypedNode::create(*RHS)); + R.addNode(AddOpTag, DynTypedNode::create(*BO)); + return true; + } + if (isa<IntegerLiteral>(LHS) && IsPtr(RHS, R)) { + R.addNode(OffsetTag, DynTypedNode::create(*LHS)); + R.addNode(AddOpTag, DynTypedNode::create(*BO)); + return true; + } + return false; + }; + const auto InnerMatcher = [&PlusOverPtrAndInteger, &Found, + &Results](const Expr *E) { + const auto *UO = dyn_cast<UnaryOperator>(E); + if (!UO || UO->getOpcode() != UO_Deref) + return; + + const auto *Operand = UO->getSubExpr()->IgnoreParens(); + MatchResult R; + if (PlusOverPtrAndInteger(Operand, R)) { + R.addNode(DerefOpTag, DynTypedNode::create(*UO)); + Results.emplace_back(R); + Found = true; + } + }; + findStmtsInUnspecifiedLvalueContext(S, InnerMatcher); + return Found; } virtual std::optional<FixItList> @@ -1986,112 +2362,119 @@ class DerefSimplePtrArithFixableGadget : public FixableGadget { } }; -/// Scan the function and return a list of gadgets found with provided kits. -static void findGadgets(const Stmt *S, ASTContext &Ctx, - const UnsafeBufferUsageHandler &Handler, - bool EmitSuggestions, FixableGadgetList &FixableGadgets, - WarningGadgetList &WarningGadgets, - DeclUseTracker &Tracker) { +class EvaluatedStmtMatcher : public FastMatcher { - struct GadgetFinderCallback : MatchFinder::MatchCallback { - GadgetFinderCallback(FixableGadgetList &FixableGadgets, - WarningGadgetList &WarningGadgets, - DeclUseTracker &Tracker) - : FixableGadgets(FixableGadgets), WarningGadgets(WarningGadgets), - Tracker(Tracker) {} - - void run(const MatchFinder::MatchResult &Result) override { - // In debug mode, assert that we've found exactly one gadget. - // This helps us avoid conflicts in .bind() tags. -#if NDEBUG -#define NEXT return -#else - [[maybe_unused]] int numFound = 0; -#define NEXT ++numFound -#endif - - if (const auto *DRE = Result.Nodes.getNodeAs<DeclRefExpr>("any_dre")) { - Tracker.discoverUse(DRE); - NEXT; - } +public: + EvaluatedStmtMatcher(WarningGadgetList &WarningGadgets) + : WarningGadgets(WarningGadgets) {} - if (const auto *DS = Result.Nodes.getNodeAs<DeclStmt>("any_ds")) { - Tracker.discoverDecl(DS); - NEXT; - } + bool matches(const DynTypedNode &DynNode, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler) override { + const Stmt *S = DynNode.get<Stmt>(); + if (!S) + return false; - // Figure out which matcher we've found, and call the appropriate - // subclass constructor. - // FIXME: Can we do this more logarithmically? -#define FIXABLE_GADGET(name) \ - if (Result.Nodes.getNodeAs<Stmt>(#name)) { \ - FixableGadgets.push_back(std::make_unique<name##Gadget>(Result)); \ - NEXT; \ - } -#include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" + MatchResult Result; #define WARNING_GADGET(name) \ - if (Result.Nodes.getNodeAs<Stmt>(#name)) { \ + if (name##Gadget::matches(S, Ctx, Result) && \ + notInSafeBufferOptOut(*S, &Handler)) { \ + WarningGadgets.push_back(std::make_unique<name##Gadget>(Result)); \ + return true; \ + } +#define WARNING_OPTIONAL_GADGET(name) \ + if (name##Gadget::matches(S, Ctx, &Handler, Result) && \ + notInSafeBufferOptOut(*S, &Handler)) { \ WarningGadgets.push_back(std::make_unique<name##Gadget>(Result)); \ - NEXT; \ + return true; \ } #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" + return false; + } - assert(numFound >= 1 && "Gadgets not found in match result!"); - assert(numFound <= 1 && "Conflicting bind tags in gadgets!"); - } +private: + WarningGadgetList &WarningGadgets; +}; - FixableGadgetList &FixableGadgets; - WarningGadgetList &WarningGadgets; - DeclUseTracker &Tracker; - }; +class StmtMatcher : public FastMatcher { - MatchFinder M; - GadgetFinderCallback CB{FixableGadgets, WarningGadgets, Tracker}; - - // clang-format off - M.addMatcher( - stmt( - forEachDescendantEvaluatedStmt(stmt(anyOf( - // Add Gadget::matcher() for every gadget in the registry. -#define WARNING_GADGET(x) \ - allOf(x ## Gadget::matcher().bind(#x), \ - notInSafeBufferOptOut(&Handler)), -#define WARNING_OPTIONAL_GADGET(x) \ - allOf(x ## Gadget::matcher(&Handler).bind(#x), \ - notInSafeBufferOptOut(&Handler)), -#include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" - // Avoid a hanging comma. - unless(stmt()) - ))) - ), - &CB - ); - // clang-format on +public: + StmtMatcher(FixableGadgetList &FixableGadgets, DeclUseTracker &Tracker) + : FixableGadgets(FixableGadgets), Tracker(Tracker) {} + + // Match all DeclRefExprs so that to find out + // whether there are any uncovered by gadgets. + bool matchDeclRefExprs(const Stmt *S, MatchResult &Result) { + const auto *DRE = dyn_cast<DeclRefExpr>(S); + if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE))) + return false; + const Decl *D = DRE->getDecl(); + if (!D || (!isa<VarDecl>(D) && !isa<BindingDecl>(D))) + return false; + Result.addNode("any_dre", DynTypedNode::create(*DRE)); + return true; + } - if (EmitSuggestions) { - // clang-format off - M.addMatcher( - stmt( - forEachDescendantStmt(stmt(eachOf( -#define FIXABLE_GADGET(x) \ - x ## Gadget::matcher().bind(#x), + // Also match DeclStmts because we'll need them when fixing + // their underlying VarDecls that otherwise don't have + // any backreferences to DeclStmts. + bool matchDeclStmt(const Stmt *S, MatchResult &Result) { + const auto *DS = dyn_cast<DeclStmt>(S); + if (!DS) + return false; + Result.addNode("any_ds", DynTypedNode::create(*DS)); + return true; + } + + bool matches(const DynTypedNode &DynNode, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler) override { + bool matchFound = false; + const Stmt *S = DynNode.get<Stmt>(); + if (!S) { + return matchFound; + } + + llvm::SmallVector<MatchResult> Results; +#define FIXABLE_GADGET(name) \ + if (name##Gadget::matches(S, Results)) { \ + for (const auto &R : Results) { \ + FixableGadgets.push_back(std::make_unique<name##Gadget>(R)); \ + matchFound = true; \ + } \ + Results = {}; \ + } #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" - // In parallel, match all DeclRefExprs so that to find out - // whether there are any uncovered by gadgets. - declRefExpr(anyOf(hasPointerType(), hasArrayType()), - to(anyOf(varDecl(), bindingDecl()))).bind("any_dre"), - // Also match DeclStmts because we'll need them when fixing - // their underlying VarDecls that otherwise don't have - // any backreferences to DeclStmts. - declStmt().bind("any_ds") - ))) - ), - &CB - ); - // clang-format on + + MatchResult Result; + if (matchDeclRefExprs(S, Result)) { + const auto *DRE = Result.getNodeAs<DeclRefExpr>("any_dre"); + Tracker.discoverUse(DRE); + matchFound = true; + } + if (matchDeclStmt(S, Result)) { + const auto *DS = Result.getNodeAs<DeclStmt>("any_ds"); + Tracker.discoverDecl(DS); + matchFound = true; + } + return matchFound; } - M.match(*S, Ctx); +private: + FixableGadgetList &FixableGadgets; + DeclUseTracker &Tracker; +}; + +// Scan the function and return a list of gadgets found with provided kits. +static void findGadgets(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler, + bool EmitSuggestions, FixableGadgetList &FixableGadgets, + WarningGadgetList &WarningGadgets, + DeclUseTracker &Tracker) { + EvaluatedStmtMatcher ESMatcher{WarningGadgets}; + forEachDescendantEvaluatedStmt(S, Ctx, Handler, ESMatcher); + if (EmitSuggestions) { + StmtMatcher SMatcher{FixableGadgets, Tracker}; + forEachDescendantStmt(S, Ctx, Handler, SMatcher); + } } // Compares AST nodes by source locations. @@ -2672,7 +3055,7 @@ static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx, // `DRE.data()` std::optional<FixItList> UPCStandalonePointerGadget::getFixits(const FixitStrategy &S) const { - const auto VD = cast<VarDecl>(Node->getDecl()); + const auto *const VD = cast<VarDecl>(Node->getDecl()); switch (S.lookup(VD)) { case FixitStrategy::Kind::Array: case FixitStrategy::Kind::Span: { @@ -3636,9 +4019,11 @@ class VariableGroupsManagerImpl : public VariableGroupsManager { } }; -void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, - WarningGadgetList WarningGadgets, DeclUseTracker Tracker, - UnsafeBufferUsageHandler &Handler, bool EmitSuggestions) { +static void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, + WarningGadgetList WarningGadgets, + DeclUseTracker Tracker, + UnsafeBufferUsageHandler &Handler, + bool EmitSuggestions) { if (!EmitSuggestions) { // Our job is very easy without suggestions. Just warn about // every problematic operation and consider it done. No need to deal @@ -3650,7 +4035,7 @@ void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, // This return guarantees that most of the machine doesn't run when // suggestions aren't requested. - assert(FixableGadgets.size() == 0 && + assert(FixableGadgets.empty() && "Fixable gadgets found but suggestions not requested!"); return; } @@ -3749,7 +4134,7 @@ void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, DepMapTy DependenciesMap{}; DepMapTy PtrAssignmentGraph{}; - for (auto it : FixablesForAllVars.byVar) { + for (const auto &it : FixablesForAllVars.byVar) { for (const FixableGadget *fixable : it.second) { std::optional<std::pair<const VarDecl *, const VarDecl *>> ImplPair = fixable->getStrategyImplications(); >From 0475beb28231bf182e4d0f2d1f8ec7c8e872f4b4 Mon Sep 17 00:00:00 2001 From: Ivana Ivanovska <iivanov...@google.com> Date: Thu, 6 Feb 2025 12:05:53 +0000 Subject: [PATCH 2/2] Apply review comments --- clang/lib/Analysis/UnsafeBufferUsage.cpp | 44 +++++++++++------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/clang/lib/Analysis/UnsafeBufferUsage.cpp b/clang/lib/Analysis/UnsafeBufferUsage.cpp index 4520d28d9e94522..9d8ee96532e70db 100644 --- a/clang/lib/Analysis/UnsafeBufferUsage.cpp +++ b/clang/lib/Analysis/UnsafeBufferUsage.cpp @@ -95,16 +95,14 @@ class MatchResult { public: template <typename T> const T *getNodeAs(StringRef ID) const { - auto It = Nodes.find(std::string(ID)); + auto It = Nodes.find(ID); if (It == Nodes.end()) { return nullptr; } return It->second.get<T>(); } - void addNode(StringRef ID, const DynTypedNode &Node) { - Nodes[std::string(ID)] = Node; - } + void addNode(StringRef ID, const DynTypedNode &Node) { Nodes[ID] = Node; } private: llvm::StringMap<DynTypedNode> Nodes; @@ -257,8 +255,8 @@ static void forEachDescendantEvaluatedStmt(const Stmt *S, ASTContext &Ctx, const UnsafeBufferUsageHandler &Handler, FastMatcher &Matcher) { - MatchDescendantVisitor Visitor(Matcher, /* FindAll */ true, - /*ignoreUnevaluatedContext*/ true); + MatchDescendantVisitor Visitor(Matcher, /*FindAll=*/true, + /*ignoreUnevaluatedContext=*/true); Visitor.setASTContext(Ctx); Visitor.setHandler(Handler); Visitor.findMatch(DynTypedNode::create(*S)); @@ -267,8 +265,8 @@ forEachDescendantEvaluatedStmt(const Stmt *S, ASTContext &Ctx, static void forEachDescendantStmt(const Stmt *S, ASTContext &Ctx, const UnsafeBufferUsageHandler &Handler, FastMatcher &Matcher) { - MatchDescendantVisitor Visitor(Matcher, /* FindAll */ true, - /*ignoreUnevaluatedContext*/ false); + MatchDescendantVisitor Visitor(Matcher, /*FindAll=*/true, + /*ignoreUnevaluatedContext=*/false); Visitor.setASTContext(Ctx); Visitor.setHandler(Handler); Visitor.findMatch(DynTypedNode::create(*S)); @@ -1197,7 +1195,7 @@ class FixableGadget : public Gadget { } }; -static auto toSupportedVariable(const DeclRefExpr &Node) { +static bool isSupportedVariable(const DeclRefExpr &Node) { const Decl *D = Node.getDecl(); return D != nullptr && isa<VarDecl>(D); } @@ -1357,7 +1355,7 @@ class PointerArithmeticGadget : public WarningGadget { public: PointerArithmeticGadget(const MatchResult &Result) : WarningGadget(Kind::PointerArithmetic), - PA((Result.getNodeAs<BinaryOperator>(PointerArithmeticTag))), + PA(Result.getNodeAs<BinaryOperator>(PointerArithmeticTag)), Ptr(Result.getNodeAs<Expr>(PointerArithmeticPointerTag)) {} static bool classof(const Gadget *G) { @@ -1497,7 +1495,7 @@ class PointerInitGadget : public FixableGadget { if (!Init) return false; const auto *DRE = dyn_cast<DeclRefExpr>(Init->IgnoreImpCasts()); - if (!DRE || !hasPointerType(*DRE) || !toSupportedVariable(*DRE)) { + if (!DRE || !hasPointerType(*DRE) || !isSupportedVariable(*DRE)) { return false; } MatchResult R; @@ -1554,13 +1552,13 @@ class PtrToPtrAssignmentGadget : public FixableGadget { const auto *RHS = BO->getRHS()->IgnoreParenImpCasts(); if (const auto *RHSRef = dyn_cast<DeclRefExpr>(RHS); !RHSRef || !hasPointerType(*RHSRef) || - !toSupportedVariable(*RHSRef)) { + !isSupportedVariable(*RHSRef)) { return; } const auto *LHS = BO->getLHS(); if (const auto *LHSRef = dyn_cast<DeclRefExpr>(LHS); !LHSRef || !hasPointerType(*LHSRef) || - !toSupportedVariable(*LHSRef)) { + !isSupportedVariable(*LHSRef)) { return; } MatchResult R; @@ -1619,13 +1617,13 @@ class CArrayToPtrAssignmentGadget : public FixableGadget { if (const auto *RHSRef = dyn_cast<DeclRefExpr>(RHS); !RHSRef || !isa<ConstantArrayType>(RHSRef->getType().getCanonicalType()) || - !toSupportedVariable(*RHSRef)) { + !isSupportedVariable(*RHSRef)) { return; } const auto *LHS = BO->getLHS(); if (const auto *LHSRef = dyn_cast<DeclRefExpr>(LHS); !LHSRef || !hasPointerType(*LHSRef) || - !toSupportedVariable(*LHSRef)) { + !isSupportedVariable(*LHSRef)) { return; } MatchResult R; @@ -1935,7 +1933,7 @@ class ULCArraySubscriptGadget : public FixableGadget { const auto *DRE = dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts()); if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)) || - !toSupportedVariable(*DRE)) + !isSupportedVariable(*DRE)) return; MatchResult R; R.addNode(ULCArraySubscriptTag, DynTypedNode::create(*ASE)); @@ -1985,7 +1983,7 @@ class UPCStandalonePointerGadget : public FixableGadget { return; const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenImpCasts()); if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)) || - !toSupportedVariable(*DRE)) + !isSupportedVariable(*DRE)) return; MatchResult R; R.addNode(DeclRefExprTag, DynTypedNode::create(*DRE)); @@ -2031,7 +2029,7 @@ class PointerDereferenceGadget : public FixableGadget { continue; CE = CE->IgnoreParenImpCasts(); const auto *DRE = dyn_cast<DeclRefExpr>(CE); - if (!DRE || !toSupportedVariable(*DRE)) + if (!DRE || !isSupportedVariable(*DRE)) continue; MatchResult R; R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE)); @@ -2087,7 +2085,7 @@ class UPCAddressofArraySubscriptGadget : public FixableGadget { return; const auto *DRE = dyn_cast<DeclRefExpr>(ASE->getBase()->IgnoreParenImpCasts()); - if (!DRE || !toSupportedVariable(*DRE)) + if (!DRE || !isSupportedVariable(*DRE)) return; MatchResult R; R.addNode(UPCAddressofArraySubscriptTag, DynTypedNode::create(*UO)); @@ -2209,7 +2207,7 @@ class UPCPreIncrementGadget : public FixableGadget { if (!UO || UO->getOpcode() != UO_PreInc) return; const auto *DRE = dyn_cast<DeclRefExpr>(UO->getSubExpr()); - if (!DRE || !toSupportedVariable(*DRE)) + if (!DRE || !isSupportedVariable(*DRE)) return; MatchResult R; R.addNode(UPCPreIncrementTag, DynTypedNode::create(*UO)); @@ -2261,7 +2259,7 @@ class UUCAddAssignGadget : public FixableGadget { if (!BO || BO->getOpcode() != BO_AddAssign) return; const auto *DRE = dyn_cast<DeclRefExpr>(BO->getLHS()); - if (!DRE || !hasPointerType(*DRE) || !toSupportedVariable(*DRE) || + if (!DRE || !hasPointerType(*DRE) || !isSupportedVariable(*DRE) || !isa<Expr>(BO->getRHS())) return; MatchResult R; @@ -2309,7 +2307,7 @@ class DerefSimplePtrArithFixableGadget : public FixableGadget { if (!E || !hasPointerType(*E)) return false; const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreImpCasts()); - if (!DRE || !toSupportedVariable(*DRE)) + if (!DRE || !isSupportedVariable(*DRE)) return false; R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE)); return true; @@ -3055,7 +3053,7 @@ static inline std::optional<FixItList> createDataFixit(const ASTContext &Ctx, // `DRE.data()` std::optional<FixItList> UPCStandalonePointerGadget::getFixits(const FixitStrategy &S) const { - const auto *const VD = cast<VarDecl>(Node->getDecl()); + auto *VD = cast<VarDecl>(Node->getDecl()); switch (S.lookup(VD)) { case FixitStrategy::Kind::Array: case FixitStrategy::Kind::Span: { _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits