ymandel created this revision.
ymandel added a reviewer: ilya-biryukov.
Herald added subscribers: cfe-commits, jdoerfert, jfb, mgorny.
Herald added a project: clang.
ymandel added a parent revision: D59329: [LibTooling] Add NodeId, a strong type 
for AST-matcher node identifiers..

Adds a basic version of Transformer, a library supporting the concise 
specification of clang-based source-to-source transformations.  A full 
discussion of the end goal can be found on the cfe-dev list with subject "[RFC] 
Easier source-to-source transformations with clang tooling".


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D59376

Files:
  clang/include/clang/Tooling/Refactoring/Transformer.h
  clang/lib/Tooling/Refactoring/CMakeLists.txt
  clang/lib/Tooling/Refactoring/Transformer.cpp
  clang/unittests/Tooling/CMakeLists.txt
  clang/unittests/Tooling/TransformerTest.cpp

Index: clang/unittests/Tooling/TransformerTest.cpp
===================================================================
--- /dev/null
+++ clang/unittests/Tooling/TransformerTest.cpp
@@ -0,0 +1,428 @@
+//===- unittest/Tooling/TransformerTest.cpp -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/Refactoring/Transformer.h"
+
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Tooling/Tooling.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+namespace tooling {
+namespace {
+using ::clang::ast_matchers::anyOf;
+using ::clang::ast_matchers::argumentCountIs;
+using ::clang::ast_matchers::callee;
+using ::clang::ast_matchers::callExpr;
+using ::clang::ast_matchers::cxxMemberCallExpr;
+using ::clang::ast_matchers::cxxMethodDecl;
+using ::clang::ast_matchers::cxxRecordDecl;
+using ::clang::ast_matchers::declRefExpr;
+using ::clang::ast_matchers::expr;
+using ::clang::ast_matchers::functionDecl;
+using ::clang::ast_matchers::hasAnyName;
+using ::clang::ast_matchers::hasArgument;
+using ::clang::ast_matchers::hasDeclaration;
+using ::clang::ast_matchers::hasElse;
+using ::clang::ast_matchers::hasName;
+using ::clang::ast_matchers::hasType;
+using ::clang::ast_matchers::ifStmt;
+using ::clang::ast_matchers::member;
+using ::clang::ast_matchers::memberExpr;
+using ::clang::ast_matchers::namedDecl;
+using ::clang::ast_matchers::on;
+using ::clang::ast_matchers::pointsTo;
+using ::clang::ast_matchers::to;
+using ::clang::ast_matchers::unless;
+
+constexpr char KHeaderContents[] = R"cc(
+  struct string {
+    string(const char*);
+    char* c_str();
+    int size();
+  };
+  int strlen(const char*);
+
+  namespace proto {
+  struct PCFProto {
+    int foo();
+  };
+  struct ProtoCommandLineFlag : PCFProto {
+    PCFProto& GetProto();
+  };
+  }  // namespace proto
+)cc";
+} // namespace
+
+static clang::ast_matchers::internal::Matcher<clang::QualType>
+isOrPointsTo(const DeclarationMatcher &TypeMatcher) {
+  return anyOf(hasDeclaration(TypeMatcher), pointsTo(TypeMatcher));
+}
+
+static std::string format(llvm::StringRef Code) {
+  const std::vector<Range> Ranges(1, Range(0, Code.size()));
+  auto Style = format::getLLVMStyle();
+  const auto Replacements = format::reformat(Style, Code, Ranges);
+  auto Formatted = applyAllReplacements(Code, Replacements);
+  if (!Formatted) {
+    ADD_FAILURE() << "Could not format code: "
+                  << llvm::toString(Formatted.takeError());
+    return std::string();
+  }
+  return *Formatted;
+}
+
+void compareSnippets(llvm::StringRef Expected,
+                     const llvm::Optional<std::string> &MaybeActual) {
+  ASSERT_TRUE(MaybeActual) << "Rewrite failed. Expecting: " << Expected;
+  auto Actual = *MaybeActual;
+  std::string HL = "#include \"header.h\"\n";
+  auto I = Actual.find(HL);
+  if (I != std::string::npos) {
+    Actual.erase(I, HL.size());
+  }
+  EXPECT_EQ(format(Expected), format(Actual));
+}
+
+// FIXME: consider separating this class into its own file(s).
+class ClangRefactoringTestBase : public testing::Test {
+protected:
+  void appendToHeader(llvm::StringRef S) { FileContents[0].second += S; }
+
+  void addFile(llvm::StringRef Filename, llvm::StringRef Content) {
+    FileContents.emplace_back(Filename, Content);
+  }
+
+  llvm::Optional<std::string> rewrite(llvm::StringRef Input) {
+    std::string Code = ("#include \"header.h\"\n" + Input).str();
+    auto Factory = newFrontendActionFactory(&MatchFinder);
+    if (!runToolOnCodeWithArgs(
+            Factory->create(), Code, std::vector<std::string>(), "input.cc",
+            "clang-tool", std::make_shared<PCHContainerOperations>(),
+            FileContents)) {
+      return None;
+    }
+    auto ChangedCodeOrErr =
+        applyAtomicChanges("input.cc", Code, Changes, ApplyChangesSpec());
+    if (auto Err = ChangedCodeOrErr.takeError()) {
+      llvm::errs() << "Change failed: " << llvm::toString(std::move(Err))
+                   << "\n";
+      return None;
+    }
+    return *ChangedCodeOrErr;
+  }
+
+  clang::ast_matchers::MatchFinder MatchFinder;
+  AtomicChanges Changes;
+
+private:
+  FileContentMappings FileContents = {{"header.h", ""}};
+};
+
+class TransformerTest : public ClangRefactoringTestBase {
+protected:
+  TransformerTest() { appendToHeader(KHeaderContents); }
+
+  Transformer::ChangeConsumer changeRecorder() {
+    return [this](const AtomicChange &C) { Changes.push_back(C); };
+  }
+};
+
+// Wraps a (simple) string as a TextGenerator.
+static TextGenerator text(const std::string &M) {
+  return
+      [M](const clang::ast_matchers::MatchFinder::MatchResult &) { return M; };
+}
+
+// Given string s, change strlen($s.c_str()) to $s.size() TODO: my type
+// inference from matchers doesn't work since Matcher types are broken: callExpr
+// is a statement matcher, which i'm pretty sure it shouldn't be.
+RewriteRule ruleStrlenSize() {
+  ExprId StringExpr;
+  auto StringType = namedDecl(hasAnyName("::basic_string", "::string"));
+  return RewriteRule(
+             callExpr(
+                 callee(functionDecl(hasName("strlen"))),
+                 hasArgument(0, cxxMemberCallExpr(
+                                    on(expr(StringExpr.bind(),
+                                            hasType(isOrPointsTo(StringType)))),
+                                    callee(cxxMethodDecl(hasName("c_str")))))))
+      .as<clang::Expr>()
+      .replaceWith(text("REPLACED"))
+      .because(text("Use size() method directly on string."));
+}
+
+TEST_F(TransformerTest, StrlenSize) {
+  std::string Input = "int f(string s) { return strlen(s.c_str()); }";
+  std::string Expected = "int f(string s) { return REPLACED; }";
+
+  Transformer T(ruleStrlenSize(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+// Tests that no change is applied when a match is not expected.
+TEST_F(TransformerTest, NoMatch) {
+  std::string Input = "int f(string s) { return s.size(); }";
+
+  Transformer T(ruleStrlenSize(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  // Input should not be changed.
+  compareSnippets(Input, rewrite(Input));
+}
+
+// Tests that expressions in macro arguments are rewritten (when applicable).
+TEST_F(TransformerTest, StrlenSizeMacro) {
+  std::string Input = R"cc(
+#define ID(e) e
+    int f(string s) { return ID(strlen(s.c_str())); })cc";
+  std::string Expected = R"cc(
+#define ID(e) e
+    int f(string s) { return ID(REPLACED); })cc";
+
+  Transformer T(ruleStrlenSize(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+// Use the lvalue-ref overloads of the RewriteRule builder methods.
+TEST_F(TransformerTest, LvalueRefOverloads) {
+  StmtId E;
+  RewriteRule Rule(ifStmt(hasElse(E.bind())));
+  Rule.change(E).replaceWith(text("bar();"));
+
+  std::string Input = R"cc(
+    void foo() {
+      if (10 > 1.0)
+        return;
+      else
+        foo();
+    }
+  )cc";
+  std::string Expected = R"cc(
+    void foo() {
+      if (10 > 1.0)
+        return;
+      else
+        bar();
+    }
+  )cc";
+
+  Transformer T(Rule, changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+// Tests replacing an expression.
+TEST_F(TransformerTest, Flag) {
+  ExprId Flag;
+  auto Rule =
+      RewriteRule(
+          cxxMemberCallExpr(
+              on(expr(Flag.bind(), hasType(cxxRecordDecl(hasName(
+                                       "proto::ProtoCommandLineFlag"))))),
+              unless(callee(cxxMethodDecl(hasName("GetProto"))))))
+          .change(Flag)
+          .replaceWith(text("EXPR"))
+          .because(text("Use GetProto() to access proto fields."));
+
+  std::string Input = R"cc(
+    proto::ProtoCommandLineFlag flag;
+    int x = flag.foo();
+    int y = flag.GetProto().foo();
+  )cc";
+  std::string Expected = R"cc(
+    proto::ProtoCommandLineFlag flag;
+    int x = EXPR.foo();
+    int y = flag.GetProto().foo();
+  )cc";
+
+  Transformer T(Rule, changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+TEST_F(TransformerTest, NodePartNameNamedDecl) {
+  DeclId Fun;
+  auto Rule = RewriteRule(functionDecl(hasName("bad"), Fun.bind()))
+                  .change(Fun, NodePart::Name)
+                  .replaceWith(text("good"));
+
+  std::string Input = R"cc(
+    int bad(int x);
+    int bad(int x) { return x * x; }
+  )cc";
+  std::string Expected = R"cc(
+    int good(int x);
+    int good(int x) { return x * x; }
+  )cc";
+
+  Transformer T(Rule, changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+TEST_F(TransformerTest, NodePartNameDeclRef) {
+  std::string Input = R"cc(
+    template <typename T>
+    T bad(T x) {
+      return x;
+    }
+    int neutral(int x) { return bad<int>(x) * x; }
+  )cc";
+  std::string Expected = R"cc(
+    template <typename T>
+    T bad(T x) {
+      return x;
+    }
+    int neutral(int x) { return good<int>(x) * x; }
+  )cc";
+
+  ExprId Ref;
+  Transformer T(
+      RewriteRule(declRefExpr(to(functionDecl(hasName("bad"))), Ref.bind()))
+          .change(Ref, NodePart::Name)
+          .replaceWith(text("good")),
+      changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+TEST_F(TransformerTest, NodePartNameDeclRefFailure) {
+  std::string Input = R"cc(
+    struct Y {};
+    int operator*(const Y&);
+    int neutral(int x) {
+      Y y;
+      return *y + x;
+    }
+  )cc";
+
+  ExprId Ref;
+  Transformer T(RewriteRule(declRefExpr(to(functionDecl()), Ref.bind()))
+                    .change(Ref, NodePart::Name)
+                    .replaceWith(text("good")),
+                changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Input, rewrite(Input));
+}
+
+TEST_F(TransformerTest, NodePartMember) {
+  ExprId E;
+  auto Rule = RewriteRule(memberExpr(member(hasName("bad")), E.bind()))
+                  .change(E, NodePart::Member)
+                  .replaceWith(text("good"));
+
+  std::string Input = R"cc(
+    struct S {
+      int bad;
+    };
+    int g() {
+      S s;
+      return s.bad;
+    }
+  )cc";
+  std::string Expected = R"cc(
+    struct S {
+      int bad;
+    };
+    int g() {
+      S s;
+      return s.good;
+    }
+  )cc";
+
+  Transformer T(Rule, changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+// A rule that finds function calls with two arguments where the arguments are
+// the same identifier.
+RewriteRule ruleDuplicateArgs() {
+  ExprId Arg0, Arg1;
+  return RewriteRule(callExpr(argumentCountIs(2), hasArgument(0, Arg0.bind()),
+                              hasArgument(1, Arg1.bind())))
+      .where([Arg0, Arg1](
+                 const clang::ast_matchers::MatchFinder::MatchResult &result) {
+        auto *Ref0 = Arg0.getNodeAs<clang::DeclRefExpr>(result);
+        auto *Ref1 = Arg1.getNodeAs<clang::DeclRefExpr>(result);
+        return Ref0 != nullptr && Ref1 != nullptr &&
+               Ref0->getDecl() == Ref1->getDecl();
+      })
+      .as<clang::Expr>()
+      .replaceWith(text("42"));
+}
+
+TEST_F(TransformerTest, FilterPassed) {
+  std::string Input = R"cc(
+    int foo(int x, int y);
+    int x = 3;
+    int z = foo(x, x);
+  )cc";
+  std::string Expected = R"cc(
+    int foo(int x, int y);
+    int x = 3;
+    int z = 42;
+  )cc";
+
+  Transformer T(ruleDuplicateArgs(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Expected, rewrite(Input));
+}
+
+//
+// Negative tests (where we expect no transformation to occur).
+//
+
+TEST_F(TransformerTest, FilterFailed) {
+  std::string Input = R"cc(
+    int foo(int x, int y);
+    int x = 3;
+    int y = 17;
+    // Different identifiers.
+    int z = foo(x, y);
+    // One identifier, one not.
+    int w = foo(x, 3);
+  )cc";
+
+  Transformer T(ruleDuplicateArgs(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  compareSnippets(Input, rewrite(Input));
+}
+
+TEST_F(TransformerTest, NoTransformationInMacro) {
+  std::string Input = R"cc(
+#define MACRO(str) strlen((str).c_str())
+    int f(string s) { return MACRO(s); })cc";
+
+  Transformer T(ruleStrlenSize(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  // The macro should be ignored.
+  compareSnippets(Input, rewrite(Input));
+}
+
+// This test handles the corner case where a macro called within another macro
+// expands to matching code, but the matched code is an argument to the nested
+// macro.  A simple check of isMacroArgExpansion() vs. isMacroBodyExpansion()
+// will get this wrong, and transform the code. This test verifies that no such
+// transformation occurs.
+TEST_F(TransformerTest, NoTransformationInNestedMacro) {
+  std::string Input = R"cc(
+#define NESTED(e) e
+#define MACRO(str) NESTED(strlen((str).c_str()))
+    int f(string s) { return MACRO(s); })cc";
+
+  Transformer T(ruleStrlenSize(), changeRecorder());
+  T.registerMatchers(&MatchFinder);
+  // The macro should be ignored.
+  compareSnippets(Input, rewrite(Input));
+}
+} // namespace tooling
+} // namespace clang
Index: clang/unittests/Tooling/CMakeLists.txt
===================================================================
--- clang/unittests/Tooling/CMakeLists.txt
+++ clang/unittests/Tooling/CMakeLists.txt
@@ -50,6 +50,7 @@
   ReplacementsYamlTest.cpp
   RewriterTest.cpp
   ToolingTest.cpp
+  TransformerTest.cpp
   )
 
 target_link_libraries(ToolingTests
Index: clang/lib/Tooling/Refactoring/Transformer.cpp
===================================================================
--- /dev/null
+++ clang/lib/Tooling/Refactoring/Transformer.cpp
@@ -0,0 +1,239 @@
+//===--- Transformer.cpp - Transformer library implementation ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/Refactoring/Transformer.h"
+#include "clang/AST/Expr.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Rewrite/Core/Rewriter.h"
+#include "clang/Tooling/FixIt.h"
+#include "clang/Tooling/Refactoring.h"
+#include "clang/Tooling/Refactoring/AtomicChange.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include <deque>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace clang {
+namespace tooling {
+namespace {
+using ::clang::ast_matchers::MatchFinder;
+using ::clang::ast_matchers::stmt;
+using ::clang::ast_type_traits::ASTNodeKind;
+using ::clang::ast_type_traits::DynTypedNode;
+using ::llvm::Error;
+using ::llvm::Expected;
+using ::llvm::Optional;
+using ::llvm::StringError;
+using ::llvm::StringRef;
+
+using MatchResult = MatchFinder::MatchResult;
+} // namespace
+
+static bool isOriginMacroBody(const clang::SourceManager &source_manager,
+                              clang::SourceLocation loc) {
+  while (loc.isMacroID()) {
+    if (source_manager.isMacroBodyExpansion(loc))
+      return true;
+    // Otherwise, it must be in an argument, so we continue searching up the
+    // invocation stack. getImmediateMacroCallerLoc() gives the location of the
+    // argument text, inside the call text.
+    loc = source_manager.getImmediateMacroCallerLoc(loc);
+  }
+  return false;
+}
+
+static llvm::Error invalidArgumentError(llvm::Twine Message) {
+  return llvm::make_error<StringError>(llvm::errc::invalid_argument, Message);
+}
+
+static llvm::Error unboundNodeError(StringRef Role, StringRef Id) {
+  return invalidArgumentError(Role + " (=" + Id + ") references unbound node");
+}
+
+static llvm::Error typeError(llvm::Twine Message, const ASTNodeKind &Kind) {
+  return invalidArgumentError(Message + " (node kind is " + Kind.asStringRef() +
+                              ")");
+}
+
+static llvm::Error missingPropertyError(llvm::Twine Description,
+                                        StringRef Property) {
+  return invalidArgumentError(Description + " requires property '" + Property +
+                              "'");
+}
+
+// Verifies that `node` is appropriate for the given `target_part`.
+static Error verifyTarget(const DynTypedNode &Node, NodePart TargetPart) {
+  switch (TargetPart) {
+  case NodePart::Node:
+    return Error::success();
+  case NodePart::Member:
+    if (Node.get<clang::MemberExpr>() != nullptr)
+      return Error::success();
+    return typeError("NodePart::Member applied to non-MemberExpr",
+                     Node.getNodeKind());
+  case NodePart::Name:
+    if (const auto *D = Node.get<clang::NamedDecl>()) {
+      if (D->getDeclName().isIdentifier())
+        return Error::success();
+      return missingPropertyError("NodePart::Name", "identifier");
+    }
+    if (const auto *E = Node.get<clang::DeclRefExpr>()) {
+      if (E->getNameInfo().getName().isIdentifier())
+        return Error::success();
+      return missingPropertyError("NodePart::Name", "identifier");
+    }
+    if (const auto *I = Node.get<clang::CXXCtorInitializer>()) {
+      if (I->isMemberInitializer())
+        return Error::success();
+      return missingPropertyError("NodePart::Name", "member initializer");
+    }
+    return typeError(
+        "NodePart::Name applied to neither DeclRefExpr, NamedDecl nor "
+        "CXXCtorInitializer",
+        Node.getNodeKind());
+  }
+  llvm_unreachable("Unexpected case in NodePart type.");
+}
+
+// Requires VerifyTarget(node, target_part) == success.
+static CharSourceRange getTarget(const DynTypedNode &Node, ASTNodeKind Kind,
+                                 NodePart TargetPart, ASTContext &Context) {
+  SourceLocation TokenLoc;
+  switch (TargetPart) {
+  case NodePart::Node: {
+    // For non-expression statements, associate any trailing semicolon with the
+    // statement text.  However, if the target was intended as an expression (as
+    // indicated by its kind) then we do not associate any trailing semicolon
+    // with it.  We only associate the exact expression text.
+    if (Node.get<Stmt>() != nullptr) {
+      auto ExprKind = ASTNodeKind::getFromNodeKind<clang::Expr>();
+      if (!ExprKind.isBaseOf(Kind))
+        return fixit::getExtendedRange(Node, tok::TokenKind::semi, Context);
+    }
+    return CharSourceRange::getTokenRange(Node.getSourceRange());
+  }
+  case NodePart::Member:
+    TokenLoc = Node.get<clang::MemberExpr>()->getMemberLoc();
+    break;
+  case NodePart::Name:
+    if (const auto *D = Node.get<clang::NamedDecl>()) {
+      TokenLoc = D->getLocation();
+      break;
+    }
+    if (const auto *E = Node.get<clang::DeclRefExpr>()) {
+      TokenLoc = E->getLocation();
+      break;
+    }
+    if (const auto *I = Node.get<clang::CXXCtorInitializer>()) {
+      TokenLoc = I->getMemberLocation();
+      break;
+    }
+    // This should be unreachable if the target was already verified.
+    llvm_unreachable("NodePart::Name applied to neither NamedDecl nor "
+                     "CXXCtorInitializer");
+  }
+  return CharSourceRange::getTokenRange(TokenLoc, TokenLoc);
+}
+
+Expected<Transformation> applyRewriteRule(const RewriteRule &Rule,
+                                          const MatchResult &Result) {
+  // Ignore results in failing TUs or those rejected by the where clause.
+  if (Result.Context->getDiagnostics().hasErrorOccurred() ||
+      !Rule.filter().matches(Result))
+    return Transformation();
+
+  auto &NodesMap = Result.Nodes.getMap();
+  auto It = NodesMap.find(Rule.target());
+  if (It == NodesMap.end())
+    return unboundNodeError("rule.target", Rule.target());
+  if (auto Err = llvm::handleErrors(
+          verifyTarget(It->second, Rule.targetPart()), [&Rule](StringError &E) {
+            return invalidArgumentError("Failure targeting node" +
+                                        Rule.target() + ": " + E.getMessage());
+          })) {
+    return std::move(Err);
+  }
+  CharSourceRange Target = getTarget(It->second, Rule.targetKind(),
+                                     Rule.targetPart(), *Result.Context);
+  if (Target.isInvalid() ||
+      isOriginMacroBody(*Result.SourceManager, Target.getBegin()))
+    return Transformation();
+
+  auto ReplacementOrErr = Rule.replacement(Result);
+  if (auto Err = ReplacementOrErr.takeError())
+    return std::move(Err);
+  return Transformation{Target, std::move(*ReplacementOrErr)};
+}
+
+constexpr char RewriteRule::RootId[];
+
+RewriteRule &
+RewriteRule::where(std::function<bool(const MatchResult &Result)> FilterFn) & {
+  Filter = MatchFilter(std::move(FilterFn));
+  return *this;
+}
+
+RewriteRule &RewriteRule::change(const NodeId &TargetId, NodePart Part) & {
+  Target = std::string(TargetId.id());
+  TargetKind = ASTNodeKind();
+  TargetPart = Part;
+  return *this;
+}
+
+RewriteRule &RewriteRule::replaceWith(TextGenerator TG) & {
+  Replacement = std::move(TG);
+  return *this;
+}
+
+RewriteRule &RewriteRule::because(TextGenerator TG) & {
+  Explanation = std::move(TG);
+  return *this;
+}
+
+// `Explanation` is a `string&`, rather than a `string` or `StringRef` to save
+// an extra copy needed to intialize the captured lambda variable.  After C++14,
+// we can use intializers to do this properly.
+RewriteRule makeRule(StatementMatcher Matcher, TextGenerator Replacement,
+                     const std::string &Explanation) {
+  return RewriteRule(Matcher)
+      .replaceWith(std::move(Replacement))
+      .because([Explanation](const MatchResult &) { return Explanation; });
+}
+
+void Transformer::registerMatchers(MatchFinder *MatchFinder) {
+  MatchFinder->addDynamicMatcher(Rule.matcher(), this);
+}
+
+void Transformer::run(const MatchResult &Result) {
+  auto ChangeOrErr = applyRewriteRule(Rule, Result);
+  if (auto Err = ChangeOrErr.takeError()) {
+    llvm::errs() << "Rewrite failed: " << llvm::toString(std::move(Err))
+                 << "\n";
+    return;
+  }
+  auto &Change = *ChangeOrErr;
+  auto &Range = Change.Range;
+  if (Range.isInvalid()) {
+    // No rewrite applied (but no error encountered either).
+    return;
+  }
+  AtomicChange AC(*Result.SourceManager, Range.getBegin());
+  if (auto Err = AC.replace(*Result.SourceManager, Range, Change.Replacement)) {
+    AC.setError(llvm::toString(std::move(Err)));
+  }
+  Consumer(AC);
+}
+} // namespace tooling
+} // namespace clang
Index: clang/lib/Tooling/Refactoring/CMakeLists.txt
===================================================================
--- clang/lib/Tooling/Refactoring/CMakeLists.txt
+++ clang/lib/Tooling/Refactoring/CMakeLists.txt
@@ -13,6 +13,7 @@
   Rename/USRFindingAction.cpp
   Rename/USRLocFinder.cpp
   NodeId.cpp
+  Transformer.cpp
 
   LINK_LIBS
   clangAST
Index: clang/include/clang/Tooling/Refactoring/Transformer.h
===================================================================
--- /dev/null
+++ clang/include/clang/Tooling/Refactoring/Transformer.h
@@ -0,0 +1,285 @@
+//===--- Transformer.h - Clang source-rewriting library ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+///  \file
+///  Defines a library supporting the concise specification of clang-based
+///  source-to-source transformations.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_
+#define LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_
+
+#include "NodeId.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/ASTMatchers/ASTMatchersInternal.h"
+#include "clang/Tooling/Refactoring/AtomicChange.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Error.h"
+#include <deque>
+#include <functional>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace clang {
+namespace tooling {
+
+/// \name Matcher-type abbreviations for all top-level classes in the
+/// AST class hierarchy.
+/// @{
+using ast_matchers::CXXCtorInitializerMatcher;
+using ast_matchers::DeclarationMatcher;
+using ast_matchers::NestedNameSpecifierLocMatcher;
+using ast_matchers::NestedNameSpecifierMatcher;
+using ast_matchers::StatementMatcher;
+using ast_matchers::TypeLocMatcher;
+using ast_matchers::TypeMatcher;
+using TemplateArgumentMatcher =
+    ast_matchers::internal::Matcher<TemplateArgument>;
+using TemplateNameMatcher = ast_matchers::internal::Matcher<TemplateName>;
+using ast_matchers::internal::DynTypedMatcher;
+/// @}
+
+/// A simple abstraction of a filter for match results.  Currently, it simply
+/// wraps a predicate, but we may extend the functionality to support a simple
+/// boolean expression language for constructing filters.
+class MatchFilter {
+public:
+  using Predicate =
+      std::function<bool(const ast_matchers::MatchFinder::MatchResult &Result)>;
+
+  MatchFilter()
+      : Filter([](const ast_matchers::MatchFinder::MatchResult &) {
+          return true;
+        }) {}
+  explicit MatchFilter(Predicate P) : Filter(std::move(P)) {}
+
+  MatchFilter(const MatchFilter &) = default;
+  MatchFilter(MatchFilter &&) = default;
+  MatchFilter &operator=(const MatchFilter &) = default;
+  MatchFilter &operator=(MatchFilter &&) = default;
+
+  bool matches(const ast_matchers::MatchFinder::MatchResult &Result) const {
+    return Filter(Result);
+  }
+
+private:
+  Predicate Filter;
+};
+
+/// Determines the part of the AST node to replace.  We support this to work
+/// around the fact that the AST does not differentiate various syntactic
+/// elements into their own nodes, so users can specify them relative to a node,
+/// instead.
+enum class NodePart {
+  /// The node itself.
+  Node,
+  /// Given a \c MemberExpr, selects the member's token.
+  Member,
+  /// Given a \c NamedDecl or \c CxxCtorInitializer, selects that token of the
+  /// relevant name, not including qualifiers.
+  Name,
+};
+
+using TextGenerator = std::function<llvm::Expected<std::string>(
+    const ast_matchers::MatchFinder::MatchResult &)>;
+
+/// Description of a source-code transformation.
+//
+// A *rewrite rule* describes a transformation of source code. It has the
+// following components:
+//
+// * Matcher: the pattern term, expressed as clang matchers (with Transformer
+//   extensions).
+//
+// * Where: a "where clause" -- that is, a predicate over (matched) AST nodes
+//   that restricts matches beyond what is (easily) expressable as a pattern.
+//
+// * Target: the source code impacted by the rule. This identifies an AST node,
+//   or part thereof, whose source range indicates the extent of the replacement
+//   applied by the replacement term.  By default, the extent is the node
+//   matched by the pattern term.
+//
+// * Replacement: a function that produces a replacement string for the target,
+//   based on the match result.
+//
+// * Explanation: explanation of the rewrite.
+//
+// Rules have an additional, implicit, component: the parameters. These are
+// portions of the pattern which are left unspecified, yet named so that we can
+// reference them in the replacement term.  The structure of parameters can be
+// partially or even fully specified, in which case they serve just to identify
+// matched nodes for later reference rather than abstract over portions of the
+// AST.  However, in all cases, we refer to named portions of the pattern as
+// parameters.
+//
+// Parameters can be declared explicitly using the NodeId type and its
+// derivatives or left implicit by using the native support for binding ids in
+// the clang matchers.
+//
+// RewriteRule is constructed in a "fluent" style, by chaining setters of
+// individual components.  We provide ref-qualified overloads of the setters to
+// avoid an unnecessary copy when a RewriteRule is initialized from a temporary,
+// like:
+// \code
+//   RewriteRule R = RewriteRule().matching(functionDecl(...)).replaceWith(...);
+// \endcode
+class RewriteRule {
+public:
+  RewriteRule(DynTypedMatcher M)
+      : Matcher(std::move(M)), TargetKind(Matcher.getSupportedKind()) {
+    Matcher.setAllowBind(true);
+  }
+  template <typename T>
+  RewriteRule(ast_matchers::internal::Matcher<T> M)
+      : RewriteRule(makeMatcher(std::move(M))) {}
+
+  RewriteRule(const RewriteRule &) = default;
+  RewriteRule(RewriteRule &&) = default;
+  RewriteRule &operator=(const RewriteRule &) = default;
+  RewriteRule &operator=(RewriteRule &&) = default;
+
+  RewriteRule &where(MatchFilter::Predicate Filter) &;
+  RewriteRule &&where(MatchFilter::Predicate Filter) && {
+    return std::move(where(std::move(Filter)));
+  }
+
+  template <typename T> RewriteRule &as() &;
+  template <typename T> RewriteRule &&as() && { return std::move(as<T>()); }
+
+  RewriteRule &change(const NodeId &Target, NodePart Part = NodePart::Node) &;
+  RewriteRule &&change(const NodeId &Target,
+                       NodePart Part = NodePart::Node) && {
+    return std::move(change(Target, Part));
+  }
+  template <typename T>
+  RewriteRule &change(const TypedNodeId<T> &Target,
+                      NodePart Part = NodePart::Node) &;
+  template <typename T>
+  RewriteRule &&change(const TypedNodeId<T> &Target,
+                       NodePart Part = NodePart::Node) && {
+    return std::move(change(Target, Part));
+  }
+
+  RewriteRule &replaceWith(TextGenerator Replacement) &;
+  RewriteRule &&replaceWith(TextGenerator Replacement) && {
+    return std::move(replaceWith(std::move(Replacement)));
+  }
+
+  RewriteRule &because(TextGenerator Explanation) &;
+  RewriteRule &&because(TextGenerator Explanation) && {
+    return std::move(because(std::move(Explanation)));
+  }
+
+  const DynTypedMatcher &matcher() const { return Matcher; }
+  const MatchFilter &filter() const { return Filter; }
+  llvm::StringRef target() const { return Target; }
+  ast_type_traits::ASTNodeKind targetKind() const { return TargetKind; }
+  NodePart targetPart() const { return TargetPart; }
+
+  llvm::Expected<std::string>
+  replacement(const ast_matchers::MatchFinder::MatchResult &R) const {
+    return Replacement(R);
+  }
+
+  llvm::Expected<std::string>
+  explanation(const ast_matchers::MatchFinder::MatchResult &R) const {
+    return Explanation(R);
+  }
+
+private:
+  template <typename MatcherT> static DynTypedMatcher makeMatcher(MatcherT M) {
+    // Copy `M`'s (underlying) `DynTypedMatcher`.
+    DynTypedMatcher DM = M;
+    DM.setAllowBind(true);
+    // RewriteRule guarantees that the node described by the matcher will always
+    // be accessible as `RootId`, so we bind it here. `tryBind` is guaranteed to
+    // succeed, because `AllowBind` is true.
+    return *DM.tryBind(RootId);
+  }
+
+  // Id used as the default target of each match.
+  static constexpr char RootId[] = "___root___";
+
+  // Supports any (top-level node) matcher type.
+  DynTypedMatcher Matcher;
+  MatchFilter Filter;
+  // The (bound) id of the node whose source will be replaced.  This id should
+  // never be the empty string. By default, refers to the node matched by
+  // `Matcher`.
+  std::string Target = RootId;
+  ast_type_traits::ASTNodeKind TargetKind;
+  NodePart TargetPart = NodePart::Node;
+  TextGenerator Replacement;
+  TextGenerator Explanation;
+};
+
+template <typename T> RewriteRule &RewriteRule::as() & {
+  TargetKind = ast_type_traits::ASTNodeKind::getFromNodeKind<T>();
+  return *this;
+}
+
+template <typename T>
+RewriteRule &RewriteRule::change(const TypedNodeId<T> &TargetId,
+                                 NodePart Part) & {
+  Target = std::string(TargetId.id());
+  TargetKind = ast_type_traits::ASTNodeKind::getFromNodeKind<T>();
+  TargetPart = Part;
+  return *this;
+}
+
+// Convenience factory function for the common case where a rule has a statement
+// matcher, template and explanation.
+RewriteRule makeRule(StatementMatcher Matcher, TextGenerator Replacement,
+                     const std::string &Explanation);
+
+/// A source "transformation," represented by a character range in the source to
+/// be replaced and a corresponding replacement string.
+struct Transformation {
+  CharSourceRange Range;
+  std::string Replacement;
+};
+
+/// Attempts to apply a rule to a match.  Fails if the match is not eligible for
+/// rewriting or, for example, if any invariants are violated relating to bound
+/// nodes in the match.
+Expected<Transformation>
+applyRewriteRule(const RewriteRule &Rule,
+                 const ast_matchers::MatchFinder::MatchResult &Match);
+
+/// Handles the matcher and callback registration for a single rewrite rule, as
+/// defined by the arguments of the constructor.
+class Transformer : public ast_matchers::MatchFinder::MatchCallback {
+public:
+  using ChangeConsumer =
+      std::function<void(const clang::tooling::AtomicChange &Change)>;
+
+  /// \param Consumer Receives each successful rewrites as an \c AtomicChange.
+  Transformer(RewriteRule Rule, ChangeConsumer Consumer)
+      : Rule(std::move(Rule)), Consumer(std::move(Consumer)) {}
+
+  /// N.B. Passes `this` pointer to `MatchFinder`.  So, this object should not
+  /// be moved after this call.
+  void registerMatchers(ast_matchers::MatchFinder *MatchFinder);
+
+  /// Not called directly by users -- called by the framework, via base class
+  /// pointer.
+  void run(const ast_matchers::MatchFinder::MatchResult &Result) override;
+
+private:
+  RewriteRule Rule;
+  /// Receives each successful rewrites as an \c AtomicChange.
+  ChangeConsumer Consumer;
+};
+} // namespace tooling
+} // namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_TRANSFORMER_H_
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to