avogelsgesang created this revision.
Herald added subscribers: usaxena95, kadircet, arphaman, mgorny.
Herald added a project: All.
avogelsgesang requested review of this revision.
Herald added subscribers: cfe-commits, MaskRay, ilya-biryukov.
Herald added a project: clang-tools-extra.

This commit adds a new "add subclass" tweak which facilitates quick
scaffolding of inheritance hierarchies. The tweak can be triggered
on any class with virtual methods. It then inserts a new subclass
which overrides all virtual methods with dummy implementations.

There are two variations of this tweak:

1. A variant which overrides all virtual functions
2. A variant which overrides only pure virtual functions

This tweak also supports deeper inheritance hierarchies, and collects
the methods to be overriden not only from the immediate base class but
from the complete inheritance tree.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D122102

Files:
  clang-tools-extra/clangd/AST.cpp
  clang-tools-extra/clangd/AST.h
  clang-tools-extra/clangd/refactor/InsertionPoint.cpp
  clang-tools-extra/clangd/refactor/tweaks/AddSubclass.cpp
  clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
  clang-tools-extra/clangd/unittests/CMakeLists.txt
  clang-tools-extra/clangd/unittests/tweaks/AddSubclassTests.cpp

Index: clang-tools-extra/clangd/unittests/tweaks/AddSubclassTests.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/unittests/tweaks/AddSubclassTests.cpp
@@ -0,0 +1,526 @@
+//===-- AddSubclassTests.cpp ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Annotations.h"
+#include "TweakTesting.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+namespace clangd {
+namespace {
+
+TWEAK_TEST(AddSubclassAllVirtuals);
+
+TEST_F(AddSubclassAllVirtualsTest, Prepare) {
+  // Not available if there are no virtual function
+  EXPECT_UNAVAILABLE("^struct ^Base { void foo(); };");
+  // Available on virtual functions
+  EXPECT_AVAILABLE("^struct ^Base { virtual void foo(); };");
+  // Available on pure virtual functions
+  EXPECT_AVAILABLE("^struct ^Base { virtual void foo() = 0; };");
+  // Available for inherited virtual functions
+  EXPECT_AVAILABLE(R"cpp(
+struct Base { virtual void foo() = 0; };
+^struct ^Intermediate : public Base {};
+)cpp");
+  // Available for inherited virtual functions, even if already overriden
+  EXPECT_AVAILABLE(R"cpp(
+struct Base { virtual void foo() = 0; };
+^struct ^Intermediate : public Base { void foo() override; };
+)cpp");
+}
+
+TEST_F(AddSubclassAllVirtualsTest, ApplyInDifferenctScopes) {
+  struct {
+    llvm::StringRef TestSource;
+    llvm::StringRef ExpectedSource;
+  } Cases[]{
+      // Basic case, outside any namespace
+      {
+          R"cpp(
+struct ^Base { virtual int foo() = 0; };
+)cpp",
+          R"cpp(
+struct Base { virtual int foo() = 0; };
+
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+)cpp",
+      },
+      // Inserted between two classes
+      {
+          R"cpp(
+struct ^Base { virtual int foo() = 0; };
+struct OtherStruct {};
+)cpp",
+          R"cpp(
+struct Base { virtual int foo() = 0; };
+
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+struct OtherStruct {};
+)cpp",
+      },
+      // Inside a namespace
+      {
+          R"cpp(
+namespace NS {
+struct ^Base { virtual int foo() = 0; };
+})cpp",
+          R"cpp(
+namespace NS {
+struct Base { virtual int foo() = 0; };
+
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+})cpp",
+      },
+      // Inside an outer class
+      {
+          R"cpp(
+struct Outer {
+struct ^Base { virtual int foo() = 0; };
+};)cpp",
+          R"cpp(
+struct Outer {
+struct Base { virtual int foo() = 0; };
+
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+};)cpp",
+      },
+      // Chooses a fresh unused name
+      {
+          R"cpp(
+struct ^Base { virtual int foo() = 0; };
+struct BaseSub;
+struct BaseSub1;
+struct BaseSub2;
+struct BaseSub4;
+)cpp",
+          R"cpp(
+struct Base { virtual int foo() = 0; };
+
+struct BaseSub3 : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+struct BaseSub;
+struct BaseSub1;
+struct BaseSub2;
+struct BaseSub4;
+)cpp",
+      },
+      // Does not accidentally collide with a comment on the last line
+      // of the file without a newline.
+      {
+          R"cpp(
+struct ^Base { virtual int foo() = 0; };
+// Some comment...)cpp",
+          R"cpp(
+struct Base { virtual int foo() = 0; };
+// Some comment...
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+)cpp",
+      },
+      // Is not confused by forward declarations (which might leak in,
+      // e.g., through `#include`s). Inserts the subclass directly after the
+      // function we trigerred the result refactoring on.
+      {
+          R"cpp(
+struct Base;
+struct ^Base { virtual int foo() = 0; };
+struct Base;
+)cpp",
+          R"cpp(
+struct Base;
+struct Base { virtual int foo() = 0; };
+
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+struct Base;
+)cpp",
+      },
+  };
+  llvm::StringMap<std::string> EditedFiles;
+  for (const auto &Case : Cases) {
+    for (const auto &SubCase : expandCases(Case.TestSource)) {
+      EXPECT_EQ(apply(SubCase, &EditedFiles), Case.ExpectedSource);
+    }
+  }
+}
+
+TEST_F(AddSubclassAllVirtualsTest, GeneratesCorrectSubclass) {
+  struct {
+    llvm::StringRef BaseClass;
+    llvm::StringRef GeneratedSubclass;
+    std::vector<std::string> ExtraArgs = {};
+  } Cases[]{
+      // Basic case; generating a `struct` inheriting from the base class
+      {
+          R"cpp(
+struct ^Base { virtual int foo() = 0; };
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+)cpp"},
+      // Only overrides virtual functions; leaves other functions alone
+      {
+          R"cpp(
+struct ^Base {
+  virtual int foo() = 0;
+  int bar();
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+)cpp"},
+      // Also supports overriding the virtual destructor, overloaded operators
+      // and conversion functions.
+      {
+          R"cpp(
+struct ^Base {
+  virtual ~Base() = 0;
+  virtual operator double() = 0;
+  virtual int operator[](int X) = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  ~BaseSub() override = default;
+  operator double() override { return Base::operator double(); }
+  int operator[](int X) override { return Base::operator[](X); }
+};
+)cpp"},
+      // Function attributes like `const`, `noexcept` etc. are copied
+      {
+          R"cpp(
+struct ^Base {
+  consteval virtual int foo() noexcept;
+  constexpr virtual operator double() const;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  consteval int foo() noexcept override { return Base::foo(); }
+  constexpr operator double() const override { return Base::operator double(); }
+};
+)cpp",
+          {"-std=c++20"}},
+      // Uses `class` instead of struct if the base class was also a `class`
+      {
+          R"cpp(
+class ^Base {
+public:
+  virtual int foo() = 0;
+};
+)cpp",
+          R"cpp(
+class BaseSub : public Base {
+  using Base::Base;
+public:
+  int foo() override { return Base::foo(); }
+};
+)cpp"},
+      // Default implementation does not contain a `return` for void functions
+      {
+          R"cpp(
+struct ^Base {
+  virtual void foo() = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  void foo() override { Base::foo(); }
+};
+)cpp"},
+      // No default implementation for private functions.
+      // We can't call the private implementation of the base class.
+      {
+          R"cpp(
+struct ^Base {
+private:
+  virtual void foo() = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+private:
+  void foo() override;
+};
+)cpp"},
+      // Default implementation forwards parameters
+      {
+          R"cpp(
+struct Moveable {
+  Moveable() = default;
+  Moveable(Moveable&&) = default;
+};
+
+struct ^Base {
+  virtual void foo(int a) = 0;
+  virtual void bar(int a, double b) = 0;
+  virtual void baz(int, double b, char) = 0;
+  virtual void foobar(int&& x) = 0;
+  virtual void foobaz(Moveable x) = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  void foo(int a) override { Base::foo(a); }
+  void bar(int a, double b) override { Base::bar(a, b); }
+  void baz(int _1, double b, char _3) override { Base::baz(_1, b, _3); }
+  void foobar(int &&x) override { Base::foobar(std::move(x)); }
+  void foobaz(Moveable x) override { Base::foobaz(std::move(x)); }
+};
+)cpp"},
+      // Can expand multiple overloaded functions
+      {
+          R"cpp(
+struct ^Base {
+  virtual void foo(int a) = 0;
+  virtual void foo(double b) = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  void foo(int a) override { Base::foo(a); }
+  void foo(double b) override { Base::foo(b); }
+};
+)cpp"},
+      // Collects virtual functions  from *all* base classes
+      {
+          R"cpp(
+struct Base1 {
+  virtual void foo() = 0;
+};
+struct Intermediate : public Base1 {
+  virtual void bar() = 0;
+};
+struct Base2 {
+  virtual void baz() = 0;
+};
+struct ^Base : public Intermediate, Base2 {
+  virtual void foobar() = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  void foo() override { Base::foo(); }
+  void bar() override { Base::bar(); }
+  void baz() override { Base::baz(); }
+  void foobar() override { Base::foobar(); }
+};
+)cpp",
+      },
+      // Correctly propagates visibility
+      {
+          R"cpp(
+struct ^Base {
+  virtual int publicFoo() = 0;
+private:
+  virtual int privateFoo() = 0;
+protected:
+  virtual int protectedFoo() = 0;
+public:
+  virtual int publicBar() = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  int publicFoo() override { return Base::publicFoo(); }
+private:
+  int privateFoo() override;
+protected:
+  int protectedFoo() override { return Base::protectedFoo(); }
+public:
+  int publicBar() override { return Base::publicBar(); }
+};
+)cpp",
+      },
+      // Keeps the structuring into multiple `public`/`protected`/`private`
+      // blocks form the base class
+      {
+          R"cpp(
+class ^Base {
+public:
+  virtual int publicFoo() = 0;
+  virtual int publicBar() = 0;
+protected:
+  virtual int protectedFoo() = 0;
+public:
+  virtual int publicBaz() = 0;
+};
+)cpp",
+          R"cpp(
+class BaseSub : public Base {
+  using Base::Base;
+public:
+  int publicFoo() override { return Base::publicFoo(); }
+  int publicBar() override { return Base::publicBar(); }
+protected:
+  int protectedFoo() override { return Base::protectedFoo(); }
+public:
+  int publicBaz() override { return Base::publicBaz(); }
+};
+)cpp",
+      },
+      // Correctly propagates visibility also for non-public inheritance
+      {
+          R"cpp(
+struct Base1 {
+  virtual void foo() = 0;
+private:
+  virtual void privateFoo() = 0;
+};
+struct Intermediate : protected Base1 {
+  virtual void bar() = 0;
+};
+struct Base2 {
+  virtual void baz() = 0;
+};
+struct ^Base : public Intermediate, private Base2 {
+  virtual void foobar() = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+protected:
+  void foo() override { Base::foo(); }
+private:
+  void privateFoo() override;
+public:
+  void bar() override { Base::bar(); }
+private:
+  void baz() override;
+public:
+  void foobar() override { Base::foobar(); }
+};
+)cpp",
+      },
+      // Copies comments
+      {
+          R"cpp(
+struct ^Base {
+  // Some comment
+  virtual void foo() = 0;
+
+  // A method with a brief description
+  //
+  // And a longer description
+  virtual void bar() = 0;
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  // Some comment
+  void foo() override { Base::foo(); }
+  // A method with a brief description
+  void bar() override { Base::bar(); }
+};
+)cpp",
+      },
+  };
+  llvm::StringMap<std::string> EditedFiles;
+  for (const auto &Case : Cases) {
+    ExtraArgs = Case.ExtraArgs;
+    Annotations Code(Case.BaseClass);
+    for (const auto &SubCase : expandCases(Case.BaseClass)) {
+      EXPECT_EQ(apply(SubCase, &EditedFiles),
+                (Code.code() + Case.GeneratedSubclass).str());
+    }
+  }
+}
+
+// We do not test everything again, but only test the difference in behavior
+TWEAK_TEST(AddSubclassPureVirtualOnly);
+
+TEST_F(AddSubclassPureVirtualOnlyTest,
+       OnlyOverridesNonImplementedVirtualFunctions) {
+  struct {
+    llvm::StringRef BaseClass;
+    llvm::StringRef GeneratedSubclass;
+  } Cases[]{
+      // Only overrides pure virtual functions; leaves other virtual  functions
+      // alone
+      {
+          R"cpp(
+struct ^Base {
+  virtual int foo() = 0;
+  virtual int bar();
+};
+)cpp",
+          R"cpp(
+struct BaseSub : public Base {
+  using Base::Base;
+  int foo() override { return Base::foo(); }
+};
+)cpp"},
+      // Does not override pure virtual functions if they were already
+      // implemented by an intermediate class
+      // and conversion functions.
+      {
+          R"cpp(
+struct Base {
+  virtual int foo() = 0;
+  virtual int bar() = 0;
+};
+struct ^Intermediate : public Base {
+  int foo() override;
+};
+)cpp",
+          R"cpp(
+struct IntermediateSub : public Intermediate {
+  using Intermediate::Intermediate;
+  int bar() override { return Intermediate::bar(); }
+};
+)cpp"},
+  };
+  llvm::StringMap<std::string> EditedFiles;
+  for (const auto &Case : Cases) {
+    Annotations Code(Case.BaseClass);
+    for (const auto &SubCase : expandCases(Case.BaseClass)) {
+      EXPECT_EQ(apply(SubCase, &EditedFiles),
+                (Code.code() + Case.GeneratedSubclass).str());
+    }
+  }
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/unittests/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/unittests/CMakeLists.txt
+++ clang-tools-extra/clangd/unittests/CMakeLists.txt
@@ -106,6 +106,7 @@
   support/TestTracer.cpp
   support/TraceTests.cpp
 
+  tweaks/AddSubclassTests.cpp
   tweaks/AddUsingTests.cpp
   tweaks/AnnotateHighlightingsTests.cpp
   tweaks/DefineInlineTests.cpp
Index: clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
+++ clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt
@@ -12,6 +12,7 @@
 # $<TARGET_OBJECTS:obj.clangDaemonTweaks> to a list of sources, see
 # clangd/tool/CMakeLists.txt for an example.
 add_clang_library(clangDaemonTweaks OBJECT
+  AddSubclass.cpp
   AddUsing.cpp
   AnnotateHighlightings.cpp
   DumpAST.cpp
Index: clang-tools-extra/clangd/refactor/tweaks/AddSubclass.cpp
===================================================================
--- /dev/null
+++ clang-tools-extra/clangd/refactor/tweaks/AddSubclass.cpp
@@ -0,0 +1,422 @@
+//===--- ExpandAutoType.cpp --------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "refactor/Tweak.h"
+
+#include "XRefs.h"
+#include "refactor/InsertionPoint.h"
+#include "support/Logger.h"
+#include "clang/AST/Type.h"
+#include "clang/AST/TypeLoc.h"
+#include "clang/Basic/LLVM.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Error.h"
+#include <AST.h>
+#include <climits>
+#include <memory>
+#include <string>
+
+namespace clang {
+namespace clangd {
+namespace {
+
+using MethodVector =
+    std::vector<std::pair<AccessSpecifier, const CXXMethodDecl *>>;
+
+/// Adds a subclass implementing all virtual functions
+/// Given:
+///    struct Base {
+///      Base(int x, std::string y);
+///      /// Do something
+///      virtual int foo();
+///      /// Do something else
+///      virtual double bar(double x) = 0;
+///    };
+/// Adds a new class:
+///    struct BaseSub : public Base {
+///      using Base::Base;
+///      /// Do something
+///      int foo() override { return Base::foo(); }
+///      /// Do something else
+///      double bar(double x) { return Base::bar(x); }
+///    };
+/// The new is added directly below its base class in the same file. In the
+/// future, a separate tweak could be used to move this class somewhere else.
+/// The subclass is always called `BaseSub`, potentially disambiguated with a
+/// numeric suffix to avoid name clashes. The "Rename" functionality can be used
+/// to rename the new class.
+///
+/// The generated methods delegate to the respective implementations of the
+/// parent class by default. If not appropriate, the corresponding code can be
+/// deleted by the user explicitely. We provide this default implementation to
+/// guard the user against accidentally forgetting to call the base class'
+/// method. It's easier to spot incorrect code than incorrectly missing code.
+/// The methods bodies are defined inside the function declaration, the "Define
+/// Outline" tweak can be used to move them into a `.cpp` file.
+class AddSubclass : public Tweak {
+public:
+  AddSubclass(bool UnimplementedPureVirtualOnly)
+      : UnimplementedPureVirtualOnly(UnimplementedPureVirtualOnly) {}
+
+  llvm::StringLiteral kind() const override {
+    return CodeAction::REFACTOR_KIND;
+  }
+  bool prepare(const Selection &Inputs) override;
+  Expected<Effect> apply(const Selection &Inputs) override;
+
+private:
+  bool UnimplementedPureVirtualOnly;
+  /// Cache the CXXMethodDecls, so that we do not need to search twice.
+  MethodVector CachedMethods;
+  /// Cache the CXXRecordDecl, so that we do not need to search twice.
+  llvm::Optional<const clang::CXXRecordDecl *> CachedLocation;
+};
+
+/// Variation of the `AddSubclass` tweak which overrides all
+/// virtual methods.
+class AddSubclassAllVirtuals : public AddSubclass {
+public:
+  AddSubclassAllVirtuals() : AddSubclass(false) {}
+  const char *id() const final;
+  std::string title() const override {
+    return "Add subclass, overriding all virtual methods";
+  }
+};
+
+/// Variation of the `AddSubclass` tweak which overrides only
+/// pure virtual methods.
+class AddSubclassPureVirtualOnly : public AddSubclass {
+public:
+  AddSubclassPureVirtualOnly() : AddSubclass(true) {}
+  const char *id() const final;
+  std::string title() const override {
+    return "Add subclass, overriding unimplemented pure virtual methods";
+  }
+};
+
+REGISTER_TWEAK(AddSubclassAllVirtuals)
+REGISTER_TWEAK(AddSubclassPureVirtualOnly)
+
+// FIXME: copied from `clangd-doc`. Can I somehow deduplicate this?
+static AccessSpecifier getFinalAccessSpecifier(AccessSpecifier FirstAS,
+                                               AccessSpecifier SecondAS) {
+  if (FirstAS == AccessSpecifier::AS_none ||
+      SecondAS == AccessSpecifier::AS_none)
+    return AccessSpecifier::AS_none;
+  if (FirstAS == AccessSpecifier::AS_private ||
+      SecondAS == AccessSpecifier::AS_private)
+    return AccessSpecifier::AS_private;
+  if (FirstAS == AccessSpecifier::AS_protected ||
+      SecondAS == AccessSpecifier::AS_protected)
+    return AccessSpecifier::AS_protected;
+  return AccessSpecifier::AS_public;
+}
+
+static void collectRelevantMethodDecls(const CXXRecordDecl &Decl,
+                                       MethodVector &Target,
+                                       AccessSpecifier InheritanceAS,
+                                       bool UnimplementedPureVirtualOnly) {
+  // Collect all virtual methods of all base classes
+  for (const CXXBaseSpecifier &Base : Decl.bases()) {
+    auto BaseType = Base.getType().getCanonicalType();
+    auto BaseAS = Base.getAccessSpecifier();
+    if (BaseAS == AccessSpecifier::AS_none)
+      BaseAS = AccessSpecifier::AS_private;
+    BaseAS = getFinalAccessSpecifier(BaseAS, InheritanceAS);
+    auto *BaseRecordDecl = BaseType.getTypePtr()->getAsCXXRecordDecl();
+    // We simply ignore all base classes which are not `BaseRecordDecl`. Base
+    // classes might, e.g., also be template type arguments instead.
+    if (BaseRecordDecl) {
+      collectRelevantMethodDecls(*BaseRecordDecl, Target, BaseAS,
+                                 UnimplementedPureVirtualOnly);
+    }
+  }
+
+  // Collect all virtual methods of this class
+  for (CXXMethodDecl *M : Decl.methods()) {
+    // Only collect actual base functions. Ignore functions which already
+    // override other functions from one of our base classes.
+    if (M->size_overridden_methods() == 0) {
+      bool Qualifies =
+          UnimplementedPureVirtualOnly ? M->isPure() : M->isVirtual();
+      if (Qualifies) {
+        auto MethodAS = getFinalAccessSpecifier(InheritanceAS, M->getAccess());
+        Target.emplace_back(MethodAS, M);
+      }
+    }
+  }
+
+  // Remove all overriden methods
+  if (UnimplementedPureVirtualOnly) {
+    for (CXXMethodDecl *M : Decl.methods()) {
+      for (const CXXMethodDecl *Overridden : M->overridden_methods()) {
+        auto Iter = find_if(Target.begin(), Target.end(),
+                            [&](auto E) { return E.second == Overridden; });
+        if (Iter != Target.end())
+          Target.erase(Iter);
+      }
+    }
+  }
+}
+
+bool AddSubclass::prepare(const Selection &Inputs) {
+  // This tweak assumes move semantics.
+  if (!Inputs.AST->getLangOpts().CPlusPlus11)
+    return false;
+
+  CachedLocation = llvm::None;
+  CachedMethods.clear();
+  if (auto *Node = Inputs.ASTSelection.commonAncestor()) {
+    if (auto *Class = Node->ASTNode.get<CXXRecordDecl>()) {
+      CachedLocation = Class;
+      collectRelevantMethodDecls(*Class, CachedMethods,
+                                 AccessSpecifier::AS_public,
+                                 UnimplementedPureVirtualOnly);
+    }
+  }
+
+  return CachedLocation && !CachedMethods.empty();
+}
+
+/// Find a name which does not conflict with existing names
+/// by appending a number to the name, if necessary
+std::string getNewIdentifier(std::string Name, const ASTContext &AC,
+                             const DeclContext &DC) {
+  unsigned Counter = 0;
+  auto &Idents = AC.Idents;
+  while (true) {
+    std::string NumberedName = Name;
+    if (Counter) {
+      NumberedName += std::to_string(Counter);
+    }
+    IdentifierTable::iterator IdIter = Idents.find(NumberedName);
+    if (IdIter == Idents.end()) {
+      return NumberedName;
+    }
+    const IdentifierInfo *Identifier = IdIter->getValue();
+    if (DC.lookup(DeclarationName{Identifier}).empty()) {
+      return NumberedName;
+    }
+    ++Counter;
+  }
+}
+
+struct ForwardingParamInfo {
+  QualType Type;
+  std::string Name;
+  bool Move;
+};
+
+static bool canMoveRecordDecl(const CXXRecordDecl &C) {
+  // We can't always tell if C is copyable/movable without doing Sema work.
+  // We assume operations are possible unless we can prove not.
+  if (C.hasUserDeclaredMoveConstructor()) {
+    for (const CXXConstructorDecl *CCD : C.ctors()) {
+      if (CCD->isMoveConstructor() && CCD->isDeleted()) {
+        return false;
+      }
+    }
+  }
+  return C.hasUserDeclaredMoveConstructor() ||
+         C.needsOverloadResolutionForMoveConstructor() ||
+         !C.defaultedMoveConstructorIsDeleted();
+}
+
+static bool shouldMoveType(const Type *T) {
+  if (auto *RecordDecl = T->getAsCXXRecordDecl()) {
+    return canMoveRecordDecl(*RecordDecl);
+  }
+  return false;
+}
+
+static std::vector<ForwardingParamInfo>
+prepareForwardedArgs(const FunctionDecl &Func) {
+  std::vector<ForwardingParamInfo> Parameters;
+  Parameters.reserve(Func.param_size());
+  unsigned ParamNr = 0;
+  for (auto &Param : Func.parameters()) {
+    ++ParamNr;
+    auto Name = Param->getNameAsString();
+    if (Name.empty()) {
+      // Synthesize name
+      Name = llvm::formatv("_{0}", ParamNr);
+    }
+    auto ParamType = Param->getOriginalType();
+    bool Move = ParamType.getLocalUnqualifiedType()->isRValueReferenceType() ||
+                shouldMoveType(ParamType.getCanonicalType().getTypePtr());
+    Parameters.push_back({ParamType, Name, Move});
+  }
+  return Parameters;
+}
+
+static llvm::Expected<std::string>
+formatSubclassCode(const CXXRecordDecl &BaseClass,
+                   const MethodVector &Methods) {
+  auto &DC = *BaseClass.getParent();
+  auto &AC = BaseClass.getASTContext();
+
+  std::string S;
+  llvm::raw_string_ostream OS(S);
+  OS << "\n";
+
+  // Use the same keyword (struct or class) as the base class
+  AccessSpecifier CurrentAS;
+  if (BaseClass.isClass()) {
+    OS << "class ";
+    // We want a `private:` section header even if the first fuctions are
+    // private. Hence, don't set `CurrentAS` to `private` but to `none`.
+    CurrentAS = AccessSpecifier::AS_none;
+  } else {
+    OS << "struct ";
+    CurrentAS = AccessSpecifier::AS_public;
+  }
+
+  // Find a class name which does not conflict with existing names
+  std::string SubclassName =
+      getNewIdentifier(BaseClass.getNameAsString() + "Sub", AC, DC);
+  OS << SubclassName;
+
+  // Inherit from the base class
+  OS << " : public " << BaseClass.getName() << " {\n";
+
+  // We always inherit the constructors
+  OS << "  using " << BaseClass.getName() << "::" << BaseClass.getName()
+     << ";\n";
+
+  // Add the methods
+  for (auto &M : Methods) {
+    AccessSpecifier MethodAS = M.first;
+    const CXXMethodDecl *Method = M.second;
+
+    if (MethodAS != CurrentAS) {
+      OS << getAccessSpelling(MethodAS) << ":\n";
+      CurrentAS = MethodAS;
+    }
+
+    // Copy over the comment from the base class
+    auto *Comment = AC.getRawCommentForAnyRedecl(Method);
+    if (Comment && !Comment->isTrailingComment())
+      OS << "  // " << Comment->getBriefText(AC) << "\n";
+
+    OS << "  ";
+    if (Method->isConstexprSpecified())
+      OS << "constexpr ";
+    if (Method->isConsteval())
+      OS << "consteval ";
+
+    auto DeclName = Method->getDeclName();
+    bool printReturnType;
+    std::string MethodName;
+    switch (DeclName.getNameKind()) {
+    case DeclarationName::Identifier:
+      printReturnType = true;
+      MethodName = DeclName.getAsString();
+      break;
+    case DeclarationName::CXXDestructorName:
+      printReturnType = false;
+      MethodName = std::string{"~"} + SubclassName;
+      break;
+    case DeclarationName::CXXConversionFunctionName:
+      printReturnType = false;
+      MethodName = std::string{"operator "} +
+                   printType(Method->getReturnType(), BaseClass);
+      break;
+    case DeclarationName::CXXOperatorName:
+      printReturnType = true;
+      MethodName = std::string{"operator"} +
+                   getOperatorSpelling(DeclName.getCXXOverloadedOperator());
+      break;
+    case DeclarationName::CXXConstructorName:
+    case DeclarationName::ObjCZeroArgSelector:
+    case DeclarationName::ObjCOneArgSelector:
+    case DeclarationName::CXXDeductionGuideName:
+    case DeclarationName::CXXLiteralOperatorName:
+    case DeclarationName::CXXUsingDirective:
+    case DeclarationName::ObjCMultiArgSelector:
+      return error("Unsupported method type `{0}`", DeclName.getNameKind());
+    }
+    if (printReturnType) {
+      OS << printType(Method->getReturnType(), BaseClass);
+      OS << " ";
+    }
+
+    OS << MethodName;
+
+    // Print the argument list
+    auto ForwardedArgs = prepareForwardedArgs(*Method);
+    OS << "(";
+    const char *Sep = "";
+    for (auto &Arg : ForwardedArgs) {
+      OS << Sep;
+      OS << printType(Arg.Type, BaseClass, /*Placeholder=*/Arg.Name);
+      Sep = ", ";
+    }
+    OS << ")";
+
+    if (Method->isConst())
+      OS << " const";
+    if (Method->getExceptionSpecType() == EST_BasicNoexcept)
+      OS << " noexcept";
+    OS << " override";
+
+    if (MethodAS == AccessSpecifier::AS_private) {
+      // We don't provide a default implementation if the overriden method is
+      // private.
+      OS << ";\n";
+    } else if (Method->getDeclName().getNameKind() ==
+               DeclarationName::CXXDestructorName) {
+      OS << " = default;\n";
+    } else {
+      // The default implementation simply delegates to the base class
+      OS << " { ";
+      if (!Method->getReturnType()->isVoidType())
+        OS << "return ";
+      OS << BaseClass.getName() << "::" << MethodName << "(";
+      Sep = "";
+      for (auto &Arg : ForwardedArgs) {
+        OS << Sep;
+        if (Arg.Move)
+          OS << "std::move(";
+        OS << Arg.Name;
+        if (Arg.Move)
+          OS << ")";
+        Sep = ", ";
+      }
+      OS << "); }\n";
+    }
+  }
+
+  OS << "};\n";
+  OS.flush();
+  return S;
+}
+
+Expected<Tweak::Effect> AddSubclass::apply(const Selection &Inputs) {
+  auto *Class = *CachedLocation;
+  auto &SM = Inputs.AST->getSourceManager();
+
+  auto SubclassCode = formatSubclassCode(*Class, CachedMethods);
+  if (!SubclassCode)
+    return SubclassCode.takeError();
+
+  tooling::Replacements Replacements;
+  auto Insertion = insertDecl(
+      *SubclassCode, *Class->getLexicalParent(),
+      {Anchor{[&](const Decl *D) { return D == Class; }, Anchor::Below}});
+  if (!Insertion)
+    return Insertion.takeError();
+  auto AddError = Replacements.add(std::move(*Insertion));
+  if (AddError)
+    return AddError;
+  return Effect::mainFileEdit(SM, std::move(Replacements));
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: clang-tools-extra/clangd/refactor/InsertionPoint.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/InsertionPoint.cpp
+++ clang-tools-extra/clangd/refactor/InsertionPoint.cpp
@@ -119,8 +119,11 @@
   // Fallback: insert at the end.
   if (Loc.isInvalid())
     Loc = endLoc(DC);
+  if (Loc.isInvalid())
+    return error("Couldn't find a valid location for insertion");
   const auto &SM = DC.getParentASTContext().getSourceManager();
-  if (!SM.isWrittenInSameFile(Loc, cast<Decl>(DC).getLocation()))
+  auto DeclLoc = cast<Decl>(DC).getLocation();
+  if (DeclLoc.isValid() && !SM.isWrittenInSameFile(Loc, DeclLoc))
     return error("{0} body in wrong file: {1}", DC.getDeclKindName(),
                  Loc.printToString(SM));
   return tooling::Replacement(SM, Loc, 0, Code);
Index: clang-tools-extra/clangd/AST.h
===================================================================
--- clang-tools-extra/clangd/AST.h
+++ clang-tools-extra/clangd/AST.h
@@ -94,7 +94,8 @@
 
 /// Returns a QualType as string. The result doesn't contain unwritten scopes
 /// like anonymous/inline namespace.
-std::string printType(const QualType QT, const DeclContext &CurContext);
+std::string printType(const QualType QT, const DeclContext &CurContext,
+                      llvm::StringRef Placeholder = "");
 
 /// Indicates if \p D is a template instantiation implicitly generated by the
 /// compiler, e.g.
Index: clang-tools-extra/clangd/AST.cpp
===================================================================
--- clang-tools-extra/clangd/AST.cpp
+++ clang-tools-extra/clangd/AST.cpp
@@ -349,7 +349,8 @@
   return SymbolID(USR);
 }
 
-std::string printType(const QualType QT, const DeclContext &CurContext) {
+std::string printType(const QualType QT, const DeclContext &CurContext,
+                      const llvm::StringRef Placeholder) {
   std::string Result;
   llvm::raw_string_ostream OS(Result);
   PrintingPolicy PP(CurContext.getParentASTContext().getPrintingPolicy());
@@ -370,7 +371,7 @@
   PrintCB PCB(&CurContext);
   PP.Callbacks = &PCB;
 
-  QT.print(OS, PP);
+  QT.print(OS, PP, Placeholder);
   return OS.str();
 }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to