KKoovalsky updated this revision to Diff 491331.
KKoovalsky added a comment.

1. Fixed lit fail: dereferencing a null pointer when calling `ignoreImplicit()` 
on `CommonAncestor` being `nullptr` - fixed with sanitization of 
`CommonAncestor` being `nullptr`.
2. Fixed link error: missing `ASTMatchers` lib.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D140619/new/

https://reviews.llvm.org/D140619

Files:
  clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
  clang-tools-extra/clangd/tool/CMakeLists.txt
  clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp

Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
+++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
@@ -24,8 +24,6 @@
 
   // Root statements should have common parent.
   EXPECT_EQ(apply("for(;;) [[1+2; 1+2;]]"), "unavailable");
-  // Expressions aren't extracted.
-  EXPECT_EQ(apply("int x = 0; [[x++;]]"), "unavailable");
   // We don't support extraction from lambdas.
   EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable");
   // Partial statements aren't extracted.
@@ -190,6 +188,16 @@
     }]]
   )cpp";
   EXPECT_EQ(apply(CompoundFailInput), "unavailable");
+
+  std::string CompoundWithMultipleStatementsFailInput = R"cpp(
+    void f() [[{
+      int a = 1;
+      int b = 2;
+      ++b;
+      b += a;
+    }]]
+  )cpp";
+  EXPECT_EQ(apply(CompoundWithMultipleStatementsFailInput), "unavailable");
 }
 
 TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) {
@@ -571,6 +579,795 @@
   EXPECT_EQ(apply(Before), After);
 }
 
+TEST_F(ExtractFunctionTest, Expressions) {
+  std::vector<std::pair<std::string, std::string>> InputOutputs{
+      // FULL BINARY EXPRESSIONS
+      // Full binary expression, basic maths
+      {R"cpp(
+void wrapperFun() {
+  double a{2.0}, b{3.2}, c{31.55};
+  double v{[[b * b - 4 * a * c]]};
+}
+      )cpp",
+       R"cpp(
+double extracted(double &a, double &b, double &c) {
+return b * b - 4 * a * c;
+}
+void wrapperFun() {
+  double a{2.0}, b{3.2}, c{31.55};
+  double v{extracted(a, b, c)};
+}
+      )cpp"},
+      // Full binary expression composed of '+' operator overloads ops
+      {
+          R"cpp(
+struct S {
+  S operator+(const S&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  S S1, S2, S3;
+  auto R{[[S1 + S2 + S3]]};
+}
+      )cpp",
+          R"cpp(
+struct S {
+  S operator+(const S&) {
+    return *this;
+  }
+};
+S extracted(S &S1, S &S2, S &S3) {
+return S1 + S2 + S3;
+}
+void wrapperFun() {
+  S S1, S2, S3;
+  auto R{extracted(S1, S2, S3)};
+}
+      )cpp"},
+      // Boolean predicate as expression
+      {
+          R"cpp(
+void wrapperFun() {
+  int a{1};
+  auto R{[[a > 1]]};
+}
+      )cpp",
+          R"cpp(
+bool extracted(int &a) {
+return a > 1;
+}
+void wrapperFun() {
+  int a{1};
+  auto R{extracted(a)};
+}
+      )cpp"},
+      // Expression: captures no global variable
+      {R"cpp(
+static int a{2};
+void wrapperFun() {
+  int b{3}, c{31}, d{311};
+  auto v{[[a + b + c + d]]};
+}
+      )cpp",
+       R"cpp(
+static int a{2};
+int extracted(int &b, int &c, int &d) {
+return a + b + c + d;
+}
+void wrapperFun() {
+  int b{3}, c{31}, d{311};
+  auto v{extracted(b, c, d)};
+}
+      )cpp"},
+      // Full expr: infers return type of call returning by ref
+      {
+          R"cpp(
+struct S {
+  S& operator+(const S&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  S S1, S2, S3;
+  auto R{[[S1 + S2 + S3]]};
+}
+      )cpp",
+          R"cpp(
+struct S {
+  S& operator+(const S&) {
+    return *this;
+  }
+};
+S & extracted(S &S1, S &S2, S &S3) {
+return S1 + S2 + S3;
+}
+void wrapperFun() {
+  S S1, S2, S3;
+  auto R{extracted(S1, S2, S3)};
+}
+      )cpp"},
+      // Full expr: infers return type of call returning by const-ref
+      {
+          R"cpp(
+struct S {
+  const S& operator+(const S&) const {
+    return *this;
+  }
+};
+void wrapperFun() {
+  S S1, S2, S3;
+  auto R{[[S1 + S2 + S3]]};
+}
+      )cpp",
+          R"cpp(
+struct S {
+  const S& operator+(const S&) const {
+    return *this;
+  }
+};
+const S & extracted(S &S1, S &S2, S &S3) {
+return S1 + S2 + S3;
+}
+void wrapperFun() {
+  S S1, S2, S3;
+  auto R{extracted(S1, S2, S3)};
+}
+      )cpp"},
+      // Captures deeply nested arguments
+      {
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{[[fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f))]]};
+}
+      )cpp",
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+int extracted(int &a, int &b, int &c, int &d, int &e, int &f) {
+return fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f));
+}
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{extracted(a, b, c, d, e, f)};
+}
+      )cpp"},
+      // SUBEXPRESSIONS
+      // Left-aligned subexpression
+      {R"cpp(
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{13};
+  auto v{[[a + b]] + c + d};
+}
+      )cpp",
+       R"cpp(
+int extracted(int &a, int &b) {
+return a + b;
+}
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{13};
+  auto v{extracted(a, b) + c + d};
+}
+      )cpp"},
+      {R"cpp(
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{13};
+  auto v{[[a + b + c]] + d};
+}
+      )cpp",
+       R"cpp(
+int extracted(int &a, int &b, int &c) {
+return a + b + c;
+}
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{13};
+  auto v{extracted(a, b, c) + d};
+}
+      )cpp"},
+      // Subexpression from the middle
+      {R"cpp(
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{15}, e{300};
+  auto v{a + [[b + c + d]] + e};
+}
+      )cpp",
+       R"cpp(
+int extracted(int &b, int &c, int &d) {
+return b + c + d;
+}
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{15}, e{300};
+  auto v{a + extracted(b, c, d) + e};
+}
+      )cpp"},
+      // Right-aligned subexpression
+      {R"cpp(
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{15}, e{300};
+  auto v{a + b + [[c + d + e]]};
+}
+      )cpp",
+       R"cpp(
+int extracted(int &c, int &d, int &e) {
+return c + d + e;
+}
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{15}, e{300};
+  auto v{a + b + extracted(c, d, e)};
+}
+      )cpp"},
+      // Larger subexpression from the middle
+      {R"cpp(
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{311};
+  auto v{a + [[a + b + c + d]] + c};
+}
+      )cpp",
+       R"cpp(
+int extracted(int &a, int &b, int &c, int &d) {
+return a + b + c + d;
+}
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{311};
+  auto v{a + extracted(a, b, c, d) + c};
+}
+      )cpp"},
+      // Subexpression with duplicated references
+      {R"cpp(
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{311};
+  auto v{a + b + [[c + c + c + d + d]] + c};
+}
+      )cpp",
+       R"cpp(
+int extracted(int &c, int &d) {
+return c + c + c + d + d;
+}
+void wrapperFun() {
+  int a{2}, b{3}, c{31}, d{311};
+  auto v{a + b + extracted(c, d) + c};
+}
+      )cpp"},
+      // Subexpression: captures no global variable
+      {R"cpp(
+static int a{2};
+void wrapperFun() {
+  int b{3}, c{31}, d{311};
+  auto v{[[a + b + c]] + d};
+}
+      )cpp",
+       R"cpp(
+static int a{2};
+int extracted(int &b, int &c) {
+return a + b + c;
+}
+void wrapperFun() {
+  int b{3}, c{31}, d{311};
+  auto v{extracted(b, c) + d};
+}
+      )cpp"},
+      // Subexpression: infers return type of call returning by ref, LHS
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2;
+  auto LS3{[[LS1.get()]] + LS2};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS1) {
+return LS1.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2;
+  auto LS3{extracted(LS1) + LS2};
+}
+      )cpp"},
+      // Subexpression: infers return type of call returning by ref, most-RHS
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3;
+  auto LS4{LS1 + LS2 + [[LS3.get()]]};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS3) {
+return LS3.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3;
+  auto LS4{LS1 + LS2 + extracted(LS3)};
+}
+      )cpp"},
+      // Subexpression: infers return type of call returning by ref, middle RHS
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct getCopy() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3;
+  auto LS4{LS1.getCopy() + [[LS2.get()]] + LS3};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct getCopy() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS2) {
+return LS2.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3;
+  auto LS4{LS1.getCopy() + extracted(LS2) + LS3};
+}
+      )cpp"},
+      // Subexpr: infers return type of call returning by const-ref
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2;
+  auto LS3{LS1 + [[LS2.get()]]};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+const LargeStruct & extracted(LargeStruct &LS2) {
+return LS2.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2;
+  auto LS3{LS1 + extracted(LS2)};
+}
+      )cpp"},
+      // Subexpression on operator overload, left-aligned
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4;
+  auto& LS5{[[LS1 + LS2.get()]] + LS3.get() + LS4};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS1, LargeStruct &LS2) {
+return LS1 + LS2.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4;
+  auto& LS5{extracted(LS1, LS2) + LS3.get() + LS4};
+}
+      )cpp"},
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4;
+  auto& LS5{[[LS1 + LS2.get() + LS3.get()]] + LS4};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS1, LargeStruct &LS2, LargeStruct &LS3) {
+return LS1 + LS2.get() + LS3.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4;
+  auto& LS5{extracted(LS1, LS2, LS3) + LS4};
+}
+      )cpp"},
+      // Subexpression on operator overload, middle-aligned
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4, LS5;
+  auto& R{LS1 + [[LS2.get() + LS3 + LS4.get()]] + LS5};
+}
+      )cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS2, LargeStruct &LS3, LargeStruct &LS4) {
+return LS2.get() + LS3 + LS4.get();
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4, LS5;
+  auto& R{LS1 + extracted(LS2, LS3, LS4) + LS5};
+}
+      )cpp"},
+      // Subexpression on operator overload, right-aligned
+      {
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4, LS5;
+  auto& R{LS1 + LS2.get() + [[LS3 + LS4.get() + LS5]]};
+})cpp",
+          R"cpp(
+struct LargeStruct {
+  char LargeMember[1024];
+  const LargeStruct& get() {
+    return *this;
+  }
+  LargeStruct& operator+(const LargeStruct&) {
+    return *this;
+  }
+};
+LargeStruct & extracted(LargeStruct &LS3, LargeStruct &LS4, LargeStruct &LS5) {
+return LS3 + LS4.get() + LS5;
+}
+void wrapperFun() {
+  LargeStruct LS1, LS2, LS3, LS4, LS5;
+  auto& R{LS1 + LS2.get() + extracted(LS3, LS4, LS5)};
+})cpp"},
+      // Boolean predicate as subexpression
+      {
+          R"cpp(
+void wrapperFun() {
+  int a{1}, b{2};
+  auto R{a > 1 ? [[b <= 0]] : false};
+}
+      )cpp",
+          R"cpp(
+bool extracted(int &b) {
+return b <= 0;
+}
+void wrapperFun() {
+  int a{1}, b{2};
+  auto R{a > 1 ? extracted(b) : false};
+}
+      )cpp"},
+      // Collects deeply nested arguments, left-aligned
+      {
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{[[fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e))))]] + fw(fw(f))};
+}
+      )cpp",
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+int extracted(int &a, int &b, int &c, int &d, int &e) {
+return fw(fw(fw(a))) + fw(fw(add(b, c))) + fw(fw(fw(add(d, e))));
+}
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{extracted(a, b, c, d, e) + fw(fw(f))};
+}
+      )cpp"},
+      // Collects deeply nested arguments, middle-aligned
+      {
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{fw(fw(fw(a))) + [[fw(fw(add(b, c))) + fw(fw(fw(add(d, e))))]] + fw(fw(f))};
+}
+      )cpp",
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+int extracted(int &b, int &c, int &d, int &e) {
+return fw(fw(add(b, c))) + fw(fw(fw(add(d, e))));
+}
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{fw(fw(fw(a))) + extracted(b, c, d, e) + fw(fw(f))};
+}
+      )cpp"},
+      // Collects deeply nested arguments, right-aligned
+      {
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{fw(fw(fw(a))) + [[fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f))]]};
+}
+      )cpp",
+          R"cpp(
+int fw(int a) { return a; };
+int add(int a, int b) { return a + b; }
+int extracted(int &b, int &c, int &d, int &e, int &f) {
+return fw(fw(add(b, c))) + fw(fw(fw(add(d, e)))) + fw(fw(f));
+}
+void wrapper() {
+    int a{0}, b{1}, c{2}, d{3}, e{4}, f{5};
+    int r{fw(fw(fw(a))) + extracted(b, c, d, e, f)};
+}
+      )cpp"},
+      // FIXME: Support macros: In this case the most-LHS is not omitted!
+      {R"cpp(
+#define ECHO(X) X
+void f() {
+    int x = 1 + [[ECHO(2 + 3) + 4]] + 5;
+})cpp",
+       R"cpp(
+#define ECHO(X) X
+int extracted() {
+return 1 + ECHO(2 + 3) + 4;
+}
+void f() {
+    int x = extracted() + 5;
+})cpp"},
+  };
+
+  for (const auto &[Input, Output] : InputOutputs) {
+    EXPECT_EQ(Output, apply(Input)) << Input;
+  }
+}
+
+TEST_F(ExtractFunctionTest, ExpressionsInMethodsSingleFile) {
+  // TODO: unavailable
+  // TODO: available
+
+  std::vector<std::pair<std::string, std::string>> InputOutputs{
+      // Expression: Does not capture members as parameters
+      // FIXME: If selected area does mutate members, make extracted() const
+      {R"cpp(
+struct S {
+void f() const {
+    int a{1}, b{2};
+    auto r{[[a + b + mem1 + mem2]]};
+}
+int mem1{0}, mem2{0};
+};
+)cpp",
+       R"cpp(
+struct S {
+int extracted(int &a, int &b) const {
+return a + b + mem1 + mem2;
+}
+void f() const {
+    int a{1}, b{2};
+    auto r{extracted(a, b)};
+}
+int mem1{0}, mem2{0};
+};
+)cpp"},
+      // Subexpression: Does not capture members as parameters
+      {R"cpp(
+struct S {
+void f() const {
+    int a{1}, b{2};
+    auto r{a + [[mem1 + mem2 + b + mem1]] + mem2};
+}
+int mem1{0}, mem2{0};
+};
+)cpp",
+       R"cpp(
+struct S {
+int extracted(int &b) const {
+return mem1 + mem2 + b + mem1;
+}
+void f() const {
+    int a{1}, b{2};
+    auto r{a + extracted(b) + mem2};
+}
+int mem1{0}, mem2{0};
+};
+)cpp"},
+  };
+
+  for (const auto &[Input, Output] : InputOutputs) {
+    EXPECT_EQ(Output, apply(Input)) << Input;
+  }
+}
+
+TEST_F(ExtractFunctionTest, ExpressionInMethodMultiFile) {
+  Header = R"cpp(
+    class SomeClass {
+      void f();
+      int mem1{0}, mem2{0};
+    };
+  )cpp";
+
+  std::string OutOfLineSource = R"cpp(
+    void SomeClass::f() {
+      int a{1}, b{2};
+      int x = [[a + mem1 + b + mem2]];
+    }
+  )cpp";
+
+  std::string OutOfLineSourceOutputCheck = R"cpp(
+    int SomeClass::extracted(int &a, int &b) {
+return a + mem1 + b + mem2;
+}
+void SomeClass::f() {
+      int a{1}, b{2};
+      int x = extracted(a, b);
+    }
+  )cpp";
+
+  std::string HeaderOutputCheck = R"cpp(
+    class SomeClass {
+      int extracted(int &a, int &b);
+void f();
+      int mem1{0}, mem2{0};
+    };
+  )cpp";
+
+  llvm::StringMap<std::string> EditedFiles;
+
+  EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck);
+  EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck);
+}
+
+TEST_F(ExtractFunctionTest, SubexpressionInMethodMultiFile) {
+  Header = R"cpp(
+    class SomeClass {
+      void f();
+      int mem1{0}, mem2{0};
+    };
+  )cpp";
+
+  std::string OutOfLineSource = R"cpp(
+    void SomeClass::f() {
+      int a{1}, b{2};
+      int x = a + [[mem1 + b + mem2]] + mem1;
+    }
+  )cpp";
+
+  std::string OutOfLineSourceOutputCheck = R"cpp(
+    int SomeClass::extracted(int &b) {
+return mem1 + b + mem2;
+}
+void SomeClass::f() {
+      int a{1}, b{2};
+      int x = a + extracted(b) + mem1;
+    }
+  )cpp";
+
+  std::string HeaderOutputCheck = R"cpp(
+    class SomeClass {
+      int extracted(int &b);
+void f();
+      int mem1{0}, mem2{0};
+    };
+  )cpp";
+
+  llvm::StringMap<std::string> EditedFiles;
+
+  EXPECT_EQ(apply(OutOfLineSource, &EditedFiles), OutOfLineSourceOutputCheck);
+  EXPECT_EQ(EditedFiles.begin()->second, HeaderOutputCheck);
+}
+
 } // namespace
 } // namespace clangd
 } // namespace clang
Index: clang-tools-extra/clangd/tool/CMakeLists.txt
===================================================================
--- clang-tools-extra/clangd/tool/CMakeLists.txt
+++ clang-tools-extra/clangd/tool/CMakeLists.txt
@@ -16,6 +16,7 @@
 clang_target_link_libraries(clangd
   PRIVATE
   clangAST
+  clangASTMatchers
   clangBasic
   clangFormat
   clangFrontend
Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -56,9 +56,12 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
+#include "clang/AST/ExprCXX.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Stmt.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Basic/LangOptions.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
@@ -73,6 +76,9 @@
 #include "llvm/Support/Error.h"
 #include "llvm/Support/raw_os_ostream.h"
 
+#include <algorithm>
+#include <optional>
+
 namespace clang {
 namespace clangd {
 namespace {
@@ -95,6 +101,208 @@
   OutOfLineDefinition
 };
 
+// Helpers for handling "binary subexpressions" like a + [[b + c]] + d. This is
+// taken from ExtractVariable, and adapted a little to handle collection of
+// parameters.
+struct ExtractedBinarySubexpressionSelection;
+
+class BinarySubexpressionSelection {
+
+public:
+  static inline std::optional<BinarySubexpressionSelection>
+  tryParse(const SelectionTree::Node &N, const SourceManager *SM) {
+    if (const BinaryOperator *Op =
+            llvm::dyn_cast_or_null<BinaryOperator>(N.ASTNode.get<Expr>())) {
+      return BinarySubexpressionSelection{SM, Op->getOpcode(), Op->getExprLoc(),
+                                          N.Children};
+    }
+    if (const CXXOperatorCallExpr *Op =
+            llvm::dyn_cast_or_null<CXXOperatorCallExpr>(
+                N.ASTNode.get<Expr>())) {
+      if (!Op->isInfixBinaryOp())
+        return std::nullopt;
+
+      llvm::SmallVector<const SelectionTree::Node *> SelectedOps;
+      // Not all children are args, there's also the callee (operator).
+      for (const auto *Child : N.Children) {
+        const Expr *E = Child->ASTNode.get<Expr>();
+        assert(E && "callee and args should be Exprs!");
+        if (E == Op->getArg(0) || E == Op->getArg(1))
+          SelectedOps.push_back(Child);
+      }
+      return BinarySubexpressionSelection{
+          SM, BinaryOperator::getOverloadedOpcode(Op->getOperator()),
+          Op->getExprLoc(), std::move(SelectedOps)};
+    }
+    return std::nullopt;
+  }
+
+  bool associative() const {
+    // Must also be left-associative!
+    switch (Kind) {
+    case BO_Add:
+    case BO_Mul:
+    case BO_And:
+    case BO_Or:
+    case BO_Xor:
+    case BO_LAnd:
+    case BO_LOr:
+      return true;
+    default:
+      return false;
+    }
+  }
+
+  bool crossesMacroBoundary() const {
+    FileID F = SM->getFileID(ExprLoc);
+    for (const SelectionTree::Node *Child : SelectedOperations)
+      if (SM->getFileID(Child->ASTNode.get<Expr>()->getExprLoc()) != F)
+        return true;
+    return false;
+  }
+
+  bool isExtractable() const {
+    return associative() and not crossesMacroBoundary();
+  }
+
+  void dumpSelectedOperations(llvm::raw_ostream &Os,
+                              const ASTContext &Cont) const {
+    for (const auto *Op : SelectedOperations)
+      Op->ASTNode.dump(Os, Cont);
+  }
+
+  std::optional<ExtractedBinarySubexpressionSelection> tryExtract() const;
+
+protected:
+  struct SelectedOperands {
+    llvm::SmallVector<const SelectionTree::Node *> Operands;
+    const SelectionTree::Node *Start;
+    const SelectionTree::Node *End;
+  };
+
+private:
+  BinarySubexpressionSelection(
+      const SourceManager *SM, BinaryOperatorKind Kind, SourceLocation ExprLoc,
+      llvm::SmallVector<const SelectionTree::Node *> SelectedOps)
+      : SM{SM}, Kind(Kind), ExprLoc(ExprLoc),
+        SelectedOperations(std::move(SelectedOps)) {}
+
+  SelectedOperands getSelectedOperands() const {
+    auto [Start, End]{getClosedRangeWithSelectedOperations()};
+
+    llvm::SmallVector<const SelectionTree::Node *> Operands;
+    Operands.reserve(SelectedOperations.size());
+    const SelectionTree::Node *BinOpSelectionIt{Start->Parent};
+
+    // Edge case: the selection starts from the most-left LHS, e.g. [[a+b+c]]+d
+    if (BinOpSelectionIt->Children.size() == 2)
+      Operands.emplace_back(BinOpSelectionIt->Children.front()); // LHS
+    // In case of operator+ call, the Children will contain the calle as well.
+    else if (BinOpSelectionIt->Children.size() == 3)
+      Operands.emplace_back(BinOpSelectionIt->Children[1]); // LHS
+
+    // Go up the Binary Operation three, up to the most-right RHS
+    for (; BinOpSelectionIt->Children.back() != End;
+         BinOpSelectionIt = BinOpSelectionIt->Parent)
+      Operands.emplace_back(BinOpSelectionIt->Children.back()); // RHS
+    // Remember to add the most-right RHS
+    Operands.emplace_back(End);
+
+    SelectedOperands Ops;
+    Ops.Start = Start;
+    Ops.End = End;
+    Ops.Operands = std::move(Operands);
+    return Ops;
+  }
+
+  std::pair<const SelectionTree::Node *, const SelectionTree::Node *>
+  getClosedRangeWithSelectedOperations() const {
+    BinaryOperatorKind OuterOp = Kind;
+    // Because the tree we're interested in contains only one operator type, and
+    // all eligible operators are left-associative, the shape of the tree is
+    // very restricted: it's a linked list along the left edges.
+    // This simplifies our implementation.
+    const SelectionTree::Node *Start = SelectedOperations.front(); // LHS
+    const SelectionTree::Node *End = SelectedOperations.back();    // RHS
+
+    // End is already correct: it can't be an OuterOp (as it's
+    // left-associative). Start needs to be pushed down int the subtree to the
+    // right spot.
+    while (true) {
+      auto MaybeOp{tryParse(Start->ignoreImplicit(), SM)};
+      if (not MaybeOp)
+        break;
+      const auto &Op{*MaybeOp};
+      if (Op.Kind != OuterOp or Op.crossesMacroBoundary())
+        break;
+      assert(!Op.SelectedOperations.empty() &&
+             "got only operator on one side!");
+      if (Op.SelectedOperations.size() == 1) { // Only Op.RHS selected
+        Start = Op.SelectedOperations.back();
+        break;
+      }
+      // Op.LHS is (at least partially) selected, so descend into it.
+      Start = Op.SelectedOperations.front();
+    }
+    return {Start, End};
+  }
+
+protected:
+  const SourceManager *SM;
+  BinaryOperatorKind Kind;
+  SourceLocation ExprLoc;
+  // May also contain partially selected operations,
+  // e.g. a + [[b + c]], will keep (a + b) BinaryOperator.
+  llvm::SmallVector<const SelectionTree::Node *> SelectedOperations;
+};
+
+struct ExtractedBinarySubexpressionSelection : BinarySubexpressionSelection {
+  ExtractedBinarySubexpressionSelection(BinarySubexpressionSelection BinSubexpr,
+                                        SelectedOperands SelectedOps)
+      : BinarySubexpressionSelection::BinarySubexpressionSelection(
+            std::move(BinSubexpr)),
+        Operands{std::move(SelectedOps)} {}
+
+  SourceRange getRange(const LangOptions &LangOpts) const {
+    auto MakeHalfOpenFileRange{[&](const SelectionTree::Node *N) {
+      return toHalfOpenFileRange(*SM, LangOpts, N->ASTNode.getSourceRange());
+    }};
+
+    return SourceRange(MakeHalfOpenFileRange(Operands.Start)->getBegin(),
+                       MakeHalfOpenFileRange(Operands.End)->getEnd());
+  }
+
+  void dumpSelectedOperands(llvm::raw_ostream &Os,
+                            const ASTContext &Cont) const {
+    for (const auto *Op : Operands.Operands)
+      Op->ASTNode.dump(Os, Cont);
+  }
+
+  llvm::SmallVector<const DeclRefExpr *>
+  collectReferences(ASTContext &Cont) const {
+    llvm::SmallVector<const DeclRefExpr *> Refs;
+    auto Matcher{
+        ast_matchers::findAll(ast_matchers::declRefExpr().bind("ref"))};
+    for (const auto *SelNode : Operands.Operands) {
+      auto Matches{ast_matchers::match(Matcher, SelNode->ASTNode, Cont)};
+      for (const auto &Match : Matches)
+        if (const DeclRefExpr * Ref{Match.getNodeAs<DeclRefExpr>("ref")}; Ref)
+          Refs.push_back(Ref);
+    }
+    return Refs;
+  }
+
+private:
+  SelectedOperands Operands;
+};
+
+std::optional<ExtractedBinarySubexpressionSelection>
+BinarySubexpressionSelection::tryExtract() const {
+  if (not isExtractable())
+    return std::nullopt;
+  return ExtractedBinarySubexpressionSelection{*this, getSelectedOperands()};
+}
+
 // A RootStmt is a statement that's fully selected including all it's children
 // and it's parent is unselected.
 // Check if a node is a root statement.
@@ -122,11 +330,14 @@
 // begins in selection range, ends in selection range and any scope that begins
 // outside the selection range, ends outside as well.
 const Node *getParentOfRootStmts(const Node *CommonAnc) {
-  if (!CommonAnc)
-    return nullptr;
   const Node *Parent = nullptr;
   switch (CommonAnc->Selected) {
   case SelectionTree::Selection::Unselected:
+    // Workaround for an operator call: BinaryOperator will be selecteded
+    // completely, but the operator call would be unselected, thus we treat it
+    // as it would be completely selected.
+    if (CommonAnc->ASTNode.get<CXXOperatorCallExpr>() != nullptr)
+      return CommonAnc->Parent;
     // Typically a block, with the { and } unselected, could also be ForStmt etc
     // Ensure all Children are RootStmts.
     Parent = CommonAnc;
@@ -152,6 +363,7 @@
 
 // The ExtractionZone class forms a view of the code wrt Zone.
 struct ExtractionZone {
+  const Node *CommonAncestor;
   // Parent of RootStatements being extracted.
   const Node *Parent = nullptr;
   // The half-open file range of the code being extracted.
@@ -162,6 +374,8 @@
   SourceRange EnclosingFuncRange;
   // Set of statements that form the ExtractionZone.
   llvm::DenseSet<const Stmt *> RootStmts;
+  // If the extraction zone is a "binary subexpression", then this will be set.
+  std::optional<BinarySubexpressionSelection> MaybeBinarySubexpr;
 
   SourceLocation getInsertionPoint() const {
     return EnclosingFuncRange.getBegin();
@@ -292,20 +506,12 @@
   return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange());
 }
 
-// returns true if Child can be a single RootStmt being extracted from
-// EnclosingFunc.
-bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) {
-  // Don't extract expressions.
-  // FIXME: We should extract expressions that are "statements" i.e. not
-  // subexpressions
-  if (Child->ASTNode.get<Expr>())
-    return false;
-  // Extracting the body of EnclosingFunc would remove it's definition.
-  assert(EnclosingFunc->hasBody() &&
+bool isEntireFunctionBodySelected(const ExtractionZone &ExtZone) {
+  assert(ExtZone.EnclosingFunction->hasBody() &&
          "We should always be extracting from a function body.");
-  if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
-    return false;
-  return true;
+  return ExtZone.Parent->Children.size() == 1 &&
+         ExtZone.getLastRootStmt()->ASTNode.get<Stmt>() ==
+             ExtZone.EnclosingFunction->getBody();
 }
 
 // FIXME: Check we're not extracting from the initializer/condition of a control
@@ -313,17 +519,30 @@
 llvm::Optional<ExtractionZone> findExtractionZone(const Node *CommonAnc,
                                                   const SourceManager &SM,
                                                   const LangOptions &LangOpts) {
+  if (CommonAnc == nullptr)
+    return std::nullopt;
   ExtractionZone ExtZone;
+  ExtZone.CommonAncestor = CommonAnc;
+  auto MaybeBinarySubexpr{
+      BinarySubexpressionSelection::tryParse(CommonAnc->ignoreImplicit(), &SM)};
+  if (MaybeBinarySubexpr) {
+    // FIXME: We shall not allow the user to extract expressions which we don't
+    // support, or which are weirdly selected (e.g. a [[+ b + c]]). If the
+    // selected subexpression is an entire expression (not only a part of
+    // expression), then we don't need the BinarySubexpressionSelection.
+    if (const auto &BinarySubexpr{*MaybeBinarySubexpr};
+        BinarySubexpr.isExtractable()) {
+      ExtZone.MaybeBinarySubexpr = std::move(MaybeBinarySubexpr);
+    }
+  }
   ExtZone.Parent = getParentOfRootStmts(CommonAnc);
   if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
     return std::nullopt;
   ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
   if (!ExtZone.EnclosingFunction)
     return std::nullopt;
-  // When there is a single RootStmt, we must check if it's valid for
-  // extraction.
-  if (ExtZone.Parent->Children.size() == 1 &&
-      !validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
+  // Extracting the body of EnclosingFunc would remove it's definition.
+  if (isEntireFunctionBodySelected(ExtZone))
     return std::nullopt;
   if (auto FuncRange =
           computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
@@ -367,6 +586,7 @@
   bool Static = false;
   ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
   bool Const = false;
+  bool Expression = false;
 
   // Decides whether the extracted function body and the function call need a
   // semicolon after extraction.
@@ -495,8 +715,11 @@
   // - hoist decls
   // - add return statement
   // - Add semicolon
-  return toSourceCode(SM, BodyRange).str() +
-         (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+  auto NewBody{toSourceCode(SM, BodyRange).str() +
+               (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "")};
+  if (Expression)
+    return "return " + NewBody;
+  return NewBody;
 }
 
 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
@@ -530,6 +753,7 @@
   // FIXME: Capture type information as well.
   DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
   DeclInformation *getDeclInfoFor(const Decl *D);
+  const DeclInformation *getDeclInfoFor(const Decl *D) const;
 };
 
 CapturedZoneInfo::DeclInformation *
@@ -543,7 +767,14 @@
 
 CapturedZoneInfo::DeclInformation *
 CapturedZoneInfo::getDeclInfoFor(const Decl *D) {
-  // If the Decl doesn't exist, we
+  auto Iter = DeclInfoMap.find(D);
+  if (Iter == DeclInfoMap.end())
+    return nullptr;
+  return &Iter->second;
+}
+
+const CapturedZoneInfo::DeclInformation *
+CapturedZoneInfo::getDeclInfoFor(const Decl *D) const {
   auto Iter = DeclInfoMap.find(D);
   if (Iter == DeclInfoMap.end())
     return nullptr;
@@ -664,12 +895,29 @@
   return Result;
 }
 
-// Adds parameters to ExtractedFunc.
-// Returns true if able to find the parameters successfully and no hoisting
-// needed.
+static const ValueDecl *unpackDeclForParameter(const Decl *D) {
+  const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(D);
+  // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl
+  // (this includes the case of recursive call to EnclosingFunc in Zone).
+  if (!VD || isa<FunctionDecl>(D))
+    return nullptr;
+  return VD;
+}
+
+static QualType getParameterTypeInfo(const ValueDecl *VD) {
+  // Parameter qualifiers are same as the Decl's qualifiers.
+  return VD->getType().getNonReferenceType();
+}
+
+using Parameters = std::vector<NewFunction::Parameter>;
+using MaybeParameters = std::optional<Parameters>;
+
 // FIXME: Check if the declaration has a local/anonymous type
-bool createParameters(NewFunction &ExtractedFunc,
-                      const CapturedZoneInfo &CapturedInfo) {
+// Returns actual parameters if able to find the parameters successfully and no
+// hoisting needed.
+static MaybeParameters
+createParamsForNoSubexpr(const CapturedZoneInfo &CapturedInfo) {
+  std::vector<NewFunction::Parameter> Params;
   for (const auto &KeyVal : CapturedInfo.DeclInfoMap) {
     const auto &DeclInfo = KeyVal.second;
     // If a Decl was Declared in zone and referenced in post zone, it
@@ -677,20 +925,16 @@
     // FIXME: Support Decl Hoisting.
     if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
         DeclInfo.IsReferencedInPostZone)
-      return false;
+      return std::nullopt;
     if (!DeclInfo.IsReferencedInZone)
       continue; // no need to pass as parameter, not referenced
     if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
         DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
       continue; // no need to pass as parameter, still accessible.
-    // Parameter specific checks.
-    const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
-    // Can't parameterise if the Decl isn't a ValueDecl or is a FunctionDecl
-    // (this includes the case of recursive call to EnclosingFunc in Zone).
-    if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
-      return false;
-    // Parameter qualifiers are same as the Decl's qualifiers.
-    QualType TypeInfo = VD->getType().getNonReferenceType();
+    const auto *VD{unpackDeclForParameter(DeclInfo.TheDecl)};
+    if (VD == nullptr)
+      return std::nullopt;
+    QualType TypeInfo{getParameterTypeInfo(VD)};
     // FIXME: Need better qualifier checks: check mutated status for
     // Decl(e.g. was it assigned, passed as nonconst argument, etc)
     // FIXME: check if parameter will be a non l-value reference.
@@ -698,12 +942,61 @@
     // pointers, etc by reference.
     bool IsPassedByReference = true;
     // We use the index of declaration as the ordering priority for parameters.
-    ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo,
-                                        IsPassedByReference,
-                                        DeclInfo.DeclIndex});
+    Params.push_back({std::string(VD->getName()), TypeInfo, IsPassedByReference,
+                      DeclInfo.DeclIndex});
   }
-  llvm::sort(ExtractedFunc.Parameters);
-  return true;
+  llvm::sort(Params);
+  return Params;
+}
+
+static MaybeParameters
+createParamsForSubexpr(const CapturedZoneInfo &CapturedInfo,
+                       const ExtractedBinarySubexpressionSelection &Subexpr,
+                       ASTContext &ASTCont) {
+  // We use the the Set here, to avoid duplicates, but since the Set will not
+  // care about the order, we need to use a vector to collect the unique
+  // references in the order of referencing.
+  llvm::SmallVector<const ValueDecl *> RefsAsDecls;
+  llvm::DenseSet<const ValueDecl *> UniqueRefsAsDecls;
+
+  for (const auto *Ref : Subexpr.collectReferences(ASTCont)) {
+    const auto *D{Ref->getDecl()};
+    const auto *VD{unpackDeclForParameter(D)};
+    // Only collect the ValueDecl-s.
+    if (VD == nullptr)
+      continue;
+    const auto *DeclInfo{CapturedInfo.getDeclInfoFor(D)};
+    if (DeclInfo == nullptr or DeclInfo->DeclaredIn != ZoneRelative::Before)
+      continue;
+    auto [It, IsNew]{UniqueRefsAsDecls.insert(VD)};
+    if (IsNew)
+      RefsAsDecls.emplace_back(VD);
+  }
+
+  std::vector<NewFunction::Parameter> Params;
+  std::transform(std::begin(RefsAsDecls), std::end(RefsAsDecls),
+                 std::back_inserter(Params), [](const ValueDecl *VD) {
+                   QualType TypeInfo{getParameterTypeInfo(VD)};
+                   // FIXME: Need better qualifier checks: check mutated status
+                   // for Decl(e.g. was it assigned, passed as nonconst
+                   // argument, etc)
+                   // FIXME: check if parameter will be a non l-value reference.
+                   // FIXME: We don't want to always pass variables of types
+                   // like int, pointers, etc by reference.
+                   bool IsPassedByRef = true;
+                   return NewFunction::Parameter{std::string(VD->getName()),
+                                                 TypeInfo, IsPassedByRef, 0};
+                 });
+  return Params;
+}
+
+// Adds parameters to ExtractedFunc.
+MaybeParameters createParams(
+    const std::optional<ExtractedBinarySubexpressionSelection> &MaybeSubexpr,
+    const CapturedZoneInfo &CapturedInfo, ASTContext &ASTCont) {
+  if (MaybeSubexpr)
+    return createParamsForSubexpr(CapturedInfo, *MaybeSubexpr, ASTCont);
+  return createParamsForNoSubexpr(CapturedInfo);
 }
 
 // Clangd uses open ranges while ExtractionSemicolonPolicy (in Clang Tooling)
@@ -723,29 +1016,47 @@
   return SemicolonPolicy;
 }
 
+// Returns true if the selected code is an expression, false otherwise.
+bool isExpression(const ExtractionZone &ExtZone) {
+  const auto &Node{*ExtZone.Parent};
+  return Node.Children.size() == 1 and
+         ExtZone.getLastRootStmt()->ASTNode.get<Expr>() != nullptr;
+}
+
 // Generate return type for ExtractedFunc. Return false if unable to do so.
-bool generateReturnProperties(NewFunction &ExtractedFunc,
-                              const FunctionDecl &EnclosingFunc,
-                              const CapturedZoneInfo &CapturedInfo) {
+std::optional<QualType>
+generateReturnProperties(const ExtractionZone &ExtZone,
+                         const CapturedZoneInfo &CapturedInfo) {
   // If the selected code always returns, we preserve those return statements.
   // The return type should be the same as the enclosing function.
   // (Others are possible if there are conversions, but this seems clearest).
+  const auto &EnclosingFunc{*ExtZone.EnclosingFunction};
   if (CapturedInfo.HasReturnStmt) {
     // If the return is conditional, neither replacing the code with
     // `extracted()` nor `return extracted()` is correct.
     if (!CapturedInfo.AlwaysReturns)
-      return false;
+      return std::nullopt;
     QualType Ret = EnclosingFunc.getReturnType();
-    // Once we support members, it'd be nice to support e.g. extracting a method
-    // of Foo<T> that returns T. But it's not clear when that's safe.
+    // Once we support members, it'd be nice to support e.g. extracting a
+    // method of Foo<T> that returns T. But it's not clear when that's safe.
     if (Ret->isDependentType())
-      return false;
-    ExtractedFunc.ReturnType = Ret;
-    return true;
+      return std::nullopt;
+    return Ret;
+  }
+  // If the selected code is an expression, then take the return type of it.
+  if (const auto &Node{*ExtZone.Parent}; Node.Children.size() == 1) {
+    if (const Expr * Expression{ExtZone.getLastRootStmt()->ASTNode.get<Expr>()};
+        Expression) {
+      if (const auto *Call{llvm::dyn_cast_or_null<CallExpr>(Expression)};
+          Call) {
+        const auto &ASTCont{ExtZone.EnclosingFunction->getParentASTContext()};
+        return Call->getCallReturnType(ASTCont);
+      }
+      return Expression->getType();
+    }
   }
   // FIXME: Generate new return statement if needed.
-  ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
-  return true;
+  return EnclosingFunc.getParentASTContext().VoidTy;
 }
 
 void captureMethodInfo(NewFunction &ExtractedFunc,
@@ -791,14 +1102,25 @@
     ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
   }
 
-  ExtractedFunc.BodyRange = ExtZone.ZoneRange;
-  ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
+  auto &ASTCont{ExtZone.EnclosingFunction->getASTContext()};
+  ExtractedFunc.Expression = isExpression(ExtZone);
+  std::optional<ExtractedBinarySubexpressionSelection> MaybeExtractedSubexpr;
+  if (ExtZone.MaybeBinarySubexpr) {
+    MaybeExtractedSubexpr = ExtZone.MaybeBinarySubexpr->tryExtract();
+    ExtractedFunc.BodyRange = MaybeExtractedSubexpr->getRange(LangOpts);
+  } else {
+    ExtractedFunc.BodyRange = ExtZone.ZoneRange;
+  }
 
+  ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
   ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
-  if (!createParameters(ExtractedFunc, CapturedInfo) ||
-      !generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
-                                CapturedInfo))
+
+  auto MaybeRetType{generateReturnProperties(ExtZone, CapturedInfo)};
+  auto MaybeParams{createParams(MaybeExtractedSubexpr, CapturedInfo, ASTCont)};
+  if (not MaybeRetType || not MaybeParams)
     return error("Too complex to extract.");
+  ExtractedFunc.ReturnType = std::move(*MaybeRetType);
+  ExtractedFunc.Parameters = std::move(*MaybeParams);
   return ExtractedFunc;
 }
 
@@ -913,8 +1235,8 @@
 
       tooling::Replacements OtherEdit(
           createForwardDeclaration(*ExtractedFunc, SM));
-      if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
-                                                 OtherEdit))
+      if (auto PathAndEdit =
+              Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit))
         MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
                                                 PathAndEdit->second);
       else
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to