junaire updated this revision to Diff 435782.
junaire added a comment.

Allow undo command failed in the edge cases like undoed too many times.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D126682/new/

https://reviews.llvm.org/D126682

Files:
  clang/include/clang/Interpreter/Interpreter.h
  clang/lib/Interpreter/IncrementalExecutor.cpp
  clang/lib/Interpreter/IncrementalExecutor.h
  clang/lib/Interpreter/IncrementalParser.cpp
  clang/lib/Interpreter/IncrementalParser.h
  clang/lib/Interpreter/Interpreter.cpp
  clang/tools/clang-repl/ClangRepl.cpp
  clang/unittests/Interpreter/InterpreterTest.cpp

Index: clang/unittests/Interpreter/InterpreterTest.cpp
===================================================================
--- clang/unittests/Interpreter/InterpreterTest.cpp
+++ clang/unittests/Interpreter/InterpreterTest.cpp
@@ -248,4 +248,38 @@
   EXPECT_EQ(42, fn(NewA));
 }
 
+TEST(InterpreterTest, UndoBasic) {
+  Args ExtraArgs = {"-Xclang", "-diagnostic-log-file", "-Xclang", "-"};
+
+  // Create the diagnostic engine with unowned consumer.
+  std::string DiagnosticOutput;
+  llvm::raw_string_ostream DiagnosticsOS(DiagnosticOutput);
+  auto DiagPrinter = std::make_unique<TextDiagnosticPrinter>(
+      DiagnosticsOS, new DiagnosticOptions());
+
+  auto Interp = createInterpreter(ExtraArgs, DiagPrinter.get());
+
+  auto R1 = Interp->Parse("int x = 42;");
+  EXPECT_TRUE(!!R1);
+
+  llvm::cantFail(Interp->Undo());
+
+  auto R2 = Interp->Parse("int x = 24;");
+  EXPECT_TRUE(!!R2);
+
+  auto R3 = Interp->Parse("int foo() { return 42;}");
+  EXPECT_TRUE(!!R3);
+
+  auto R4 = Interp->Parse("int bar = foo();");
+  EXPECT_TRUE(!!R4);
+
+  llvm::cantFail(Interp->Undo());
+
+  auto R5 = Interp->Parse("int x = 24;");
+  EXPECT_TRUE(!!R5);
+
+  auto R6 = Interp->Parse("int baz = foo();");
+  EXPECT_TRUE(!!R6);
+}
+
 } // end anonymous namespace
Index: clang/tools/clang-repl/ClangRepl.cpp
===================================================================
--- clang/tools/clang-repl/ClangRepl.cpp
+++ clang/tools/clang-repl/ClangRepl.cpp
@@ -92,8 +92,14 @@
     llvm::LineEditor LE("clang-repl");
     // FIXME: Add LE.setListCompleter
     while (llvm::Optional<std::string> Line = LE.readLine()) {
-      if (*Line == "quit")
+      if (*Line == R"(%quit)")
         break;
+      if (*Line == R"(%undo)") {
+        if (auto Err = Interp->Undo())
+          llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
+        continue;
+      }
+
       if (auto Err = Interp->ParseAndExecute(*Line))
         llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
     }
Index: clang/lib/Interpreter/Interpreter.cpp
===================================================================
--- clang/lib/Interpreter/Interpreter.cpp
+++ clang/lib/Interpreter/Interpreter.cpp
@@ -218,8 +218,7 @@
     if (Err)
       return Err;
   }
-  // FIXME: Add a callback to retain the llvm::Module once the JIT is done.
-  if (auto Err = IncrExecutor->addModule(std::move(T.TheModule)))
+  if (auto Err = IncrExecutor->addModule(T))
     return Err;
 
   if (auto Err = IncrExecutor->runCtors())
@@ -228,6 +227,10 @@
   return llvm::Error::success();
 }
 
+void Interpreter::Restore(PartialTranslationUnit &PTU) {
+  IncrParser->Restore(PTU);
+}
+
 llvm::Expected<llvm::JITTargetAddress>
 Interpreter::getSymbolAddress(GlobalDecl GD) const {
   if (!IncrExecutor)
@@ -257,3 +260,18 @@
 
   return IncrExecutor->getSymbolAddress(Name, IncrementalExecutor::LinkerName);
 }
+
+llvm::Error Interpreter::Undo(unsigned N) {
+  auto &PTUs = IncrParser->getPTUs();
+  if (N > PTUs.size())
+    return llvm::make_error<llvm::StringError>("Operation failed, "
+                                               "too many undos",
+                                               std::error_code());
+  for (unsigned I = 0; I < N; I++) {
+    if (llvm::Error Err = IncrExecutor->removeModule(PTUs.back()))
+      return Err;
+    Restore(PTUs.back());
+    PTUs.pop_back();
+  }
+  return llvm::Error::success();
+}
Index: clang/lib/Interpreter/IncrementalParser.h
===================================================================
--- clang/lib/Interpreter/IncrementalParser.h
+++ clang/lib/Interpreter/IncrementalParser.h
@@ -72,6 +72,10 @@
   ///\returns the mangled name of a \c GD.
   llvm::StringRef GetMangledName(GlobalDecl GD) const;
 
+  void Restore(PartialTranslationUnit &PTU);
+
+  std::list<PartialTranslationUnit> &getPTUs() { return PTUs; }
+
 private:
   llvm::Expected<PartialTranslationUnit &> ParseOrWrapTopLevelDecl();
 };
Index: clang/lib/Interpreter/IncrementalParser.cpp
===================================================================
--- clang/lib/Interpreter/IncrementalParser.cpp
+++ clang/lib/Interpreter/IncrementalParser.cpp
@@ -293,6 +293,24 @@
   return PTU;
 }
 
+void IncrementalParser::Restore(PartialTranslationUnit &PTU) {
+  TranslationUnitDecl *MostRecentTU = PTU.TUPart;
+  TranslationUnitDecl *FirstTU = MostRecentTU->getFirstDecl();
+  if (StoredDeclsMap *Map = FirstTU->getLookupPtr()) {
+    for (auto I = Map->begin(); I != Map->end(); ++I) {
+      StoredDeclsList &List = I->second;
+      DeclContextLookupResult R = List.getLookupResult();
+      for (NamedDecl *D : R) {
+        if (D->getTranslationUnitDecl() == MostRecentTU) {
+          List.remove(D);
+        }
+      }
+      if (List.isNull())
+        Map->erase(I);
+    }
+  }
+}
+
 llvm::StringRef IncrementalParser::GetMangledName(GlobalDecl GD) const {
   CodeGenerator *CG = getCodeGen(Act.get());
   assert(CG);
Index: clang/lib/Interpreter/IncrementalExecutor.h
===================================================================
--- clang/lib/Interpreter/IncrementalExecutor.h
+++ clang/lib/Interpreter/IncrementalExecutor.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CLANG_LIB_INTERPRETER_INCREMENTALEXECUTOR_H
 #define LLVM_CLANG_LIB_INTERPRETER_INCREMENTALEXECUTOR_H
 
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Triple.h"
 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
@@ -29,11 +30,17 @@
 } // namespace llvm
 
 namespace clang {
+
+struct PartialTranslationUnit;
+
 class IncrementalExecutor {
   using CtorDtorIterator = llvm::orc::CtorDtorIterator;
   std::unique_ptr<llvm::orc::LLJIT> Jit;
   llvm::orc::ThreadSafeContext &TSCtx;
 
+  llvm::DenseMap<const PartialTranslationUnit *, llvm::orc::ResourceTrackerSP>
+      ResourceTrackers;
+
 public:
   enum SymbolNameKind { IRName, LinkerName };
 
@@ -41,7 +48,8 @@
                       const llvm::Triple &Triple);
   ~IncrementalExecutor();
 
-  llvm::Error addModule(std::unique_ptr<llvm::Module> M);
+  llvm::Error addModule(PartialTranslationUnit &PTU);
+  llvm::Error removeModule(PartialTranslationUnit &PTU);
   llvm::Error runCtors() const;
   llvm::Expected<llvm::JITTargetAddress>
   getSymbolAddress(llvm::StringRef Name, SymbolNameKind NameKind) const;
Index: clang/lib/Interpreter/IncrementalExecutor.cpp
===================================================================
--- clang/lib/Interpreter/IncrementalExecutor.cpp
+++ clang/lib/Interpreter/IncrementalExecutor.cpp
@@ -12,6 +12,7 @@
 
 #include "IncrementalExecutor.h"
 
+#include "clang/Interpreter/PartialTranslationUnit.h"
 #include "llvm/ExecutionEngine/ExecutionEngine.h"
 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
@@ -52,8 +53,24 @@
 
 IncrementalExecutor::~IncrementalExecutor() {}
 
-llvm::Error IncrementalExecutor::addModule(std::unique_ptr<llvm::Module> M) {
-  return Jit->addIRModule(llvm::orc::ThreadSafeModule(std::move(M), TSCtx));
+llvm::Error IncrementalExecutor::addModule(PartialTranslationUnit &PTU) {
+  llvm::orc::ResourceTrackerSP RT =
+      Jit->getMainJITDylib().createResourceTracker();
+  ResourceTrackers[&PTU] = RT;
+
+  return Jit->addIRModule(RT, {std::move(PTU.TheModule), TSCtx});
+}
+
+llvm::Error IncrementalExecutor::removeModule(PartialTranslationUnit &PTU) {
+
+  llvm::orc::ResourceTrackerSP RT = std::move(ResourceTrackers[&PTU]);
+  if (!RT)
+    return llvm::Error::success();
+
+  ResourceTrackers.erase(&PTU);
+  if (llvm::Error Err = RT->remove())
+    return Err;
+  return llvm::Error::success();
 }
 
 llvm::Error IncrementalExecutor::runCtors() const {
Index: clang/include/clang/Interpreter/Interpreter.h
===================================================================
--- clang/include/clang/Interpreter/Interpreter.h
+++ clang/include/clang/Interpreter/Interpreter.h
@@ -69,6 +69,12 @@
     return llvm::Error::success();
   }
 
+  void Restore(PartialTranslationUnit &PTU);
+
+  /// Undo previous parse results for N times. It'll stop and report an error
+  /// if an error occurs.
+  llvm::Error Undo(unsigned N = 1);
+
   /// \returns the \c JITTargetAddress of a \c GlobalDecl. This interface uses
   /// the CodeGenModule's internal mangling cache to avoid recomputing the
   /// mangled name.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to