ioeric created this revision.
ioeric added a reviewer: sammccall.
Herald added subscribers: cfe-commits, kadircet, arphaman, mgrang, jkorous, 
MaskRay, javed.absar, ilya-biryukov, mgorny.

This enables clangd to intercept compiler diagnostics and attach fixes (e.g. by
querying index). This patch adds missing includes for incomplete types e.g.
member access into class with only forward declaration. This would allow adding
missing includes for user-typed symbol names that are missing declarations
(e.g. typos) in the future.


Repository:
  rCTE Clang Tools Extra

https://reviews.llvm.org/D56903

Files:
  clangd/CMakeLists.txt
  clangd/ClangdServer.cpp
  clangd/ClangdUnit.cpp
  clangd/ClangdUnit.h
  clangd/CodeComplete.cpp
  clangd/Diagnostics.cpp
  clangd/Diagnostics.h
  clangd/Headers.cpp
  clangd/Headers.h
  clangd/IncludeFixer.cpp
  clangd/IncludeFixer.h
  clangd/SourceCode.cpp
  clangd/SourceCode.h
  unittests/clangd/CMakeLists.txt
  unittests/clangd/FileIndexTests.cpp
  unittests/clangd/IncludeFixerTests.cpp
  unittests/clangd/TUSchedulerTests.cpp
  unittests/clangd/TestTU.cpp
  unittests/clangd/TestTU.h

Index: unittests/clangd/TestTU.h
===================================================================
--- unittests/clangd/TestTU.h
+++ unittests/clangd/TestTU.h
@@ -49,6 +49,9 @@
   // Extra arguments for the compiler invocation.
   std::vector<const char *> ExtraArgs;
 
+  // Index to use when building AST.
+  const SymbolIndex *ExternalIndex = nullptr;
+
   ParsedAST build() const;
   SymbolSlab headerSymbols() const;
   std::unique_ptr<SymbolIndex> index() const;
Index: unittests/clangd/TestTU.cpp
===================================================================
--- unittests/clangd/TestTU.cpp
+++ unittests/clangd/TestTU.cpp
@@ -36,6 +36,7 @@
   Inputs.CompileCommand.Directory = testRoot();
   Inputs.Contents = Code;
   Inputs.FS = buildTestFS({{FullFilename, Code}, {FullHeaderName, HeaderCode}});
+  Inputs.Index = ExternalIndex;
   auto PCHs = std::make_shared<PCHContainerOperations>();
   auto CI = buildCompilerInvocation(Inputs);
   assert(CI && "Failed to build compilation invocation.");
Index: unittests/clangd/TUSchedulerTests.cpp
===================================================================
--- unittests/clangd/TUSchedulerTests.cpp
+++ unittests/clangd/TUSchedulerTests.cpp
@@ -38,8 +38,9 @@
 class TUSchedulerTests : public ::testing::Test {
 protected:
   ParseInputs getInputs(PathRef File, std::string Contents) {
-    return ParseInputs{*CDB.getCompileCommand(File),
-                       buildTestFS(Files, Timestamps), std::move(Contents)};
+    return ParseInputs(*CDB.getCompileCommand(File),
+                       buildTestFS(Files, Timestamps), std::move(Contents),
+                       /*Index=*/nullptr);
   }
 
   void updateWithCallback(TUScheduler &S, PathRef File,
Index: unittests/clangd/IncludeFixerTests.cpp
===================================================================
--- /dev/null
+++ unittests/clangd/IncludeFixerTests.cpp
@@ -0,0 +1,78 @@
+//===-- ClangdUnitTests.cpp - ClangdUnit tests ------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Annotations.h"
+#include "ClangdUnit.h"
+#include "IncludeFixer.h"
+#include "TestIndex.h"
+#include "TestTU.h"
+#include "index/MemIndex.h"
+#include "llvm/Support/ScopedPrinter.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+namespace clangd {
+namespace {
+
+using testing::UnorderedElementsAre;
+
+testing::Matcher<const Diag &> WithFix(testing::Matcher<Fix> FixMatcher) {
+  return Field(&Diag::Fixes, ElementsAre(FixMatcher));
+}
+
+MATCHER_P2(Diag, Range, Message,
+           "Diag at " + llvm::to_string(Range) + " = [" + Message + "]") {
+  return arg.Range == Range && arg.Message == Message;
+}
+
+MATCHER_P3(Fix, Range, Replacement, Message,
+           "Fix " + llvm::to_string(Range) + " => " +
+               testing::PrintToString(Replacement) + " = [" + Message + "]") {
+  return arg.Message == Message && arg.Edits.size() == 1 &&
+         arg.Edits[0].range == Range && arg.Edits[0].newText == Replacement;
+}
+
+TEST(IncludeFixerTest, IncompleteType) {
+  Annotations Test(R"cpp(
+$insert[[]]namespace ns {
+  class X;
+}
+class Y : $base[[public ns::X]] {};
+int main() {
+  ns::X *x;
+  x$access[[->]]f();
+}
+  )cpp");
+  auto TU = TestTU::withCode(Test.code());
+  Symbol Sym = symbol("ns::X");
+  Sym.Flags |= Symbol::IndexedForCodeCompletion;
+  Sym.CanonicalDeclaration.FileURI = "unittest:///x.h";
+  Sym.IncludeHeaders.emplace_back("\"x.h\"", 1);
+
+  SymbolSlab::Builder Slab;
+  Slab.insert(Sym);
+  auto Index = MemIndex::build(std::move(Slab).build(), RefSlab());
+  TU.ExternalIndex = Index.get();
+
+  EXPECT_THAT(
+      TU.build().getDiagnostics(),
+      UnorderedElementsAre(
+          AllOf(Diag(Test.range("base"), "base class has incomplete type"),
+                WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n",
+                            "Add include \"x.h\" for symbol ns::X"))),
+          AllOf(Diag(Test.range("access"),
+                     "member access into incomplete type 'ns::X'"),
+                WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n",
+                            "Add include \"x.h\" for symbol ns::X")))));
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: unittests/clangd/FileIndexTests.cpp
===================================================================
--- unittests/clangd/FileIndexTests.cpp
+++ unittests/clangd/FileIndexTests.cpp
@@ -361,10 +361,10 @@
       /*StoreInMemory=*/true,
       [&](ASTContext &Ctx, std::shared_ptr<Preprocessor> PP) {});
   // Build AST for main file with preamble.
-  auto AST =
-      ParsedAST::build(createInvocationFromCommandLine(Cmd), PreambleData,
-                       llvm::MemoryBuffer::getMemBufferCopy(Main.code()),
-                       std::make_shared<PCHContainerOperations>(), PI.FS);
+  auto AST = ParsedAST::build(
+      createInvocationFromCommandLine(Cmd), PreambleData,
+      llvm::MemoryBuffer::getMemBufferCopy(Main.code()),
+      std::make_shared<PCHContainerOperations>(), PI.FS, /*Index=*/nullptr);
   ASSERT_TRUE(AST);
   FileIndex Index;
   Index.updateMain(MainFile, *AST);
Index: unittests/clangd/CMakeLists.txt
===================================================================
--- unittests/clangd/CMakeLists.txt
+++ unittests/clangd/CMakeLists.txt
@@ -28,6 +28,7 @@
   FuzzyMatchTests.cpp
   GlobalCompilationDatabaseTests.cpp
   HeadersTests.cpp
+  IncludeFixerTests.cpp
   IndexActionTests.cpp
   IndexTests.cpp
   JSONTransportTests.cpp
Index: clangd/SourceCode.h
===================================================================
--- clangd/SourceCode.h
+++ clangd/SourceCode.h
@@ -17,7 +17,9 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
+#include "clang/Format/Format.h"
 #include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/SHA1.h"
 
 namespace clang {
@@ -92,6 +94,11 @@
                                              const SourceManager &SourceMgr);
 
 bool IsRangeConsecutive(const Range &Left, const Range &Right);
+
+format::FormatStyle getFormatStyleForFile(llvm::StringRef File,
+                                          llvm::StringRef Content,
+                                          llvm::vfs::FileSystem *FS);
+
 } // namespace clangd
 } // namespace clang
 #endif
Index: clangd/SourceCode.cpp
===================================================================
--- clangd/SourceCode.cpp
+++ clangd/SourceCode.cpp
@@ -249,5 +249,18 @@
   return digest(Content);
 }
 
+format::FormatStyle getFormatStyleForFile(llvm::StringRef File,
+                                          llvm::StringRef Content,
+                                          llvm::vfs::FileSystem *FS) {
+  auto Style = format::getStyle(format::DefaultFormatStyle, File,
+                                format::DefaultFallbackStyle, Content, FS);
+  if (!Style) {
+    log("getStyle() failed for file {0}: {1}. Fallback is LLVM style.", File,
+        Style.takeError());
+    Style = format::getLLVMStyle();
+  }
+  return *Style;
+}
+
 } // namespace clangd
 } // namespace clang
Index: clangd/IncludeFixer.h
===================================================================
--- /dev/null
+++ clangd/IncludeFixer.h
@@ -0,0 +1,50 @@
+//===- IncludeFixer.h - Add missing includes --------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H
+#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H
+
+#include "Diagnostics.h"
+#include "Headers.h"
+#include "index/Index.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/Diagnostic.h"
+#include "llvm/ADT/StringRef.h"
+#include <memory>
+
+namespace clang {
+namespace clangd {
+
+/// Attempts to recover from error diagnostics by suggesting include insertion
+/// fixes. For example, member access into incomplete type can be fixes by
+/// include headers with the definition.
+class IncludeFixer {
+public:
+  IncludeFixer(llvm::StringRef File, std::unique_ptr<IncludeInserter> Inserter,
+               const SymbolIndex &Index)
+      : File(File), Inserter(std::move(Inserter)), Index(Index) {}
+
+  /// Returns include insertions that can potentially recover the diagnostic.
+  std::vector<Fix> fix(DiagnosticsEngine::Level DiagLevel,
+                       const clang::Diagnostic &Info) const;
+
+private:
+  std::vector<Fix> fixInCompleteType(const Type &T) const;
+
+  std::vector<Fix> fixesForSymbol(const Symbol &Sym) const;
+
+  std::string File;
+  std::unique_ptr<IncludeInserter> Inserter;
+  const SymbolIndex &Index;
+};
+
+} // namespace clangd
+} // namespace clang
+
+#endif // LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H
Index: clangd/IncludeFixer.cpp
===================================================================
--- /dev/null
+++ clangd/IncludeFixer.cpp
@@ -0,0 +1,119 @@
+//===--- IncludeFixer.cpp ----------------------------------------*- C++-*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "IncludeFixer.h"
+#include "AST.h"
+#include "Diagnostics.h"
+#include "Logger.h"
+#include "SourceCode.h"
+#include "index/Index.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/DiagnosticSema.h"
+#include "llvm/ADT/None.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace clang {
+namespace clangd {
+
+namespace {
+
+bool isIncompleteTypeDiag(unsigned int DiagID) {
+  return DiagID == diag::err_incomplete_type ||
+         DiagID == diag::err_incomplete_member_access ||
+         DiagID == diag::err_incomplete_base_class;
+}
+
+} // namespace
+
+std::vector<Fix> IncludeFixer::fix(DiagnosticsEngine::Level DiagLevel,
+                                   const clang::Diagnostic &Info) const {
+  if (isIncompleteTypeDiag(Info.getID())) {
+    // Incomplete type diagnostics should have a QualType argument for the
+    // incomplete type.
+    for (unsigned i = 0; i < Info.getNumArgs(); ++i) {
+      if (Info.getArgKind(i) == DiagnosticsEngine::ak_qualtype) {
+        auto QT = QualType::getFromOpaquePtr((void *)Info.getRawArg(i));
+        if (const Type *T = QT.getTypePtrOrNull())
+          if (T->isIncompleteType())
+            return fixInCompleteType(*T);
+      }
+    }
+  }
+  return {};
+}
+
+std::vector<Fix> IncludeFixer::fixInCompleteType(const Type &T) const {
+  // Only handle incomplete TagDecl type.
+  const TagDecl *TD = T.getAsTagDecl();
+  if (!TD)
+    return {};
+  std::string IncompleteType = printQualifiedName(*TD);
+
+  if (IncompleteType.empty()) {
+    vlog("No incomplete type name is found in diagnostic. Ignore.");
+    return {};
+  }
+  vlog("Trying to fix include for incomplete type {0}", IncompleteType);
+  FuzzyFindRequest Req;
+  Req.AnyScope = false;
+  auto ScopeAndName = splitQualifiedName(IncompleteType);
+  Req.Scopes.push_back(ScopeAndName.first);
+  Req.Query = ScopeAndName.second;
+  // Only code completion symbols insert includes.
+  Req.RestrictForCodeCompletion = true;
+  llvm::Optional<Symbol> Matched;
+  Index.fuzzyFind(Req, [&](const Symbol &Sym) {
+    // FIXME: support multiple matched symbols.
+    if (Matched || Sym.Name != Req.Query)
+      return;
+    Matched = Sym;
+  });
+
+  if (!Matched || Matched->IncludeHeaders.empty())
+    return {};
+  return fixesForSymbol(*Matched);
+}
+
+std::vector<Fix> IncludeFixer::fixesForSymbol(const Symbol &Sym) const {
+  auto Inserted = [&](llvm::StringRef Header)
+      -> llvm::Expected<std::pair<std::string, bool>> {
+    auto ResolvedDeclaring =
+        toHeaderFile(Sym.CanonicalDeclaration.FileURI, File);
+    if (!ResolvedDeclaring)
+      return ResolvedDeclaring.takeError();
+    auto ResolvedInserted = toHeaderFile(Header, File);
+    if (!ResolvedInserted)
+      return ResolvedInserted.takeError();
+    return std::make_pair(
+        Inserter->calculateIncludePath(*ResolvedDeclaring, *ResolvedInserted),
+        Inserter->shouldInsertInclude(*ResolvedDeclaring, *ResolvedInserted));
+  };
+
+  std::vector<Fix> Fixes;
+  for (const auto &Inc : getRankedIncludes(Sym)) {
+    if (auto ToInclude = Inserted(Inc)) {
+      if (ToInclude->second)
+        if (auto Edit = Inserter->insert(ToInclude->first))
+          Fixes.push_back(
+              Fix{llvm::formatv("Add include {0} for symbol {1}{2}",
+                                ToInclude->first, Sym.Scope, Sym.Name),
+                  {std::move(*Edit)}});
+    } else {
+      vlog("Failed to calculate include insertion for {0} into {1}: {2}", File,
+           Inc, llvm::toString(ToInclude.takeError()));
+    }
+  }
+  return Fixes;
+}
+
+} // namespace clangd
+} // namespace clang
Index: clangd/Headers.h
===================================================================
--- clangd/Headers.h
+++ clangd/Headers.h
@@ -13,10 +13,12 @@
 #include "Path.h"
 #include "Protocol.h"
 #include "SourceCode.h"
+#include "index/Index.h"
 #include "clang/Format/Format.h"
 #include "clang/Lex/HeaderSearch.h"
 #include "clang/Lex/PPCallbacks.h"
 #include "clang/Tooling/Inclusions/HeaderIncludes.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Error.h"
@@ -38,6 +40,15 @@
   bool valid() const;
 };
 
+/// Creates a `HeaderFile` from \p Header which can be either a URI or a literal
+/// include.
+llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
+                                        llvm::StringRef HintPath);
+
+// Returns include headers for \p Sym sorted by popularity. If two headers are
+// equally popular, prefer the shorter one.
+llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym);
+
 // An #include directive that we found in the main file.
 struct Inclusion {
   Range R;             // Inclusion range.
Index: clangd/Headers.cpp
===================================================================
--- clangd/Headers.cpp
+++ clangd/Headers.cpp
@@ -74,6 +74,41 @@
          (!Verbatim && llvm::sys::path::is_absolute(File));
 }
 
+llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
+                                        llvm::StringRef HintPath) {
+  if (isLiteralInclude(Header))
+    return HeaderFile{Header.str(), /*Verbatim=*/true};
+  auto U = URI::parse(Header);
+  if (!U)
+    return U.takeError();
+
+  auto IncludePath = URI::includeSpelling(*U);
+  if (!IncludePath)
+    return IncludePath.takeError();
+  if (!IncludePath->empty())
+    return HeaderFile{std::move(*IncludePath), /*Verbatim=*/true};
+
+  auto Resolved = URI::resolve(*U, HintPath);
+  if (!Resolved)
+    return Resolved.takeError();
+  return HeaderFile{std::move(*Resolved), /*Verbatim=*/false};
+}
+
+llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym) {
+  auto Includes = Sym.IncludeHeaders;
+  // Sort in descending order by reference count and header length.
+  llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS,
+                          const Symbol::IncludeHeaderWithReferences &RHS) {
+    if (LHS.References == RHS.References)
+      return LHS.IncludeHeader.size() < RHS.IncludeHeader.size();
+    return LHS.References > RHS.References;
+  });
+  llvm::SmallVector<llvm::StringRef, 1> Headers;
+  for (const auto &Include : Includes)
+    Headers.push_back(Include.IncludeHeader);
+  return Headers;
+}
+
 std::unique_ptr<PPCallbacks>
 collectIncludeStructureCallback(const SourceManager &SM,
                                 IncludeStructure *Out) {
Index: clangd/Diagnostics.h
===================================================================
--- clangd/Diagnostics.h
+++ clangd/Diagnostics.h
@@ -87,6 +87,8 @@
 /// Convert from clang diagnostic level to LSP severity.
 int getSeverity(DiagnosticsEngine::Level L);
 
+class IncludeFixer;
+
 /// StoreDiags collects the diagnostics that can later be reported by
 /// clangd. It groups all notes for a diagnostic into a single Diag
 /// and filters out diagnostics that don't mention the main file (i.e. neither
@@ -100,9 +102,14 @@
   void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
                         const clang::Diagnostic &Info) override;
 
+  /// If set, possibly adds fixes for diagnostics using \p Fixer.
+  void setIncludeFixer(const IncludeFixer &Fixer) { FixIncludes = &Fixer; }
+
 private:
   void flushLastDiag();
 
+  const IncludeFixer *FixIncludes = nullptr;
+
   std::vector<Diag> Output;
   llvm::Optional<LangOptions> LangOpts;
   llvm::Optional<Diag> LastDiag;
Index: clangd/Diagnostics.cpp
===================================================================
--- clangd/Diagnostics.cpp
+++ clangd/Diagnostics.cpp
@@ -9,6 +9,7 @@
 
 #include "Diagnostics.h"
 #include "Compiler.h"
+#include "IncludeFixer.h"
 #include "Logger.h"
 #include "SourceCode.h"
 #include "clang/Basic/SourceManager.h"
@@ -375,6 +376,11 @@
 
     if (!Info.getFixItHints().empty())
       AddFix(true /* try to invent a message instead of repeating the diag */);
+    if (FixIncludes) {
+      auto ExtraFixes = FixIncludes->fix(DiagLevel, Info);
+      LastDiag->Fixes.insert(LastDiag->Fixes.end(), ExtraFixes.begin(),
+                             ExtraFixes.end());
+    }
   } else {
     // Handle a note to an existing diagnostic.
     if (!LastDiag) {
Index: clangd/CodeComplete.cpp
===================================================================
--- clangd/CodeComplete.cpp
+++ clangd/CodeComplete.cpp
@@ -178,28 +178,6 @@
   return Result;
 }
 
-/// Creates a `HeaderFile` from \p Header which can be either a URI or a literal
-/// include.
-static llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
-                                               llvm::StringRef HintPath) {
-  if (isLiteralInclude(Header))
-    return HeaderFile{Header.str(), /*Verbatim=*/true};
-  auto U = URI::parse(Header);
-  if (!U)
-    return U.takeError();
-
-  auto IncludePath = URI::includeSpelling(*U);
-  if (!IncludePath)
-    return IncludePath.takeError();
-  if (!IncludePath->empty())
-    return HeaderFile{std::move(*IncludePath), /*Verbatim=*/true};
-
-  auto Resolved = URI::resolve(*U, HintPath);
-  if (!Resolved)
-    return Resolved.takeError();
-  return HeaderFile{std::move(*Resolved), /*Verbatim=*/false};
-}
-
 /// A code completion result, in clang-native form.
 /// It may be promoted to a CompletionItem if it's among the top-ranked results.
 struct CompletionCandidate {
@@ -1156,24 +1134,6 @@
   return CachedReq;
 }
 
-// Returns the most popular include header for \p Sym. If two headers are
-// equally popular, prefer the shorter one. Returns empty string if \p Sym has
-// no include header.
-llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym) {
-  auto Includes = Sym.IncludeHeaders;
-  // Sort in descending order by reference count and header length.
-  llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS,
-                          const Symbol::IncludeHeaderWithReferences &RHS) {
-    if (LHS.References == RHS.References)
-      return LHS.IncludeHeader.size() < RHS.IncludeHeader.size();
-    return LHS.References > RHS.References;
-  });
-  llvm::SmallVector<llvm::StringRef, 1> Headers;
-  for (const auto &Include : Includes)
-    Headers.push_back(Include.IncludeHeader);
-  return Headers;
-}
-
 // Runs Sema-based (AST) and Index-based completion, returns merged results.
 //
 // There are a few tricky considerations:
@@ -1254,19 +1214,12 @@
     CodeCompleteResult Output;
     auto RecorderOwner = llvm::make_unique<CompletionRecorder>(Opts, [&]() {
       assert(Recorder && "Recorder is not set");
-      auto Style =
-          format::getStyle(format::DefaultFormatStyle, SemaCCInput.FileName,
-                           format::DefaultFallbackStyle, SemaCCInput.Contents,
-                           SemaCCInput.VFS.get());
-      if (!Style) {
-        log("getStyle() failed for file {0}: {1}. Fallback is LLVM style.",
-            SemaCCInput.FileName, Style.takeError());
-        Style = format::getLLVMStyle();
-      }
+      auto Style = getFormatStyleForFile(
+          SemaCCInput.FileName, SemaCCInput.Contents, SemaCCInput.VFS.get());
       // If preprocessor was run, inclusions from preprocessor callback should
       // already be added to Includes.
       Inserter.emplace(
-          SemaCCInput.FileName, SemaCCInput.Contents, *Style,
+          SemaCCInput.FileName, SemaCCInput.Contents, Style,
           SemaCCInput.Command.Directory,
           Recorder->CCSema->getPreprocessor().getHeaderSearchInfo());
       for (const auto &Inc : Includes.MainFileIncludes)
Index: clangd/ClangdUnit.h
===================================================================
--- clangd/ClangdUnit.h
+++ clangd/ClangdUnit.h
@@ -16,6 +16,7 @@
 #include "Headers.h"
 #include "Path.h"
 #include "Protocol.h"
+#include "index/Index.h"
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/Frontend/PrecompiledPreamble.h"
 #include "clang/Lex/Preprocessor.h"
@@ -62,9 +63,19 @@
 
 /// Information required to run clang, e.g. to parse AST or do code completion.
 struct ParseInputs {
+  ParseInputs(tooling::CompileCommand CompileCommand,
+              IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS,
+              std::string Contents, const SymbolIndex *Index)
+      : CompileCommand(CompileCommand), FS(std::move(FS)),
+        Contents(std::move(Contents)), Index(Index) {}
+
+  ParseInputs() = default;
+
   tooling::CompileCommand CompileCommand;
   IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
   std::string Contents;
+  // Used to recover from diagnostics (e.g. find missing includes for symbol).
+  const SymbolIndex *Index = nullptr;
 };
 
 /// Stores and provides access to parsed AST.
@@ -77,7 +88,8 @@
         std::shared_ptr<const PreambleData> Preamble,
         std::unique_ptr<llvm::MemoryBuffer> Buffer,
         std::shared_ptr<PCHContainerOperations> PCHs,
-        IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS);
+        IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS,
+        const SymbolIndex *Index);
 
   ParsedAST(ParsedAST &&Other);
   ParsedAST &operator=(ParsedAST &&Other);
Index: clangd/ClangdUnit.cpp
===================================================================
--- clangd/ClangdUnit.cpp
+++ clangd/ClangdUnit.cpp
@@ -12,9 +12,12 @@
 #include "../clang-tidy/ClangTidyModuleRegistry.h"
 #include "Compiler.h"
 #include "Diagnostics.h"
+#include "Headers.h"
+#include "IncludeFixer.h"
 #include "Logger.h"
 #include "SourceCode.h"
 #include "Trace.h"
+#include "index/Index.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/Basic/LangOptions.h"
 #include "clang/Frontend/CompilerInstance.h"
@@ -31,6 +34,7 @@
 #include "clang/Serialization/ASTWriter.h"
 #include "clang/Tooling/CompilationDatabase.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
@@ -228,7 +232,8 @@
                  std::shared_ptr<const PreambleData> Preamble,
                  std::unique_ptr<llvm::MemoryBuffer> Buffer,
                  std::shared_ptr<PCHContainerOperations> PCHs,
-                 llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS) {
+                 llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS,
+                 const SymbolIndex *Index) {
   assert(CI);
   // Command-line parsing sets DisableFree to true by default, but we don't want
   // to leak memory in clangd.
@@ -237,9 +242,11 @@
       Preamble ? &Preamble->Preamble : nullptr;
 
   StoreDiags ASTDiags;
+  std::string Content = Buffer->getBuffer();
+
   auto Clang =
       prepareCompilerInstance(std::move(CI), PreamblePCH, std::move(Buffer),
-                              std::move(PCHs), std::move(VFS), ASTDiags);
+                              std::move(PCHs), VFS, ASTDiags);
   if (!Clang)
     return None;
 
@@ -286,6 +293,22 @@
     }
   }
 
+  llvm::Optional<IncludeFixer> FixIncludes;
+  auto BuildDir = VFS->getCurrentWorkingDirectory();
+  // Add IncludeFixer if Index is provided.
+  if (Index && !BuildDir.getError()) {
+    auto Style = getFormatStyleForFile(MainInput.getFile(), Content, VFS.get());
+    auto Inserter = llvm::make_unique<IncludeInserter>(
+        MainInput.getFile(), Content, Style, BuildDir.get(),
+        Clang->getPreprocessor().getHeaderSearchInfo());
+    if (Preamble) {
+      for (const auto &Inc : Preamble->Includes.MainFileIncludes)
+        Inserter->addExisting(Inc);
+    }
+    FixIncludes.emplace(MainInput.getFile(), std::move(Inserter), *Index);
+    ASTDiags.setIncludeFixer(*FixIncludes);
+  }
+
   // Copy over the includes from the preamble, then combine with the
   // non-preamble includes below.
   auto Includes = Preamble ? Preamble->Includes : IncludeStructure{};
@@ -539,7 +562,7 @@
   return ParsedAST::build(llvm::make_unique<CompilerInvocation>(*Invocation),
                           Preamble,
                           llvm::MemoryBuffer::getMemBufferCopy(Inputs.Contents),
-                          PCHs, std::move(VFS));
+                          PCHs, std::move(VFS), Inputs.Index);
 }
 
 SourceLocation getBeginningOfIdentifier(ParsedAST &Unit, const Position &Pos,
Index: clangd/ClangdServer.cpp
===================================================================
--- clangd/ClangdServer.cpp
+++ clangd/ClangdServer.cpp
@@ -153,8 +153,9 @@
   // "PreparingBuild" status to inform users, it is non-trivial given the
   // current implementation.
   WorkScheduler.update(File,
-                       ParseInputs{getCompileCommand(File),
-                                   FSProvider.getFileSystem(), Contents.str()},
+                       ParseInputs(getCompileCommand(File),
+                                   FSProvider.getFileSystem(), Contents.str(),
+                                   Index),
                        WantDiags);
 }
 
Index: clangd/CMakeLists.txt
===================================================================
--- clangd/CMakeLists.txt
+++ clangd/CMakeLists.txt
@@ -40,6 +40,7 @@
   FuzzyMatch.cpp
   GlobalCompilationDatabase.cpp
   Headers.cpp
+  IncludeFixer.cpp
   JSONTransport.cpp
   Logger.cpp
   Protocol.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to