njames93 updated this revision to Diff 318315.
njames93 added a comment.

Split up the code a little more. Fix a few malformed comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D94942/new/

https://reviews.llvm.org/D94942

Files:
  clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
  clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
  clang-tools-extra/clangd/unittests/CMakeLists.txt
  clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp

Index: clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/unittests/tweaks/ImplementAbstractTests.cpp
@@ -0,0 +1,349 @@
+//===-- ImplementAbstractTests.cpp ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTU.h"
+#include "TweakTesting.h"
+#include "gmock/gmock-matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::Not;
+
+namespace clang {
+namespace clangd {
+namespace {
+
+TWEAK_TEST(ImplementAbstract);
+
+TEST_F(ImplementAbstractTest, TestUnavailable) {
+
+  StringRef Cases[]{
+      // Not a pure virtual method.
+      R"cpp(
+      class A {
+        virtual void Foo();
+      };
+      class ^B : public A {};
+    )cpp",
+      // Pure virtual method overridden in class.
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        void Foo() override;
+      };
+    )cpp",
+      // Pure virtual method overridden in class with virtual keyword
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        virtual void Foo() override;
+      };
+    )cpp",
+      // Pure virtual method overridden in class without override keyword
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class ^B : public A {
+        void Foo();
+      };
+    )cpp",
+      // Pure virtual method overriden in base class.
+      R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      class B : public A {
+        void Foo() override;
+      };
+      class ^C : public B {
+      };
+    )cpp"};
+  for (const auto &Case : Cases) {
+    EXPECT_THAT(Case, Not(isAvailable()));
+  }
+}
+
+TEST_F(ImplementAbstractTest, NormalAvailable) {
+  struct Case {
+    llvm::StringRef TestHeader;
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  };
+
+  Case Cases[]{
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };)cpp",
+          R"cpp(
+      class B : public A {^};
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        public:
+        virtual void Foo() = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {};
+    )cpp",
+          R"cpp(
+      class B : public A {
+public:
+
+void Foo() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo(int Param) = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {};
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo(int Param) override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo(int Param) = 0;
+      };)cpp",
+          R"cpp(
+      struct ^B : public A {};
+    )cpp",
+          R"cpp(
+      struct B : public A {
+private:
+
+void Foo(int Param) override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo(int Param) const volatile = 0;
+        public:
+        virtual void Bar(int Param) = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {
+        void Foo(int Param) const volatile override;
+      };
+    )cpp",
+          R"cpp(
+      class B : public A {
+        void Foo(int Param) const volatile override;
+      
+public:
+
+void Bar(int Param) override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+           class A {
+        virtual void Foo() = 0;
+        virtual void Bar() = 0;
+      };
+      class B : public A {
+        void Foo() override;
+      };
+        )cpp",
+          R"cpp(
+          class ^C : public B {
+            virtual void Baz();
+          };
+        )cpp",
+          R"cpp(
+          class C : public B {
+            virtual void Baz();
+void Bar() override;
+
+          };
+        )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {
+        ~B();
+      };
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo() override;
+
+        ~B();
+      };
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+        public:
+        virtual void Bar() = 0;
+      };)cpp",
+          R"cpp(
+      class ^B : public A {
+      };
+    )cpp",
+          R"cpp(
+      class B : public A {
+void Foo() override;
+
+      
+public:
+
+void Bar() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+      class A {
+        virtual void Foo() = 0;
+      };
+      struct B : public A {
+        virtual void Bar() = 0;
+      };)cpp",
+          R"cpp(
+      class ^C : public B {
+      };
+    )cpp",
+          R"cpp(
+      class C : public B {
+void Foo() override;
+
+      
+public:
+
+void Bar() override;
+};
+    )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };
+            struct B : public A {
+              virtual void Bar() = 0;
+            };)cpp",
+          R"cpp(
+            class ^C : private B {
+            };
+          )cpp",
+          R"cpp(
+            class C : private B {
+void Foo() override;
+void Bar() override;
+
+            };
+          )cpp",
+      },
+  };
+
+  for (const auto &Case : Cases) {
+    Header = Case.TestHeader.str();
+    EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource);
+  }
+}
+
+TEST_F(ImplementAbstractTest, TemplateUnavailable) {
+  StringRef Cases[]{
+      R"cpp(
+        template<typename T>
+        class A {
+          virtual void Foo() = 0;
+        };
+        template<typename T>
+        class ^B : public A<T> {};
+        )cpp",
+      R"cpp(
+        template<typename T>
+        class ^B : public T {};
+        )cpp",
+  };
+  for (const auto &Case : Cases) {
+    EXPECT_THAT(Case, Not(isAvailable()));
+  }
+}
+
+TEST_F(ImplementAbstractTest, TemplateAvailable) {
+  struct Case {
+    llvm::StringRef TestHeader;
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  };
+  Case Cases[]{
+      {
+          R"cpp(
+            template<typename T>
+            class A {
+              virtual void Foo() = 0;
+            };
+            )cpp",
+          R"cpp(
+            class ^B : public A<int> {};
+            )cpp",
+          R"cpp(
+            class B : public A<int> {
+void Foo() override;
+};
+            )cpp",
+      },
+      {
+          R"cpp(
+            class A {
+              virtual void Foo() = 0;
+            };)cpp",
+          R"cpp(
+            template<typename T>
+            class ^B : public A {};
+            )cpp",
+          R"cpp(
+            template<typename T>
+            class B : public A {
+void Foo() override;
+};
+            )cpp",
+      },
+  };
+  for (const auto &Case : Cases) {
+    Header = Case.TestHeader.str();
+    EXPECT_EQ(apply(Case.TestSource), Case.ExpectedSource);
+  }
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/unittests/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/unittests/CMakeLists.txt
+++ clang-tools-extra/clangd/unittests/CMakeLists.txt
@@ -118,6 +118,7 @@
   tweaks/ExpandMacroTests.cpp
   tweaks/ExtractFunctionTests.cpp
   tweaks/ExtractVariableTests.cpp
+  tweaks/ImplementAbstractTests.cpp
   tweaks/ObjCLocalizeStringLiteralTests.cpp
   tweaks/PopulateSwitchTests.cpp
   tweaks/RawStringLiteralTests.cpp
Index: clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/refactor/tweaks/ImplementAbstract.cpp
@@ -0,0 +1,308 @@
+//===--- ImplementAbstract.cpp -----------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+#include "support/Logger.h"
+#include "clang/Basic/Specifiers.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace clang {
+namespace clangd {
+
+namespace {
+
+using MethodAndAccess =
+    llvm::PointerIntPair<const CXXMethodDecl *, 2, AccessSpecifier>;
+
+AccessSpecifier getMostConstrained(AccessSpecifier InheritSpecifier,
+                                   AccessSpecifier DefinedAs) {
+  return std::max(InheritSpecifier, DefinedAs);
+}
+
+bool collectPureVirtual(const CXXRecordDecl &Record,
+                        llvm::SmallVectorImpl<MethodAndAccess> &Results,
+                        AccessSpecifier Access,
+                        llvm::SmallPtrSetImpl<const CXXMethodDecl *> &Overrides,
+                        bool IsRoot) {
+  if (Record.getNumBases() > 0) {
+    for (const CXXMethodDecl *Method : Record.methods()) {
+      if (!Method->isVirtual())
+        continue;
+      // If we have any pure virtual methods declared in the root (The class
+      // this tweak was invoked on), assume the user probably doesn't want to
+      // implement all abstract methods as the class will still be astract.
+      if (IsRoot && Method->isPure())
+        return true;
+      for (const auto *Overriding : Method->overridden_methods())
+        Overrides.insert(Overriding);
+    }
+    for (auto Base : Record.bases()) {
+      const RecordType *RT = Base.getType()->getAs<RecordType>();
+      if (!RT)
+        // Probably a dependent base, just error out.
+        return true;
+      const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(RT->getDecl());
+      if (!BaseDecl->isPolymorphic())
+        continue;
+      if (collectPureVirtual(
+              *BaseDecl, Results,
+              getMostConstrained(Access, Base.getAccessSpecifier()), Overrides,
+              false))
+        // Propergate any error back up.
+        return true;
+    }
+  } else {
+    assert(!IsRoot && "We should have filtered out this case already");
+  }
+  // Add the Pure methods from this class after traversing the bases. This means
+  // when it comes time to create implementation, methods from classes higher up
+  // the heirachy will appear first.
+  for (const CXXMethodDecl *Method : Record.methods()) {
+    if (!Method->isPure())
+      continue;
+    if (!Overrides.contains(Method))
+      Results.emplace_back(Method,
+                           getMostConstrained(Access, Method->getAccess()));
+  }
+  return false;
+}
+
+const CXXRecordDecl *getSelectedRecord(const Tweak::Selection &Inputs) {
+  // FIXME: This method won't return the class when the caret in the body of the
+  // class. So the only way to get the tweak offered is to be be touching the
+  // marked ranges. It would be nicer if this was offered if cursor was inside
+  // the class (but perhaps not inside the classes decls).
+  //  [[class]]  [[Derived]]  [[:]]  [[public]]  Base  [[{]]
+  //   ^
+  // [[}]];
+  if (const SelectionTree::Node *Node = Inputs.ASTSelection.commonAncestor())
+    return Node->ASTNode.get<CXXRecordDecl>();
+  return nullptr;
+}
+
+/// Some quick to check basic heuristics to check before we try and collect
+/// virtual methods.
+bool isClassOK(const CXXRecordDecl &RecordDecl) {
+  if (!RecordDecl.isThisDeclarationADefinition())
+    return false;
+  if (!RecordDecl.isClass() && !RecordDecl.isStruct())
+    return false;
+  if (RecordDecl.hasAnyDependentBases() || RecordDecl.getNumBases() == 0)
+    return false;
+  // We should check for abstract, but that prevents working on template classes
+  // that don't have any dependent bases.
+  if (!RecordDecl.isPolymorphic())
+    return false;
+  return true;
+}
+
+struct InsertionDetail {
+  SourceLocation Loc = {};
+  AccessSpecifier Access;
+  unsigned char AfterPriority = 0;
+};
+
+// This is a little hacky because EndLoc of a decl doesn't include
+// the semi-colon.
+auto getLocAfterDecl(const Decl &D, const SourceManager &SM,
+                     const LangOptions &LO) {
+  if (D.hasBody())
+    return D.getEndLoc().getLocWithOffset(1);
+  if (auto Next = Lexer::findNextToken(D.getEndLoc(), SM, LO)) {
+    if (Next->is(tok::semi))
+      return Next->getEndLoc();
+  }
+  return D.getEndLoc().getLocWithOffset(1);
+}
+
+/// Generate insertion points in \p R that don't require inserting access
+/// specifiers. The insertion points generally try to appear after the last
+/// method declared in the class with a specific access. \p ShouldIncludeAccess
+/// is a way to avoid generating insertion points for access specifiers we
+/// aren't going to fill in.
+SmallVector<InsertionDetail, 3>
+getInsertionPoints(const CXXRecordDecl &R, ArrayRef<bool> ShouldIncludeAccess,
+                   const SourceManager &SM, const LangOptions &LO) {
+  SmallVector<InsertionDetail, 3> Result;
+  auto GetDetailForAccess = [&](AccessSpecifier Spec) -> InsertionDetail & {
+    assert(Spec != AS_none);
+    for (InsertionDetail &Item : Result) {
+      if (Item.Access == Spec)
+        return Item;
+    }
+    return Result.emplace_back(InsertionDetail{{}, Spec});
+  };
+
+  // This whole block is designed to get an insertion point after the last
+  // method has been declared with each access specifier. Doing this ensures we
+  // keep the same visibility for implemented methods without the need to add
+  // unnecessary access specifiers.
+  for (auto *Decl : R.decls()) {
+    if (!ShouldIncludeAccess[Decl->getAccess()])
+      continue;
+    // Ignore things like compiler generated special member functions.
+    if (Decl->isImplicit())
+      continue;
+    // Hack to try and leave the destructor as last method in a block.
+    if (isa<CXXDestructorDecl>(Decl))
+      continue;
+    InsertionDetail &Detail = GetDetailForAccess(Decl->getAccess());
+    if (isa<CXXMethodDecl>(Decl)) {
+      Detail.Loc = getLocAfterDecl(*Decl, SM, LO);
+      Detail.AfterPriority = 2;
+    } else {
+      // Try to put methods after access spec but before fields.
+      auto Priority = isa<AccessSpecDecl>(Decl) ? 1 : 0;
+      if (Detail.AfterPriority <= Priority) {
+        Detail.Loc = getLocAfterDecl(*Decl, SM, LO);
+        Detail.AfterPriority = Priority;
+      }
+    }
+  }
+  if (Result.empty()) {
+    auto Access = R.isClass() ? AS_private : AS_public;
+    if (ShouldIncludeAccess[Access]) {
+      // An empty class so start inserting methods that don't need an access
+      // specifier just after the open curly brace.
+      GetDetailForAccess(Access).Loc =
+          R.getBraceRange().getBegin().getLocWithOffset(1);
+    }
+  }
+  return Result;
+}
+
+void printMethods(llvm::raw_ostream &Out, ArrayRef<MethodAndAccess> Items,
+                  AccessSpecifier AccessKind, const CXXRecordDecl *PrintContext,
+                  bool PrintAccessSpec) {
+  class PrintCB : public PrintingCallbacks {
+  public:
+    PrintCB(const DeclContext *CurContext) : CurContext(CurContext) {}
+    virtual ~PrintCB() {}
+    bool isScopeVisible(const DeclContext *DC) const override {
+      return DC->Encloses(CurContext);
+    }
+
+  private:
+    const DeclContext *CurContext;
+  };
+  PrintCB Callbacks(PrintContext);
+  auto Policy = PrintContext->getASTContext().getPrintingPolicy();
+  Policy.SuppressScope = false;
+  Policy.Callbacks = &Callbacks;
+  if (PrintAccessSpec)
+    Out << "\n" << getAccessSpelling(AccessKind) << ":\n";
+  Out << "\n";
+  for (const auto &MethodAndAccess : Items) {
+    if (MethodAndAccess.getInt() != AccessKind)
+      continue;
+    const CXXMethodDecl *Method = MethodAndAccess.getPointer();
+    Method->getReturnType().print(Out, Policy);
+    Out << ' ';
+    Out << Method->getNameAsString() << "(";
+    bool IsFirst = true;
+    for (const auto &Param : Method->parameters()) {
+      if (!IsFirst)
+        Out << ", ";
+      else
+        IsFirst = false;
+      Param->print(Out, Policy);
+    }
+    Out << ") ";
+    if (Method->isConst())
+      Out << "const ";
+    if (Method->isVolatile())
+      Out << "volatile ";
+    // Always suggest `override` over `final`.
+    Out << "override;\n";
+  }
+}
+
+class ImplementAbstract : public Tweak {
+public:
+  const char *id() const override;
+
+  bool prepare(const Selection &Inputs) override {
+    Selected = getSelectedRecord(Inputs);
+    if (!Selected)
+      return false;
+    if (!isClassOK(*Selected))
+      return false;
+    llvm::SmallPtrSet<const CXXMethodDecl *, 16> Overrides;
+    if (collectPureVirtual(*Selected, PureVirtualMethods, AS_public, Overrides,
+                           true))
+      return false;
+    return !PureVirtualMethods.empty();
+  }
+
+  Expected<Effect> apply(const Selection &Inputs) override {
+    // We should have at least one pure virtual method to add.
+    assert(!PureVirtualMethods.empty() &&
+           "Prepare returned true when no methodx existed");
+    bool AccessNeedsProcessing[3] = {0};
+    for (auto Item : PureVirtualMethods) {
+      AccessNeedsProcessing[Item.getInt()] = true;
+    }
+
+    auto InsertionPoints = getInsertionPoints(*Selected, AccessNeedsProcessing,
+                                              Inputs.AST->getSourceManager(),
+                                              Inputs.AST->getLangOpts());
+    SmallString<256> Buffer;
+    llvm::raw_svector_ostream OS(Buffer);
+    tooling::Replacements Replacements;
+    for (auto &Item : InsertionPoints) {
+      assert(Item.Loc.isValid());
+      if (!AccessNeedsProcessing[Item.Access])
+        continue;
+      AccessNeedsProcessing[Item.Access] = false;
+      printMethods(OS, PureVirtualMethods, Item.Access, Selected,
+                   /*PrintAccessSpec=*/false);
+      if (auto Err = Replacements.add(tooling::Replacement(
+              Inputs.AST->getSourceManager(), Item.Loc, 0, Buffer))) {
+        return std::move(Err);
+      }
+      Buffer.clear();
+    }
+
+    // Any access specifiers not convered can be added in one insertion.
+    for (AccessSpecifier Spec : {AS_public, AS_protected, AS_private}) {
+      if (!AccessNeedsProcessing[Spec])
+        continue;
+      printMethods(OS, PureVirtualMethods, Spec, Selected,
+                   /*PrintAccessSpec=*/true);
+    }
+    if (!Buffer.empty()) {
+      if (auto Err = Replacements.add(tooling::Replacement(
+              Inputs.AST->getSourceManager(),
+              Selected->getBraceRange().getEnd(), 0, Buffer))) {
+        return std::move(Err);
+      }
+    }
+    return Effect::mainFileEdit(Inputs.AST->getASTContext().getSourceManager(),
+                                std::move(Replacements));
+  }
+
+  std::string title() const override {
+    return "Implement pure virtual methods";
+  }
+
+  llvm::StringLiteral kind() const override {
+    return CodeAction::REFACTOR_KIND;
+  }
+
+private:
+  const CXXRecordDecl *Selected;
+  llvm::SmallVector<MethodAndAccess, 0> PureVirtualMethods;
+};
+
+REGISTER_TWEAK(ImplementAbstract)
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
+++ clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
@@ -21,6 +21,7 @@
   ExpandMacro.cpp
   ExtractFunction.cpp
   ExtractVariable.cpp
+  ImplementAbstract.cpp
   ObjCLocalizeStringLiteral.cpp
   PopulateSwitch.cpp
   RawStringLiteral.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to