ymandel created this revision.
ymandel added a reviewer: gribozavr2.
Herald added a project: clang.

The new combinator, `rewriteDescendants`, applies a rewrite rule to all
descendants of a specified bound node.  That rewrite rule can refer to nodes
bound by the parent, both in the matcher and in the edits.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D84409

Files:
  clang/include/clang/Tooling/Transformer/RewriteRule.h
  clang/lib/Tooling/Transformer/RewriteRule.cpp
  clang/lib/Tooling/Transformer/Transformer.cpp
  clang/unittests/Tooling/TransformerTest.cpp

Index: clang/unittests/Tooling/TransformerTest.cpp
===================================================================
--- clang/unittests/Tooling/TransformerTest.cpp
+++ clang/unittests/Tooling/TransformerTest.cpp
@@ -114,7 +114,9 @@
       if (C) {
         Changes.push_back(std::move(*C));
       } else {
-        consumeError(C.takeError());
+        // FIXME: stash this error rather then printing.
+        llvm::errs() << "Error generating changes: "
+                     << llvm::toString(C.takeError()) << "\n";
         ++ErrorCount;
       }
     };
@@ -414,6 +416,71 @@
            Input, Expected);
 }
 
+// Rewrite various Stmts inside a Decl.
+TEST_F(TransformerTest, RewriteDescendantsDeclChangeStmt) {
+  std::string Input =
+      "int f(int x) { int y = x; { int z = x * x; } return x; }";
+  std::string Expected =
+      "int f(int x) { int y = 3; { int z = 3 * 3; } return 3; }";
+  auto InlineX =
+      makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3")));
+  testRule(makeRule(functionDecl(hasName("f")).bind("fun"),
+                    rewriteDescendants("fun", InlineX)),
+           Input, Expected);
+}
+
+// Rewrite various TypeLocs inside a Decl.
+TEST_F(TransformerTest, RewriteDescendantsDeclChangeTypeLoc) {
+  std::string Input = "int f(int *x) { return *x; }";
+  std::string Expected = "char f(char *x) { return *x; }";
+  auto IntToChar = makeRule(typeLoc(loc(qualType(isInteger(), builtinType()))),
+                            changeTo(cat("char")));
+  testRule(makeRule(functionDecl(hasName("f")).bind("fun"),
+                    rewriteDescendants("fun", IntToChar)),
+           Input, Expected);
+}
+
+TEST_F(TransformerTest, RewriteDescendantsStmt) {
+  std::string Input =
+      "int f(int x) { int y = x; { int z = x * x; } return x; }";
+  std::string Expected =
+      "int f(int x) { int y = 3; { int z = 3 * 3; } return 3; }";
+  auto InlineX =
+      makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3")));
+  testRule(makeRule(functionDecl(hasName("f"), hasBody(stmt().bind("body"))),
+                    rewriteDescendants("body", InlineX)),
+           Input, Expected);
+}
+
+TEST_F(TransformerTest, RewriteDescendantsTypeLoc) {
+  std::string Input = "int f(int *x) { return *x; }";
+  std::string Expected = "int f(char *x) { return *x; }";
+  auto IntToChar =
+      makeRule(typeLoc(loc(qualType(isInteger(), builtinType()))).bind("loc"),
+               changeTo(cat("char")));
+  testRule(
+      makeRule(functionDecl(hasName("f"),
+                            hasParameter(0, varDecl(hasTypeLoc(
+                                                typeLoc().bind("parmType"))))),
+               rewriteDescendants("parmType", IntToChar)),
+      Input, Expected);
+}
+
+TEST_F(TransformerTest, RewriteDescendantsReferToParentBinding) {
+  std::string Input =
+      "int f(int p) { int y = p; { int z = p * p; } return p; }";
+  std::string Expected =
+      "int f(int p) { int y = 3; { int z = 3 * 3; } return 3; }";
+  std::string VarId = "var";
+  auto InlineVar = makeRule(declRefExpr(to(varDecl(equalsBoundNode(VarId)))),
+                            changeTo(cat("3")));
+  testRule(makeRule(functionDecl(hasName("f"),
+                                 hasParameter(0, varDecl().bind(VarId)))
+                        .bind("fun"),
+                    rewriteDescendants("fun", InlineVar)),
+           Input, Expected);
+}
+
 TEST_F(TransformerTest, InsertBeforeEdit) {
   std::string Input = R"cc(
     int f() {
@@ -1064,4 +1131,35 @@
   EXPECT_EQ(format(*UpdatedCode), format(R"cc(#include "input.h"
                         ;)cc"));
 }
+
+TEST_F(TransformerTest, RewriteDescendantsUnboundNode) {
+  std::string Input =
+      "int f(int x) { int y = x; { int z = x * x; } return x; }";
+  auto InlineX =
+      makeRule(declRefExpr(to(varDecl(hasName("x")))), changeTo(cat("3")));
+  Transformer T(makeRule(functionDecl(hasName("f")),
+                         rewriteDescendants("UNBOUND", InlineX)),
+                consumer());
+  T.registerMatchers(&MatchFinder);
+  EXPECT_FALSE(rewrite(Input));
+  EXPECT_THAT(Changes, IsEmpty());
+  EXPECT_EQ(ErrorCount, 1);
+}
+
+TEST_F(TransformerTest, RewriteDescendantsInvalidNodeType) {
+  std::string Input =
+      "int f(int x) { int y = x; { int z = x * x; } return x; }";
+  auto IntToChar =
+      makeRule(qualType(isInteger(), builtinType()), changeTo(cat("char")));
+  Transformer T(
+      makeRule(functionDecl(
+                   hasName("f"),
+                   hasParameter(0, varDecl(hasType(qualType().bind("type"))))),
+               rewriteDescendants("type", IntToChar)),
+      consumer());
+  T.registerMatchers(&MatchFinder);
+  EXPECT_FALSE(rewrite(Input));
+  EXPECT_THAT(Changes, IsEmpty());
+  EXPECT_EQ(ErrorCount, 1);
+}
 } // namespace
Index: clang/lib/Tooling/Transformer/Transformer.cpp
===================================================================
--- clang/lib/Tooling/Transformer/Transformer.cpp
+++ clang/lib/Tooling/Transformer/Transformer.cpp
@@ -38,6 +38,7 @@
     return;
   }
 
+  // FIXME: some combinators legitimately return no changes.
   if (Transformations->empty()) {
     // No rewrite applied (but no error encountered either).
     transformer::detail::getRuleMatchLoc(Result).print(
Index: clang/lib/Tooling/Transformer/RewriteRule.cpp
===================================================================
--- clang/lib/Tooling/Transformer/RewriteRule.cpp
+++ clang/lib/Tooling/Transformer/RewriteRule.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Tooling/Transformer/RewriteRule.h"
+#include "clang/AST/ASTTypeTraits.h"
+#include "clang/AST/Stmt.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Basic/SourceLocation.h"
@@ -115,15 +117,153 @@
   return change(std::move(S), std::make_shared<SimpleTextGenerator>(""));
 }
 
-RewriteRule transformer::makeRule(ast_matchers::internal::DynTypedMatcher M,
-                                  EditGenerator Edits,
+RewriteRule transformer::makeRule(DynTypedMatcher M, EditGenerator Edits,
                                   TextGenerator Explanation) {
   return RewriteRule{{RewriteRule::Case{
       std::move(M), std::move(Edits), std::move(Explanation), {}}}};
 }
 
+namespace {
+
+/// Unconditionally binds the given node set before trying `InnerMatcher` and
+/// keeps the bound nodes on a successful match.
+template <typename T>
+class BindingsMatcher : public ast_matchers::internal::MatcherInterface<T> {
+  ast_matchers::BoundNodes Nodes;
+  const ast_matchers::internal::Matcher<T> InnerMatcher;
+
+public:
+  explicit BindingsMatcher(ast_matchers::BoundNodes Nodes,
+                           ast_matchers::internal::Matcher<T> InnerMatcher)
+      : Nodes(std::move(Nodes)), InnerMatcher(std::move(InnerMatcher)) {}
+
+  bool matches(
+      const T &Node, ast_matchers::internal::ASTMatchFinder *Finder,
+      ast_matchers::internal::BoundNodesTreeBuilder *Builder) const override {
+    ast_matchers::internal::BoundNodesTreeBuilder Result(*Builder);
+    for (const auto &N : Nodes.getMap())
+      Result.setBinding(N.first, N.second);
+    if (InnerMatcher.matches(Node, Finder, &Result)) {
+      *Builder = std::move(Result);
+      return true;
+    }
+    return false;
+  }
+};
+
+/// Matches nodes of type T that have at least one descendant node for which the
+/// given inner matcher matches.  Will match for each descendant node that
+/// matches.  Based on ForEachDescendantMatcher, but takes a dynamic matcher,
+/// instead of a static one, because it is used by RewriteRule, which carries
+/// (only top-level) dynamic matchers.
+template <typename T>
+class DynamicForEachDescendantMatcher
+    : public ast_matchers::internal::MatcherInterface<T> {
+  const DynTypedMatcher DescendantMatcher;
+
+public:
+  explicit DynamicForEachDescendantMatcher(DynTypedMatcher DescendantMatcher)
+      : DescendantMatcher(std::move(DescendantMatcher)) {}
+
+  bool matches(
+      const T &Node, ast_matchers::internal::ASTMatchFinder *Finder,
+      ast_matchers::internal::BoundNodesTreeBuilder *Builder) const override {
+    return Finder->matchesDescendantOf(
+        Node, this->DescendantMatcher, Builder,
+        ast_matchers::internal::ASTMatchFinder::BK_All);
+  }
+};
+
+template <typename T>
+ast_matchers::internal::Matcher<T>
+forEachDescendantDynamically(ast_matchers::BoundNodes Nodes,
+                             DynTypedMatcher M) {
+  return ast_matchers::internal::makeMatcher(new BindingsMatcher<T>(
+      std::move(Nodes),
+      ast_matchers::internal::makeMatcher(
+          new DynamicForEachDescendantMatcher<T>(std::move(M)))));
+}
+
+class ApplyRuleCallback : public MatchFinder::MatchCallback {
+public:
+  ApplyRuleCallback(RewriteRule Rule) : Rule(std::move(Rule)) {}
+
+  template <typename T>
+  void registerMatchers(ast_matchers::BoundNodes Nodes, MatchFinder *MF) {
+    for (auto &Matcher : transformer::detail::buildMatchers(Rule))
+      MF->addMatcher(forEachDescendantDynamically<T>(std::move(Nodes), Matcher),
+                     this);
+  }
+
+  void run(const MatchFinder::MatchResult &Result) override {
+    if (!Edits)
+      return;
+    transformer::RewriteRule::Case Case =
+        transformer::detail::findSelectedCase(Result, Rule);
+    auto Transformations = Case.Edits(Result);
+    if (!Transformations) {
+      Edits = Transformations.takeError();
+      return;
+    }
+    // FIXME: some combinators legitimately return no changes.
+    if (Transformations->empty()) {
+      // No rewrite applied (but no error encountered either).
+      transformer::detail::getRuleMatchLoc(Result).print(
+          llvm::errs() << "note: skipping atch at loc ", *Result.SourceManager);
+      llvm::errs() << "\n";
+      return;
+    }
+
+    Edits->append(Transformations->begin(), Transformations->end());
+  }
+
+  RewriteRule Rule;
+
+  // Initialize to a non-error state.
+  Expected<SmallVector<Edit, 1>> Edits = SmallVector<Edit, 1>();
+};
+} // namespace
+
+template <typename T>
+llvm::Expected<SmallVector<clang::transformer::Edit, 1>>
+rewriteDescendantsImpl(const T &Node, RewriteRule Rule,
+                       const MatchResult &Result) {
+  ApplyRuleCallback Callback(std::move(Rule));
+  MatchFinder Finder;
+  Callback.registerMatchers<T>(Result.Nodes, &Finder);
+  Finder.match(Node, *Result.Context);
+  return std::move(Callback.Edits);
+}
+
+EditGenerator transformer::rewriteDescendants(std::string NodeId,
+                                              RewriteRule Rule) {
+  // FIXME: warn or return error if `Rule` contains any `AddedIncludes`, since
+  // these will be dropped.
+  return [NodeId = std::move(NodeId),
+          Rule = std::move(Rule)](const MatchResult &Result)
+             -> llvm::Expected<SmallVector<clang::transformer::Edit, 1>> {
+    const ast_matchers::BoundNodes::IDToNodeMap &NodesMap =
+        Result.Nodes.getMap();
+    auto It = NodesMap.find(NodeId);
+    if (It == NodesMap.end())
+      return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
+                                                 "ID not bound: " + NodeId);
+    if (auto *Node = It->second.get<Decl>())
+      return rewriteDescendantsImpl(*Node, std::move(Rule), Result);
+    if (auto *Node = It->second.get<Stmt>())
+      return rewriteDescendantsImpl(*Node, std::move(Rule), Result);
+    if (auto *Node = It->second.get<TypeLoc>())
+      return rewriteDescendantsImpl(*Node, std::move(Rule), Result);
+
+    return llvm::make_error<llvm::StringError>(
+        llvm::errc::invalid_argument,
+        "type unsupported for recursive rewriting, ID=\"" + NodeId +
+            "\", Kind=" + It->second.getNodeKind().asStringRef());
+  };
+}
+
 void transformer::addInclude(RewriteRule &Rule, StringRef Header,
-                         IncludeFormat Format) {
+                             IncludeFormat Format) {
   for (auto &Case : Rule.Cases)
     Case.AddedIncludes.emplace_back(Header.str(), Format);
 }
Index: clang/include/clang/Tooling/Transformer/RewriteRule.h
===================================================================
--- clang/include/clang/Tooling/Transformer/RewriteRule.h
+++ clang/include/clang/Tooling/Transformer/RewriteRule.h
@@ -332,6 +332,11 @@
                    remove(enclose(after(inner), after(outer)))});
 }
 
+// Applies `Rule` to all descendants of the node bound to `NodeId`. `Rule` can
+// refer to nodes bound by the calling rule. `Rule` is not applied to the node
+// itself.
+EditGenerator rewriteDescendants(std::string NodeId, RewriteRule Rule);
+
 /// The following three functions are a low-level part of the RewriteRule
 /// API. We expose them for use in implementing the fixtures that interpret
 /// RewriteRule, like Transformer and TransfomerTidy, or for more advanced
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to