capfredf created this revision.
Herald added subscribers: ChuanqiXu, kadircet, arphaman.
Herald added a project: All.
capfredf requested review of this revision.
Herald added projects: clang, clang-tools-extra.
Herald added a subscriber: cfe-commits.

This patch enabled code completion for ClangREPL. The feature was built upon 
three
existing Clang components: a list completer for `LineEditor`,
a CompletionConsumer from `SemaCodeCompletion`, and the ASTUnit::codeComplete 
method.
with the first component serving as the main entry point of handling 
interactive inputs.

Because a completion point for a compiler instance has to be unchanged once it
is set, an incremental compiler instance is created for each code
completion. Such a compiler instance carries over AST context source from the
main interpreter compiler in order to obtain declarations or bindings from
previous input in the same REPL session.

The most important API `codeComplete` in `Interpreter/CodeCompletion` is a thin
wrapper that calls with ASTUnit::codeComplete with necessary arguments, such as
a code completion point and a `ReplCompletionConsumer`, which communicates
completion results from `SemaCodeCompletion` back to the list completer for the
REPL.

In addition, `PCC_TopLevelOrExpression and `CCC_TopLevelOrExpression` top levels
were added so that `SemaCodeCompletion` can treat top level statements like
expression statements at the REPL. For example,

  clang-repl> int foo = 42;
  clang-repl> f<tab>

>From a parser's persective, the cursor is at a top level. If we used code
completion without any changes, `PCC_Namespace` would be supplied to
`Sema::CodeCompleteOrdinaryName`, and thus the completion results would not
include `foo`.

Currently, the way we use `PCC_TopLevelOrExpression` and
`CCC_TopLevelOrExpression` is no different from the way we use `PCC_Statement`
and `CCC_Statement` respectively.

Previous Differential Revision: https://reviews.llvm.org/D154382


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D158629

Files:
  clang-tools-extra/clangd/CodeComplete.cpp
  clang/include/clang/Frontend/ASTUnit.h
  clang/include/clang/Interpreter/CodeCompletion.h
  clang/include/clang/Sema/CodeCompleteConsumer.h
  clang/include/clang/Sema/Sema.h
  clang/lib/Frontend/ASTUnit.cpp
  clang/lib/Interpreter/CMakeLists.txt
  clang/lib/Interpreter/CodeCompletion.cpp
  clang/lib/Interpreter/IncrementalParser.cpp
  clang/lib/Interpreter/IncrementalParser.h
  clang/lib/Interpreter/Interpreter.cpp
  clang/lib/Parse/ParseDecl.cpp
  clang/lib/Parse/Parser.cpp
  clang/lib/Sema/CodeCompleteConsumer.cpp
  clang/lib/Sema/SemaCodeComplete.cpp
  clang/test/CodeCompletion/incrememal-mode-completion-no-error.cpp
  clang/test/CodeCompletion/incremental-top-level.cpp
  clang/tools/clang-repl/ClangRepl.cpp
  clang/tools/libclang/CIndexCodeCompletion.cpp
  clang/unittests/Interpreter/CMakeLists.txt
  clang/unittests/Interpreter/CodeCompletionTest.cpp

Index: clang/unittests/Interpreter/CodeCompletionTest.cpp
===================================================================
--- /dev/null
+++ clang/unittests/Interpreter/CodeCompletionTest.cpp
@@ -0,0 +1,101 @@
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Interpreter/Interpreter.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
+#include "llvm/LineEditor/LineEditor.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+namespace {
+auto CB = clang::IncrementalCompilerBuilder();
+
+static std::unique_ptr<Interpreter> createInterpreter() {
+  auto CI = cantFail(CB.CreateCpp());
+  return cantFail(clang::Interpreter::create(std::move(CI)));
+}
+
+static std::vector<std::string> runComp(clang::Interpreter &MainInterp,
+                                        llvm::StringRef Prefix,
+                                        llvm::Error &ErrR) {
+  auto CI = CB.CreateCpp();
+  if (auto Err = CI.takeError()) {
+    ErrR = std::move(Err);
+    return {};
+  }
+
+  auto Interp = clang::Interpreter::create(std::move(*CI));
+  if (auto Err = Interp.takeError()) {
+    // log the error and returns an empty vector;
+    ErrR = std::move(Err);
+
+    return {};
+  }
+
+  std::vector<clang::CodeCompletionResult> Results;
+
+  codeComplete(
+      const_cast<clang::CompilerInstance *>((*Interp)->getCompilerInstance()),
+      Prefix, 1, Prefix.size(), MainInterp.getCompilerInstance(), Results);
+
+  std::vector<std::string> Comps;
+  for (auto c : convertToCodeCompleteStrings(Results)) {
+    if (c.find(Prefix) == 0)
+      Comps.push_back(c.substr(Prefix.size()));
+  }
+
+  return Comps;
+}
+
+TEST(CodeCompletionTest, Sanity) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "f", Err);
+  EXPECT_EQ((size_t)2, comps.size()); // foo and float
+  EXPECT_EQ(comps[0], std::string("oo"));
+  EXPECT_EQ((bool)Err, false);
+}
+
+TEST(CodeCompletionTest, SanityNoneValid) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "babanana", Err);
+  EXPECT_EQ((size_t)0, comps.size()); // foo and float
+  EXPECT_EQ((bool)Err, false);
+}
+
+TEST(CodeCompletionTest, TwoDecls) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int application = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  if (auto R = Interp->ParseAndExecute("int apple = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "app", Err);
+  EXPECT_EQ((size_t)2, comps.size());
+  EXPECT_EQ((bool)Err, false);
+}
+
+TEST(CodeCompletionTest, CompFunDeclsNoError) {
+  auto Interp = createInterpreter();
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "void app(", Err);
+  EXPECT_EQ((bool)Err, false);
+}
+
+} // anonymous namespace
Index: clang/unittests/Interpreter/CMakeLists.txt
===================================================================
--- clang/unittests/Interpreter/CMakeLists.txt
+++ clang/unittests/Interpreter/CMakeLists.txt
@@ -9,6 +9,7 @@
 add_clang_unittest(ClangReplInterpreterTests
   IncrementalProcessingTest.cpp
   InterpreterTest.cpp
+  CodeCompletionTest.cpp
   )
 target_link_libraries(ClangReplInterpreterTests PUBLIC
   clangAST
Index: clang/tools/libclang/CIndexCodeCompletion.cpp
===================================================================
--- clang/tools/libclang/CIndexCodeCompletion.cpp
+++ clang/tools/libclang/CIndexCodeCompletion.cpp
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "CIndexer.h"
 #include "CIndexDiagnostic.h"
+#include "CIndexer.h"
 #include "CLog.h"
 #include "CXCursor.h"
 #include "CXSourceLocation.h"
@@ -25,6 +25,7 @@
 #include "clang/Basic/SourceManager.h"
 #include "clang/Frontend/ASTUnit.h"
 #include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendActions.h"
 #include "clang/Sema/CodeCompleteConsumer.h"
 #include "clang/Sema/Sema.h"
 #include "llvm/ADT/SmallString.h"
@@ -41,7 +42,6 @@
 #include <cstdlib>
 #include <string>
 
-
 #ifdef UDP_CODE_COMPLETION_LOGGER
 #include "clang/Basic/Version.h"
 #include <arpa/inet.h>
@@ -543,6 +543,7 @@
     case CodeCompletionContext::CCC_PreprocessorExpression:
     case CodeCompletionContext::CCC_PreprocessorDirective:
     case CodeCompletionContext::CCC_Attribute:
+    case CodeCompletionContext::CCC_TopLevelOrExpression:
     case CodeCompletionContext::CCC_TypeQualifiers: {
       //Only Clang results should be accepted, so we'll set all of the other
       //context bits to 0 (i.e. the empty set)
Index: clang/tools/clang-repl/ClangRepl.cpp
===================================================================
--- clang/tools/clang-repl/ClangRepl.cpp
+++ clang/tools/clang-repl/ClangRepl.cpp
@@ -13,7 +13,9 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/FrontendDiagnostic.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Interpreter.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
 
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
 #include "llvm/LineEditor/LineEditor.h"
@@ -70,6 +72,70 @@
   return (Errs || HasError) ? EXIT_FAILURE : EXIT_SUCCESS;
 }
 
+struct ReplListCompleter {
+  clang::IncrementalCompilerBuilder &CB;
+  clang::Interpreter &MainInterp;
+  ReplListCompleter(clang::IncrementalCompilerBuilder &CB,
+                    clang::Interpreter &Interp)
+      : CB(CB), MainInterp(Interp){};
+
+  std::vector<llvm::LineEditor::Completion> operator()(llvm::StringRef Buffer,
+                                                       size_t Pos) const;
+  std::vector<llvm::LineEditor::Completion>
+  operator()(llvm::StringRef Buffer, size_t Pos, llvm::Error &ErrRes) const;
+};
+
+std::vector<llvm::LineEditor::Completion>
+ReplListCompleter::operator()(llvm::StringRef Buffer, size_t Pos) const {
+  auto Err = llvm::Error::success();
+  auto res = (*this)(Buffer, Pos, Err);
+  if (Err)
+    llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
+  return res;
+}
+
+std::vector<llvm::LineEditor::Completion>
+ReplListCompleter::operator()(llvm::StringRef Buffer, size_t Pos,
+                              llvm::Error &ErrRes) const {
+  std::vector<llvm::LineEditor::Completion> Comps;
+  std::vector<clang::CodeCompletionResult> Results;
+
+  auto CI = CB.CreateCpp();
+  if (auto Err = CI.takeError()) {
+    ErrRes = std::move(Err);
+    return {};
+  }
+
+  size_t Lines =
+      std::count(Buffer.begin(), std::next(Buffer.begin(), Pos), '\n') + 1;
+  auto Interp = clang::Interpreter::create(std::move(*CI));
+
+  if (auto Err = Interp.takeError()) {
+    // log the error and returns an empty vector;
+    ErrRes = std::move(Err);
+
+    return {};
+  }
+
+  codeComplete(
+      const_cast<clang::CompilerInstance *>((*Interp)->getCompilerInstance()),
+      Buffer, Lines, Pos + 1, MainInterp.getCompilerInstance(), Results);
+
+  size_t space_pos = Buffer.rfind(" ");
+  llvm::StringRef s;
+  if (space_pos == llvm::StringRef::npos) {
+    s = Buffer;
+  } else {
+    s = Buffer.substr(space_pos + 1);
+  }
+
+  for (auto c : convertToCodeCompleteStrings(Results)) {
+    if (c.find(s) == 0)
+      Comps.push_back(llvm::LineEditor::Completion(c.substr(s.size()), c));
+  }
+  return Comps;
+}
+
 llvm::ExitOnError ExitOnErr;
 int main(int argc, const char **argv) {
   ExitOnErr.setBanner("clang-repl: ");
@@ -133,6 +199,7 @@
     DeviceCI->LoadRequestedPlugins();
 
   std::unique_ptr<clang::Interpreter> Interp;
+
   if (CudaEnabled) {
     Interp = ExitOnErr(
         clang::Interpreter::createWithCUDA(std::move(CI), std::move(DeviceCI)));
@@ -155,8 +222,8 @@
 
   if (OptInputs.empty()) {
     llvm::LineEditor LE("clang-repl");
-    // FIXME: Add LE.setListCompleter
     std::string Input;
+    LE.setListCompleter(ReplListCompleter(CB, *Interp));
     while (std::optional<std::string> Line = LE.readLine()) {
       llvm::StringRef L = *Line;
       L = L.trim();
@@ -168,10 +235,10 @@
       }
 
       Input += L;
-
       if (Input == R"(%quit)") {
         break;
-      } else if (Input == R"(%undo)") {
+      }
+      if (Input == R"(%undo)") {
         if (auto Err = Interp->Undo()) {
           llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
           HasError = true;
Index: clang/test/CodeCompletion/incremental-top-level.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeCompletion/incremental-top-level.cpp
@@ -0,0 +1,4 @@
+int foo = 10;
+f
+// RUN: %clang_cc1 -fincremental-extensions -fsyntax-only -code-completion-at=%s:%(line-1):1 %s | FileCheck %s
+// CHECK: COMPLETION: foo : [#int#]foo
Index: clang/test/CodeCompletion/incrememal-mode-completion-no-error.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeCompletion/incrememal-mode-completion-no-error.cpp
@@ -0,0 +1,3 @@
+void foo(
+// RUN: %clang_cc1 -fincremental-extensions -fsyntax-only -code-completion-at=%s:%(line-1):9 %s | wc -c | FileCheck %s
+// CHECK: 0
Index: clang/lib/Sema/SemaCodeComplete.cpp
===================================================================
--- clang/lib/Sema/SemaCodeComplete.cpp
+++ clang/lib/Sema/SemaCodeComplete.cpp
@@ -225,6 +225,7 @@
     case CodeCompletionContext::CCC_ObjCMessageReceiver:
     case CodeCompletionContext::CCC_ParenthesizedExpression:
     case CodeCompletionContext::CCC_Statement:
+    case CodeCompletionContext::CCC_TopLevelOrExpression:
     case CodeCompletionContext::CCC_Recovery:
       if (ObjCMethodDecl *Method = SemaRef.getCurMethodDecl())
         if (Method->isInstanceMethod())
@@ -1850,6 +1851,7 @@
   case Sema::PCC_ObjCInstanceVariableList:
   case Sema::PCC_Expression:
   case Sema::PCC_Statement:
+  case Sema::PCC_TopLevelOrExpression:
   case Sema::PCC_ForInit:
   case Sema::PCC_Condition:
   case Sema::PCC_RecoveryInFunction:
@@ -1907,6 +1909,7 @@
   case Sema::PCC_Type:
   case Sema::PCC_ParenthesizedExpression:
   case Sema::PCC_LocalDeclarationSpecifiers:
+  case Sema::PCC_TopLevelOrExpression:
     return true;
 
   case Sema::PCC_Expression:
@@ -2219,6 +2222,7 @@
     break;
 
   case Sema::PCC_RecoveryInFunction:
+  case Sema::PCC_TopLevelOrExpression:
   case Sema::PCC_Statement: {
     if (SemaRef.getLangOpts().CPlusPlus11)
       AddUsingAliasResult(Builder, Results);
@@ -4208,6 +4212,8 @@
 
   case Sema::PCC_LocalDeclarationSpecifiers:
     return CodeCompletionContext::CCC_Type;
+  case Sema::PCC_TopLevelOrExpression:
+    return CodeCompletionContext::CCC_TopLevelOrExpression;
   }
 
   llvm_unreachable("Invalid ParserCompletionContext!");
@@ -4348,6 +4354,7 @@
     break;
 
   case PCC_Statement:
+  case PCC_TopLevelOrExpression:
   case PCC_ParenthesizedExpression:
   case PCC_Expression:
   case PCC_ForInit:
@@ -4385,6 +4392,7 @@
   case PCC_ParenthesizedExpression:
   case PCC_Expression:
   case PCC_Statement:
+  case PCC_TopLevelOrExpression:
   case PCC_RecoveryInFunction:
     if (S->getFnParent())
       AddPrettyFunctionResults(getLangOpts(), Results);
Index: clang/lib/Sema/CodeCompleteConsumer.cpp
===================================================================
--- clang/lib/Sema/CodeCompleteConsumer.cpp
+++ clang/lib/Sema/CodeCompleteConsumer.cpp
@@ -51,6 +51,7 @@
   case CCC_ParenthesizedExpression:
   case CCC_Symbol:
   case CCC_SymbolOrNewName:
+  case CCC_TopLevelOrExpression:
     return true;
 
   case CCC_TopLevel:
@@ -169,6 +170,8 @@
     return "Recovery";
   case CCKind::CCC_ObjCClassForwardDecl:
     return "ObjCClassForwardDecl";
+  case CCKind::CCC_TopLevelOrExpression:
+    return "ReplTopLevel";
   }
   llvm_unreachable("Invalid CodeCompletionContext::Kind!");
 }
Index: clang/lib/Parse/Parser.cpp
===================================================================
--- clang/lib/Parse/Parser.cpp
+++ clang/lib/Parse/Parser.cpp
@@ -923,9 +923,16 @@
                                          /*IsInstanceMethod=*/std::nullopt,
                                          /*ReturnType=*/nullptr);
     }
-    Actions.CodeCompleteOrdinaryName(
-        getCurScope(),
-        CurParsedObjCImpl ? Sema::PCC_ObjCImplementation : Sema::PCC_Namespace);
+
+    Sema::ParserCompletionContext PCC;
+    if (CurParsedObjCImpl) {
+      PCC = Sema::PCC_ObjCImplementation;
+    } else if (PP.isIncrementalProcessingEnabled()) {
+      PCC = Sema::PCC_TopLevelOrExpression;
+    } else {
+      PCC = Sema::PCC_Namespace;
+    };
+    Actions.CodeCompleteOrdinaryName(getCurScope(), PCC);
     return nullptr;
   case tok::kw_import: {
     Sema::ModuleImportState IS = Sema::ModuleImportState::NotACXX20Module;
Index: clang/lib/Parse/ParseDecl.cpp
===================================================================
--- clang/lib/Parse/ParseDecl.cpp
+++ clang/lib/Parse/ParseDecl.cpp
@@ -18,6 +18,7 @@
 #include "clang/Basic/Attributes.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/TargetInfo.h"
+#include "clang/Basic/TokenKinds.h"
 #include "clang/Parse/ParseDiagnostic.h"
 #include "clang/Parse/Parser.h"
 #include "clang/Parse/RAIIObjectsForParser.h"
Index: clang/lib/Interpreter/Interpreter.cpp
===================================================================
--- clang/lib/Interpreter/Interpreter.cpp
+++ clang/lib/Interpreter/Interpreter.cpp
@@ -11,13 +11,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Interpreter/Interpreter.h"
-
 #include "DeviceOffload.h"
 #include "IncrementalExecutor.h"
 #include "IncrementalParser.h"
-
 #include "InterpreterUtils.h"
+
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Mangle.h"
 #include "clang/AST/TypeVisitor.h"
@@ -33,6 +31,7 @@
 #include "clang/Driver/Tool.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/TextDiagnosticBuffer.h"
+#include "clang/Interpreter/Interpreter.h"
 #include "clang/Interpreter/Value.h"
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Sema/Lookup.h"
@@ -127,7 +126,6 @@
 
   Clang->getFrontendOpts().DisableFree = false;
   Clang->getCodeGenOpts().DisableFree = false;
-
   return std::move(Clang);
 }
 
@@ -276,6 +274,7 @@
       std::unique_ptr<Interpreter>(new Interpreter(std::move(CI), Err));
   if (Err)
     return std::move(Err);
+
   auto PTU = Interp->Parse(Runtimes);
   if (!PTU)
     return PTU.takeError();
Index: clang/lib/Interpreter/IncrementalParser.h
===================================================================
--- clang/lib/Interpreter/IncrementalParser.h
+++ clang/lib/Interpreter/IncrementalParser.h
@@ -13,9 +13,9 @@
 #ifndef LLVM_CLANG_LIB_INTERPRETER_INCREMENTALPARSER_H
 #define LLVM_CLANG_LIB_INTERPRETER_INCREMENTALPARSER_H
 
+#include "clang/AST/GlobalDecl.h"
 #include "clang/Interpreter/PartialTranslationUnit.h"
 
-#include "clang/AST/GlobalDecl.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Error.h"
@@ -24,7 +24,7 @@
 #include <memory>
 namespace llvm {
 class LLVMContext;
-}
+} // namespace llvm
 
 namespace clang {
 class ASTConsumer;
Index: clang/lib/Interpreter/IncrementalParser.cpp
===================================================================
--- clang/lib/Interpreter/IncrementalParser.cpp
+++ clang/lib/Interpreter/IncrementalParser.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "IncrementalParser.h"
+
 #include "clang/AST/DeclContextInternals.h"
 #include "clang/CodeGen/BackendUtil.h"
 #include "clang/CodeGen/CodeGenAction.h"
@@ -157,16 +158,11 @@
   TranslationUnitKind getTranslationUnitKind() override {
     return TU_Incremental;
   }
+
   void ExecuteAction() override {
     CompilerInstance &CI = getCompilerInstance();
     assert(CI.hasPreprocessor() && "No PP!");
 
-    // FIXME: Move the truncation aspect of this into Sema, we delayed this till
-    // here so the source manager would be initialized.
-    if (hasCodeCompletionSupport() &&
-        !CI.getFrontendOpts().CodeCompletionAt.FileName.empty())
-      CI.createCodeCompletionConsumer();
-
     // Use a code completion consumer?
     CodeCompleteConsumer *CompletionConsumer = nullptr;
     if (CI.hasCodeCompletionConsumer())
@@ -398,5 +394,4 @@
   assert(CG);
   return CG->GetMangledName(GD);
 }
-
 } // end namespace clang
Index: clang/lib/Interpreter/CodeCompletion.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/CodeCompletion.cpp
@@ -0,0 +1,231 @@
+//===------ CodeCompletion.cpp - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/AST/ASTImporter.h"
+#include "clang/AST/DeclarationName.h"
+#include "clang/AST/ExternalASTSource.h"
+#include "clang/Basic/IdentifierTable.h"
+#include "clang/Frontend/ASTUnit.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Interpreter/Interpreter.h"
+#include "clang/Lex/PreprocessorOptions.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
+#include "clang/Sema/CodeCompleteOptions.h"
+#include "clang/Sema/Sema.h"
+
+namespace clang {
+
+const std::string CodeCompletionFileName = "input_line_[Completion]";
+
+clang::CodeCompleteOptions getClangCompleteOpts() {
+  clang::CodeCompleteOptions Opts;
+  Opts.IncludeCodePatterns = true;
+  Opts.IncludeMacros = true;
+  Opts.IncludeGlobals = true;
+  Opts.IncludeBriefComments = true;
+  return Opts;
+}
+
+class ReplCompletionConsumer : public CodeCompleteConsumer {
+public:
+  ReplCompletionConsumer(std::vector<CodeCompletionResult> &Results)
+      : CodeCompleteConsumer(getClangCompleteOpts()),
+        CCAllocator(std::make_shared<GlobalCodeCompletionAllocator>()),
+        CCTUInfo(CCAllocator), Results(Results){};
+
+  void ProcessCodeCompleteResults(class Sema &S, CodeCompletionContext Context,
+                                  CodeCompletionResult *InResults,
+                                  unsigned NumResults) final;
+
+  CodeCompletionAllocator &getAllocator() override { return *CCAllocator; }
+
+  CodeCompletionTUInfo &getCodeCompletionTUInfo() override { return CCTUInfo; }
+
+private:
+  std::shared_ptr<GlobalCodeCompletionAllocator> CCAllocator;
+  CodeCompletionTUInfo CCTUInfo;
+  std::vector<CodeCompletionResult> &Results;
+};
+
+void ReplCompletionConsumer::ProcessCodeCompleteResults(
+    class Sema &S, CodeCompletionContext Context,
+    CodeCompletionResult *InResults, unsigned NumResults) {
+  for (unsigned I = 0; I < NumResults; ++I) {
+    auto &Result = InResults[I];
+    switch (Result.Kind) {
+    case CodeCompletionResult::RK_Declaration:
+      if (Result.Declaration->getIdentifier()) {
+        Results.push_back(Result);
+      }
+      break;
+    case CodeCompletionResult::RK_Keyword:
+      Results.push_back(Result);
+      break;
+    default:
+      break;
+    }
+  }
+};
+
+std::vector<std::string> convertToCodeCompleteStrings(
+    const std::vector<clang::CodeCompletionResult> &Results) {
+  std::vector<std::string> CompletionStrings;
+  for (auto Res : Results) {
+    switch (Res.Kind) {
+    case clang::CodeCompletionResult::RK_Declaration:
+      if (auto *ID = Res.Declaration->getIdentifier()) {
+        CompletionStrings.push_back(ID->getName().str());
+      }
+      break;
+    case clang::CodeCompletionResult::RK_Keyword:
+      CompletionStrings.push_back(Res.Keyword);
+      break;
+    default:
+      break;
+    }
+  }
+  return CompletionStrings;
+}
+
+class IncrementalSyntaxOnlyAction : public SyntaxOnlyAction {
+  const CompilerInstance *ParentCI;
+
+public:
+  IncrementalSyntaxOnlyAction(const CompilerInstance *ParentCI)
+      : ParentCI(ParentCI) {}
+
+protected:
+  void ExecuteAction() override;
+};
+
+class ExternalSource : public clang::ExternalASTSource {
+  ASTContext &ChildASTCtxt;
+  TranslationUnitDecl *ChildTUDeclCtxt;
+  ASTContext &ParentASTCtxt;
+  TranslationUnitDecl *ParentTUDeclCtxt;
+
+  std::unique_ptr<ASTImporter> Importer;
+
+public:
+  ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                 ASTContext &ParentASTCtxt, FileManager &ParentFM);
+  bool FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                      DeclarationName Name) override;
+  void
+  completeVisibleDeclsMap(const clang::DeclContext *childDeclContext) override;
+};
+
+// This method is intended to set up `ExternalASTSource` to the running
+// compiler instance before the super `ExecuteAction` triggers parsing
+void IncrementalSyntaxOnlyAction::ExecuteAction() {
+  CompilerInstance &CI = getCompilerInstance();
+  ExternalSource *myExternalSource =
+      new ExternalSource(CI.getASTContext(), CI.getFileManager(),
+                         ParentCI->getASTContext(), ParentCI->getFileManager());
+  llvm::IntrusiveRefCntPtr<clang::ExternalASTSource> astContextExternalSource(
+      myExternalSource);
+  CI.getASTContext().setExternalSource(astContextExternalSource);
+  CI.getASTContext().getTranslationUnitDecl()->setHasExternalVisibleStorage(
+      true);
+
+  SyntaxOnlyAction::ExecuteAction();
+}
+
+ExternalSource::ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                               ASTContext &ParentASTCtxt, FileManager &ParentFM)
+    : ChildASTCtxt(ChildASTCtxt),
+      ChildTUDeclCtxt(ChildASTCtxt.getTranslationUnitDecl()),
+      ParentASTCtxt(ParentASTCtxt),
+      ParentTUDeclCtxt(ParentASTCtxt.getTranslationUnitDecl()) {
+  ASTImporter *importer =
+      new ASTImporter(ChildASTCtxt, ChildFM, ParentASTCtxt, ParentFM,
+                      /*MinimalImport : ON*/ true);
+  Importer.reset(importer);
+}
+
+bool ExternalSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                                    DeclarationName Name) {
+  IdentifierTable &ParentIdTable = ParentASTCtxt.Idents;
+
+  auto ParentDeclName =
+      DeclarationName(&(ParentIdTable.get(Name.getAsString())));
+
+  DeclContext::lookup_result lookup_result =
+      ParentTUDeclCtxt->lookup(ParentDeclName);
+
+  if (!lookup_result.empty()) {
+    return true;
+  }
+  return false;
+}
+
+void ExternalSource::completeVisibleDeclsMap(
+    const DeclContext *ChildDeclContext) {
+  assert(ChildDeclContext && ChildDeclContext == ChildTUDeclCtxt &&
+         "No child decl context!");
+
+  if (!ChildDeclContext->hasExternalVisibleStorage())
+    return;
+
+  for (auto *DeclCtxt = ParentTUDeclCtxt; DeclCtxt != nullptr;
+       DeclCtxt = DeclCtxt->getPreviousDecl()) {
+    for (auto &IDeclContext : DeclCtxt->decls()) {
+      if (NamedDecl *Decl = llvm::dyn_cast<NamedDecl>(IDeclContext)) {
+        if (auto DeclOrErr = Importer->Import(Decl)) {
+          if (NamedDecl *importedNamedDecl =
+                  llvm::dyn_cast<NamedDecl>(*DeclOrErr)) {
+            SetExternalVisibleDeclsForName(ChildDeclContext,
+                                           importedNamedDecl->getDeclName(),
+                                           importedNamedDecl);
+          }
+
+        } else {
+          llvm::consumeError(DeclOrErr.takeError());
+        }
+      }
+    }
+    ChildDeclContext->setHasExternalLexicalStorage(false);
+  }
+}
+
+void codeComplete(CompilerInstance *InterpCI, llvm::StringRef Content,
+                  unsigned Line, unsigned Col, const CompilerInstance *ParentCI,
+                  std::vector<CodeCompletionResult> &CCResults) {
+  std::unique_ptr<llvm::MemoryBuffer> MB =
+      llvm::MemoryBuffer::getMemBufferCopy(Content, CodeCompletionFileName);
+  llvm::SmallVector<ASTUnit::RemappedFile, 4> RemappedFiles;
+
+  RemappedFiles.push_back(std::make_pair(CodeCompletionFileName, MB.release()));
+
+  auto DiagOpts = DiagnosticOptions();
+  auto consumer = ReplCompletionConsumer(CCResults);
+
+  auto diag = InterpCI->getDiagnosticsPtr();
+  ASTUnit *AU = ASTUnit::LoadFromCompilerInvocationAction(
+      InterpCI->getInvocationPtr(), std::make_shared<PCHContainerOperations>(),
+      diag);
+  llvm::SmallVector<clang::StoredDiagnostic, 8> sd = {};
+  llvm::SmallVector<const llvm::MemoryBuffer *, 1> tb = {};
+  InterpCI->getFrontendOpts().Inputs[0] = FrontendInputFile(
+      CodeCompletionFileName, Language::CXX, InputKind::Source);
+  auto Act = std::unique_ptr<IncrementalSyntaxOnlyAction>(
+      new IncrementalSyntaxOnlyAction(ParentCI));
+  AU->CodeComplete(CodeCompletionFileName, 1, Col, RemappedFiles, false, false,
+                   false, consumer,
+                   std::make_shared<clang::PCHContainerOperations>(), *diag,
+                   InterpCI->getLangOpts(), InterpCI->getSourceManager(),
+                   InterpCI->getFileManager(), sd, tb, std::move(Act));
+}
+
+} // namespace clang
Index: clang/lib/Interpreter/CMakeLists.txt
===================================================================
--- clang/lib/Interpreter/CMakeLists.txt
+++ clang/lib/Interpreter/CMakeLists.txt
@@ -13,6 +13,7 @@
 
 add_clang_library(clangInterpreter
   DeviceOffload.cpp
+  CodeCompletion.cpp
   IncrementalExecutor.cpp
   IncrementalParser.cpp
   Interpreter.cpp
Index: clang/lib/Frontend/ASTUnit.cpp
===================================================================
--- clang/lib/Frontend/ASTUnit.cpp
+++ clang/lib/Frontend/ASTUnit.cpp
@@ -2008,7 +2008,8 @@
   case CodeCompletionContext::CCC_SymbolOrNewName:
   case CodeCompletionContext::CCC_ParenthesizedExpression:
   case CodeCompletionContext::CCC_ObjCInterfaceName:
-    break;
+  case CodeCompletionContext::CCC_TopLevelOrExpression:
+      break;
 
   case CodeCompletionContext::CCC_EnumTag:
   case CodeCompletionContext::CCC_UnionTag:
@@ -2167,7 +2168,8 @@
     std::shared_ptr<PCHContainerOperations> PCHContainerOps,
     DiagnosticsEngine &Diag, LangOptions &LangOpts, SourceManager &SourceMgr,
     FileManager &FileMgr, SmallVectorImpl<StoredDiagnostic> &StoredDiagnostics,
-    SmallVectorImpl<const llvm::MemoryBuffer *> &OwnedBuffers) {
+    SmallVectorImpl<const llvm::MemoryBuffer *> &OwnedBuffers,
+    std::unique_ptr<SyntaxOnlyAction> Act) {
   if (!Invocation)
     return;
 
@@ -2304,8 +2306,9 @@
   if (!Clang->getLangOpts().Modules)
     PreprocessorOpts.DetailedRecord = false;
 
-  std::unique_ptr<SyntaxOnlyAction> Act;
-  Act.reset(new SyntaxOnlyAction);
+  if (!Act)
+    Act.reset(new SyntaxOnlyAction);
+
   if (Act->BeginSourceFile(*Clang.get(), Clang->getFrontendOpts().Inputs[0])) {
     if (llvm::Error Err = Act->Execute()) {
       consumeError(std::move(Err)); // FIXME this drops errors on the floor.
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -13449,7 +13449,9 @@
     PCC_ParenthesizedExpression,
     /// Code completion occurs within a sequence of declaration
     /// specifiers within a function, method, or block.
-    PCC_LocalDeclarationSpecifiers
+    PCC_LocalDeclarationSpecifiers,
+    /// Code completion occurs at top-level in a REPL session
+    PCC_TopLevelOrExpression,
   };
 
   void CodeCompleteModuleImport(SourceLocation ImportLoc, ModuleIdPath Path);
Index: clang/include/clang/Sema/CodeCompleteConsumer.h
===================================================================
--- clang/include/clang/Sema/CodeCompleteConsumer.h
+++ clang/include/clang/Sema/CodeCompleteConsumer.h
@@ -336,7 +336,10 @@
     CCC_Recovery,
 
     /// Code completion in a @class forward declaration.
-    CCC_ObjCClassForwardDecl
+    CCC_ObjCClassForwardDecl,
+
+    /// Code completion at a top level in a REPL session.
+    CCC_TopLevelOrExpression,
   };
 
   using VisitedContextSet = llvm::SmallPtrSet<DeclContext *, 8>;
Index: clang/include/clang/Interpreter/CodeCompletion.h
===================================================================
--- /dev/null
+++ clang/include/clang/Interpreter/CodeCompletion.h
@@ -0,0 +1,33 @@
+//===----- CodeCompletion.h - Code Completion for ClangRepl ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#define LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#include <string>
+#include <vector>
+
+namespace llvm {
+class StringRef;
+} // namespace llvm
+
+namespace clang {
+class CodeCompletionResult;
+class CompilerInstance;
+
+void codeComplete(CompilerInstance *InterpCI, llvm::StringRef Content,
+                  unsigned Line, unsigned Col, const CompilerInstance *ParentCI,
+                  std::vector<CodeCompletionResult> &CCResults);
+
+std::vector<std::string>
+convertToCodeCompleteStrings(const std::vector<CodeCompletionResult> &Results);
+} // namespace clang
+#endif
Index: clang/include/clang/Frontend/ASTUnit.h
===================================================================
--- clang/include/clang/Frontend/ASTUnit.h
+++ clang/include/clang/Frontend/ASTUnit.h
@@ -77,6 +77,7 @@
 class PreprocessorOptions;
 class Sema;
 class TargetInfo;
+class SyntaxOnlyAction;
 
 /// \brief Enumerates the available scopes for skipping function bodies.
 enum class SkipFunctionBodiesScope { None, Preamble, PreambleAndMainFile };
@@ -887,6 +888,10 @@
   /// \param IncludeBriefComments Whether to include brief documentation within
   /// the set of code completions returned.
   ///
+  /// \param Act If supplied, this argument is used to parse the input file,
+  /// allowing customized parsing by overriding SyntaxOnlyAction lifecycle
+  /// methods.
+  ///
   /// FIXME: The Diag, LangOpts, SourceMgr, FileMgr, StoredDiagnostics, and
   /// OwnedBuffers parameters are all disgusting hacks. They will go away.
   void CodeComplete(StringRef File, unsigned Line, unsigned Column,
@@ -897,7 +902,8 @@
                     DiagnosticsEngine &Diag, LangOptions &LangOpts,
                     SourceManager &SourceMgr, FileManager &FileMgr,
                     SmallVectorImpl<StoredDiagnostic> &StoredDiagnostics,
-                    SmallVectorImpl<const llvm::MemoryBuffer *> &OwnedBuffers);
+                    SmallVectorImpl<const llvm::MemoryBuffer *> &OwnedBuffers,
+                    std::unique_ptr<SyntaxOnlyAction> Act = nullptr);
 
   /// Save this translation unit to a file with the given name.
   ///
Index: clang-tools-extra/clangd/CodeComplete.cpp
===================================================================
--- clang-tools-extra/clangd/CodeComplete.cpp
+++ clang-tools-extra/clangd/CodeComplete.cpp
@@ -842,6 +842,7 @@
   case CodeCompletionContext::CCC_NaturalLanguage:
   case CodeCompletionContext::CCC_Recovery:
   case CodeCompletionContext::CCC_NewName:
+  case CodeCompletionContext::CCC_TopLevelOrExpression:
     return false;
   }
   llvm_unreachable("unknown code completion context");
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to