johannes updated this revision to Diff 112845.
johannes retitled this revision from "Add 
include/clang/Tooling/ASTDiff/ASTPatch.h" to "[clang-diff] Initial 
implementation of patching".
johannes edited the summary of this revision.
johannes added a comment.

use rewriter to patch a third AST


https://reviews.llvm.org/D37005

Files:
  include/clang/Tooling/ASTDiff/ASTDiff.h
  lib/Tooling/ASTDiff/ASTDiff.cpp
  lib/Tooling/ASTDiff/CMakeLists.txt
  tools/clang-diff/CMakeLists.txt
  tools/clang-diff/ClangDiff.cpp
  unittests/Tooling/ASTDiffTest.cpp
  unittests/Tooling/CMakeLists.txt

Index: unittests/Tooling/CMakeLists.txt
===================================================================
--- unittests/Tooling/CMakeLists.txt
+++ unittests/Tooling/CMakeLists.txt
@@ -11,6 +11,7 @@
 endif()
 
 add_clang_unittest(ToolingTests
+  ASTDiffTest.cpp
   ASTSelectionTest.cpp
   CastExprTest.cpp
   CommentHandlerTest.cpp
@@ -43,4 +44,5 @@
   clangTooling
   clangToolingCore
   clangToolingRefactor
+  clangToolingASTDiff
   )
Index: unittests/Tooling/ASTDiffTest.cpp
===================================================================
--- /dev/null
+++ unittests/Tooling/ASTDiffTest.cpp
@@ -0,0 +1,85 @@
+//===- unittest/Tooling/ASTDiffTest.cpp -----------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/ASTDiff/ASTDiff.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+using namespace tooling;
+
+static std::string patchResult(std::array<std::string, 3> Codes) {
+  diff::SyntaxTree Trees[3];
+  std::unique_ptr<ASTUnit> ASTs[3];
+  std::vector<std::string> Args = {};
+  for (int I = 0; I < 3; I++) {
+    ASTs[I] = buildASTFromCode(Codes[I]);
+    if (!ASTs[I]) {
+      llvm::errs() << "Failed to build AST from code:\n" << Codes[I] << "\n";
+      return "";
+    }
+    Trees[I] = diff::SyntaxTree(*ASTs[I]);
+  }
+
+  diff::ComparisonOptions Options;
+  std::string TargetDstCode;
+  llvm::raw_string_ostream OS(TargetDstCode);
+  if (!diff::patch(/*ModelSrc=*/Trees[0], /*ModelDst=*/Trees[1],
+                   /*TargetSrc=*/Trees[2], Options, OS))
+    return "";
+  return OS.str();
+}
+
+// abstract the EXPECT_EQ call so that the code snippets align properly
+// use macros for this to make test failures have proper line numbers
+#define PATCH(Preamble, ModelSrc, ModelDst, Target, Expected)                  \
+  EXPECT_EQ(patchResult({{std::string(Preamble) + ModelSrc,                    \
+                          std::string(Preamble) + ModelDst,                    \
+                          std::string(Preamble) + Target}}),                   \
+            std::string(Preamble) + Expected)
+
+TEST(ASTDiff, TestDeleteArguments) {
+  PATCH(R"(void printf(const char *, ...);)",
+        R"(void foo(int x) { printf("%d", x, x); })",
+        R"(void foo(int x) { printf("%d", x); })",
+        R"(void foo(int x) { printf("different string %d", x, x); })",
+        R"(void foo(int x) { printf("different string %d", x); })");
+
+  PATCH(R"(void foo(...);)",
+        R"(void test1() { foo ( 1 + 1); })",
+        R"(void test1() { foo ( ); })",
+        R"(void test2() { foo ( 1 + 1 ); })",
+        R"(void test2() { foo (  ); })");
+
+  PATCH(R"(void foo(...);)",
+        R"(void test1() { foo (1, 2 + 2); })",
+        R"(void test1() { foo (2 + 2); })",
+        R"(void test2() { foo (/*L*/ 0 /*R*/ , 2 + 2); })",
+        R"(void test2() { foo (/*L*/  2 + 2); })");
+
+  PATCH(R"(void foo(...);)",
+        R"(void test1() { foo (1, 2); })",
+        R"(void test1() { foo (1); })",
+        R"(void test2() { foo (0, /*L*/ 0 /*R*/); })",
+        R"(void test2() { foo (0 /*R*/); })");
+}
+
+TEST(ASTDiff, TestDeleteDecls) {
+  PATCH(R"()",
+        R"()",
+        R"()",
+        R"()",
+        R"()");
+
+  PATCH(R"()",
+        R"(void foo(){})",
+        R"()",
+        R"(int x; void foo() {;;} int y;)",
+        R"(int x;  int y;)");
+}
Index: tools/clang-diff/ClangDiff.cpp
===================================================================
--- tools/clang-diff/ClangDiff.cpp
+++ tools/clang-diff/ClangDiff.cpp
@@ -42,6 +42,12 @@
                               cl::desc("Output a side-by-side diff in HTML."),
                               cl::init(false), cl::cat(ClangDiffCategory));
 
+static cl::opt<std::string>
+    Patch("patch",
+          cl::desc("Try to apply the edit actions between the two input "
+                   "files to the specified target."),
+          cl::desc("<target>"), cl::cat(ClangDiffCategory));
+
 static cl::opt<std::string> SourcePath(cl::Positional, cl::desc("<source>"),
                                        cl::Required,
                                        cl::cat(ClangDiffCategory));
@@ -563,6 +569,16 @@
   }
   diff::SyntaxTree SrcTree(*Src);
   diff::SyntaxTree DstTree(*Dst);
+
+  if (!Patch.empty()) {
+    auto Target = getAST(CommonCompilations, Patch);
+    if (!Target)
+      return 1;
+    diff::SyntaxTree TargetTree(*Target);
+    diff::patch(SrcTree, DstTree, TargetTree, Options, llvm::outs());
+    return 0;
+  }
+
   diff::ASTDiff Diff(SrcTree, DstTree, Options);
 
   if (HtmlDiff) {
Index: tools/clang-diff/CMakeLists.txt
===================================================================
--- tools/clang-diff/CMakeLists.txt
+++ tools/clang-diff/CMakeLists.txt
@@ -9,6 +9,7 @@
 target_link_libraries(clang-diff
   clangBasic
   clangFrontend
+  clangRewrite
   clangTooling
   clangToolingASTDiff
   )
Index: lib/Tooling/ASTDiff/CMakeLists.txt
===================================================================
--- lib/Tooling/ASTDiff/CMakeLists.txt
+++ lib/Tooling/ASTDiff/CMakeLists.txt
@@ -8,4 +8,6 @@
   clangBasic
   clangAST
   clangLex
+  clangRewrite
+  clangToolingCore
   )
Index: lib/Tooling/ASTDiff/ASTDiff.cpp
===================================================================
--- lib/Tooling/ASTDiff/ASTDiff.cpp
+++ lib/Tooling/ASTDiff/ASTDiff.cpp
@@ -18,6 +18,8 @@
 #include "clang/AST/LexicallyOrderedRecursiveASTVisitor.h"
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Lex/Lexer.h"
+#include "clang/Rewrite/Core/Rewriter.h"
+#include "clang/Tooling/Core/Replacement.h"
 #include "llvm/ADT/PriorityQueue.h"
 #include "llvm/Support/MD5.h"
 
@@ -27,6 +29,7 @@
 
 using namespace llvm;
 using namespace clang;
+using namespace tooling;
 
 namespace clang {
 namespace diff {
@@ -139,6 +142,7 @@
        typename std::enable_if<std::is_base_of<Decl, T>::value, T>::type *Node,
        ASTUnit &AST)
       : Impl(Parent, dyn_cast<Decl>(Node), AST) {}
+  explicit Impl(SyntaxTree *Parent, const Impl &Other);
 
   SyntaxTree *Parent;
   ASTUnit &AST;
@@ -175,6 +179,8 @@
 
   HashType hashNode(const Node &N) const;
 
+  SourceRange getSourceRange(const Node &N) const;
+
 private:
   void initTree();
   void setLeftMostDescendants();
@@ -337,6 +343,15 @@
   initTree();
 }
 
+SyntaxTree::Impl::Impl(SyntaxTree *Parent, const Impl &Other)
+    : Impl(Parent, Other.AST) {
+  Nodes = Other.Nodes;
+  Leaves = Other.Leaves;
+  PostorderIds = Other.PostorderIds;
+  NodesBfs = Other.NodesBfs;
+  TemplateArgumentLocations = TemplateArgumentLocations;
+}
+
 static std::vector<NodeId> getSubtreePostorder(const SyntaxTree::Impl &Tree,
                                                NodeId Root) {
   std::vector<NodeId> Postorder;
@@ -638,6 +653,26 @@
   return HashResult;
 }
 
+SourceRange SyntaxTree::Impl::getSourceRange(const Node &N) const {
+  SourceRange Range;
+  if (auto *Arg = N.ASTNode.get<TemplateArgument>())
+    Range = TemplateArgumentLocations.at(&N - &Nodes[0]);
+  else {
+    Range = N.ASTNode.getSourceRange();
+    if (auto *ThisExpr = N.ASTNode.get<CXXThisExpr>())
+      if (ThisExpr->isImplicit())
+        Range.setEnd(Range.getBegin());
+    // If it is a CXXConstructExpr that is not a temporary, then there is
+    // probably an identifier of an initialization that is included in the
+    // range. This identifier belongs to the parent node, so stick to the
+    // ctor arguments only.
+    if (auto *CE = N.ASTNode.get<CXXConstructExpr>())
+      if (!isa<CXXTemporaryObjectExpr>(CE))
+        Range = CE->getParenOrBraceRange();
+  }
+  return getSourceExtent(AST, Range);
+}
+
 /// Identifies a node in a subtree by its postorder offset, starting at 1.
 struct SNodeId {
   int Id = 0;
@@ -1210,10 +1245,19 @@
   return DiffImpl->getMapped(SourceTree.TreeImpl, Id);
 }
 
+SyntaxTree::SyntaxTree() : TreeImpl(nullptr) {}
+
 SyntaxTree::SyntaxTree(ASTUnit &AST)
     : TreeImpl(llvm::make_unique<SyntaxTree::Impl>(
           this, AST.getASTContext().getTranslationUnitDecl(), AST)) {}
 
+SyntaxTree::SyntaxTree(SyntaxTree &&Other) = default;
+
+SyntaxTree &SyntaxTree::operator=(SyntaxTree &&Other) = default;
+
+SyntaxTree::SyntaxTree(const SyntaxTree &Other)
+    : TreeImpl(llvm::make_unique<SyntaxTree::Impl>(this, *Other.TreeImpl)) {}
+
 SyntaxTree::~SyntaxTree() = default;
 
 ASTUnit &SyntaxTree::getASTUnit() const { return TreeImpl->AST; }
@@ -1237,19 +1281,14 @@
   return TreeImpl->findPositionInParent(Id);
 }
 
+SourceRange SyntaxTree::getSourceRange(const Node &N) const {
+  return TreeImpl->getSourceRange(N);
+}
+
 std::pair<unsigned, unsigned>
 SyntaxTree::getSourceRangeOffsets(const Node &N) const {
   const SourceManager &SrcMgr = TreeImpl->AST.getSourceManager();
-  SourceRange Range;
-  if (auto *Arg = N.ASTNode.get<TemplateArgument>())
-    Range = TreeImpl->TemplateArgumentLocations.at(&N - &TreeImpl->Nodes[0]);
-  else {
-    Range = N.ASTNode.getSourceRange();
-    if (auto *ThisExpr = N.ASTNode.get<CXXThisExpr>())
-      if (ThisExpr->isImplicit())
-        Range.setEnd(Range.getBegin());
-  }
-  Range = getSourceExtent(TreeImpl->AST, Range);
+  SourceRange Range = TreeImpl->getSourceRange(N);
   unsigned Begin = SrcMgr.getFileOffset(Range.getBegin());
   unsigned End = SrcMgr.getFileOffset(Range.getEnd());
   return {Begin, End};
@@ -1263,5 +1302,89 @@
   return TreeImpl->getNodeValue(N);
 }
 
+struct Patcher {
+  SyntaxTree::Impl &ModelSrc, &ModelDst, &Target;
+  const ComparisonOptions &Options;
+  raw_ostream &OS;
+  SourceManager &SrcMgr;
+  const LangOptions &LangOpts;
+  Replacements Replaces;
+  SyntaxTree ModelSrcCopy;
+  ASTDiff ModelDiff, ModelTargetDiff;
+
+  Patcher(SyntaxTree &ModelSrc, SyntaxTree &ModelDst, SyntaxTree &Target,
+          const ComparisonOptions &Options, raw_ostream &OS)
+      : ModelSrc(*ModelSrc.TreeImpl), ModelDst(*ModelDst.TreeImpl),
+        Target(*Target.TreeImpl), Options(Options), OS(OS),
+        SrcMgr(this->Target.AST.getSourceManager()),
+        LangOpts(this->Target.AST.getLangOpts()), ModelSrcCopy(ModelSrc),
+        ModelDiff(ModelSrc, ModelDst, Options),
+        ModelTargetDiff(ModelSrcCopy, Target, Options) {}
+
+  bool apply() {
+    addDeletions();
+    Rewriter Rewrite(SrcMgr, LangOpts);
+    if (!applyAllReplacements(Replaces, Rewrite)) {
+      llvm::errs() << "failed to apply replacements\n";
+      return false;
+    }
+    Rewrite.getEditBuffer(SrcMgr.getMainFileID()).write(OS);
+    return true;
+  }
+
+private:
+  void addDeletions() {
+    for (NodeId Id = ModelSrc.getRootId(), E = ModelSrc.getSize(); Id < E;
+         ++Id) {
+      const Node &ModelNode = ModelSrc.getNode(Id);
+      if (ModelNode.Change != Delete)
+        continue;
+      NodeId TargetId = ModelTargetDiff.getMapped(ModelSrcCopy, Id);
+      if (TargetId.isInvalid())
+        continue;
+      Replacement R(SrcMgr, findRangeForDeletion(TargetId), "", LangOpts);
+      if (Replaces.add(R))
+        llvm::errs() << "Info: Failed to add replacement.\n";
+      Id = ModelNode.RightMostDescendant;
+    }
+  }
+
+  CharSourceRange findRangeForDeletion(NodeId Id) {
+    const Node &N = Target.getNode(Id);
+    SourceRange Range = Target.getSourceRange(N);
+    if (N.Parent.isInvalid())
+      return {Range, false};
+    const Node &Parent = Target.getNode(N.Parent);
+    auto &DTN = Parent.ASTNode;
+    size_t SiblingIndex = Target.findPositionInParent(Id);
+    const auto &Siblings = Parent.Children;
+    // Remove the comma if the location is within a comma-separated list of at
+    // least size 2 (minus the callee for CallExpr).
+    if (DTN.get<CallExpr>() && Siblings.size() > 2) {
+      bool LastSibling = SiblingIndex == Siblings.size() - 1;
+      SourceLocation CommaLoc = Range.getEnd();
+      if (LastSibling)
+        CommaLoc =
+            Target.getSourceRange(Target.getNode(Siblings[SiblingIndex - 1]))
+                .getEnd()
+                .getLocWithOffset(-1);
+      CommaLoc =
+          Lexer::findLocationAfterToken(CommaLoc, tok::comma, SrcMgr, LangOpts,
+                                        /*SkipTrailingWhitespaceAndNewLine=*/
+                                        false);
+      if (LastSibling)
+        Range.setBegin(CommaLoc.getLocWithOffset(-1));
+      else
+        Range.setEnd(CommaLoc);
+    }
+    return {Range, false};
+  }
+};
+
+bool patch(SyntaxTree &ModelSrc, SyntaxTree &ModelDst, SyntaxTree &Target,
+           const ComparisonOptions &Options, raw_ostream &OS) {
+  return Patcher(ModelSrc, ModelDst, Target, Options, OS).apply();
+}
+
 } // end namespace diff
 } // end namespace clang
Index: include/clang/Tooling/ASTDiff/ASTDiff.h
===================================================================
--- include/clang/Tooling/ASTDiff/ASTDiff.h
+++ include/clang/Tooling/ASTDiff/ASTDiff.h
@@ -21,6 +21,7 @@
 #define LLVM_CLANG_TOOLING_ASTDIFF_ASTDIFF_H
 
 #include "clang/Frontend/ASTUnit.h"
+#include "clang/Rewrite/Core/Rewriter.h"
 #include "clang/Tooling/ASTDiff/ASTDiffInternal.h"
 
 namespace clang {
@@ -51,6 +52,9 @@
   llvm::Optional<std::string> getQualifiedIdentifier() const;
 };
 
+bool patch(SyntaxTree &ModelSrc, SyntaxTree &ModelDst, SyntaxTree &TargetSrc,
+           const ComparisonOptions &Options, raw_ostream &OS);
+
 class ASTDiff {
 public:
   ASTDiff(SyntaxTree &Src, SyntaxTree &Dst, const ComparisonOptions &Options);
@@ -69,13 +73,17 @@
 /// They can be constructed from any Decl or Stmt.
 class SyntaxTree {
 public:
+  /// Empty (invalid) SyntaxTree.
+  SyntaxTree();
   /// Constructs a tree from a translation unit.
   SyntaxTree(ASTUnit &AST);
   /// Constructs a tree from any AST node.
   template <class T>
   SyntaxTree(T *Node, ASTUnit &AST)
       : TreeImpl(llvm::make_unique<Impl>(this, Node, AST)) {}
-  SyntaxTree(SyntaxTree &&Other) = default;
+  SyntaxTree(SyntaxTree &&Other);
+  SyntaxTree &operator=(SyntaxTree &&Other);
+  explicit SyntaxTree(const SyntaxTree &Other);
   ~SyntaxTree();
 
   ASTUnit &getASTUnit() const;
@@ -93,7 +101,7 @@
 
   /// Returns the range that contains the text that is associated with this
   /// node.
-  /* SourceRange getSourceRange(const Node &N) const; */
+  SourceRange getSourceRange(const Node &N) const;
   /// Returns the offsets for the range returned by getSourceRange.
   std::pair<unsigned, unsigned> getSourceRangeOffsets(const Node &N) const;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to