ilya-biryukov created this revision.
ilya-biryukov added reviewers: kadircet, rsmith.

Preferred types are used by code completion for ranking. This commit
considerably increases the number of points in code where those types
are propagated.

In order to avoid complicating signatures of Parser's methods, a
preferred type is kept as a member variable in the parser and updated
during parsing.


Repository:
  rC Clang

https://reviews.llvm.org/D56723

Files:
  include/clang/Parse/Parser.h
  include/clang/Sema/CodeCompleteConsumer.h
  include/clang/Sema/Sema.h
  lib/Parse/ParseDecl.cpp
  lib/Parse/ParseDeclCXX.cpp
  lib/Parse/ParseExpr.cpp
  lib/Parse/ParseExprCXX.cpp
  lib/Parse/ParseInit.cpp
  lib/Parse/ParseStmt.cpp
  lib/Parse/ParseTemplate.cpp
  lib/Sema/SemaCodeComplete.cpp
  unittests/Sema/CodeCompleteTest.cpp

Index: unittests/Sema/CodeCompleteTest.cpp
===================================================================
--- unittests/Sema/CodeCompleteTest.cpp
+++ unittests/Sema/CodeCompleteTest.cpp
@@ -340,4 +340,96 @@
   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
 }
 
+TEST(PreferredTypeTest, Members) {
+  StringRef Code = R"cpp(
+    struct vector {
+      int *begin();
+      vector clone();
+    };
+
+    void test(int *a) {
+      a = ^vector().^clone().^begin();
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
+}
+
+TEST(PreferredTypeTest, Conditions) {
+  StringRef Code = R"cpp(
+    struct vector {
+      bool empty();
+    };
+
+    void test() {
+      if (^vector().^empty()) {}
+      while (^vector().^empty()) {}
+      for (; ^vector().^empty();) {}
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
+}
+
+TEST(PreferredTypeTest, InitAndAssignment) {
+  StringRef Code = R"cpp(
+    struct vector {
+      int* begin();
+    };
+
+    void test() {
+      const int* x = ^vector().^begin();
+      x = ^vector().^begin();
+
+      if (const int* y = ^vector().^begin()) {}
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
+}
+
+TEST(PreferredTypeTest, UnaryExprs) {
+  StringRef Code = R"cpp(
+    void test(long long a) {
+      a = +^a;
+      a = -^a
+      a = ++^a;
+      a = --^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("long long"));
+
+  Code = R"cpp(
+    void test(int a, int *ptr) {
+      !^a;
+      !^ptr;
+      !!!^a;
+
+      a = !^a;
+      a = !^ptr;
+      a = !!!^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
+
+  Code = R"cpp(
+    void test(int a) {
+      const int* x = &^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int"));
+
+  Code = R"cpp(
+    void test(int *a) {
+      int x = *^a;
+      int &r = *^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
+
+  Code = R"cpp(
+    void test(int a) {
+      *^a;
+      &^a;
+    }
+
+  )cpp";
+}
 } // namespace
Index: lib/Sema/SemaCodeComplete.cpp
===================================================================
--- lib/Sema/SemaCodeComplete.cpp
+++ lib/Sema/SemaCodeComplete.cpp
@@ -348,6 +348,180 @@
 };
 } // namespace
 
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::update(llvm::function_ref<void()> Updater) {
+  RestoreRAII R(*this);
+  Updater();
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII PreferredTypeBuilder::enterReturn(Sema &S) {
+  return update([&]() {
+    if (isa<BlockDecl>(S.CurContext)) {
+      if (sema::BlockScopeInfo *BSI = S.getCurBlock())
+        Type = BSI->ReturnType;
+      return;
+    }
+    if (const auto *Function = dyn_cast<FunctionDecl>(S.CurContext)) {
+      Type = Function->getReturnType();
+      return;
+    }
+    if (const auto *Method = dyn_cast<ObjCMethodDecl>(S.CurContext)) {
+      Type = Method->getReturnType();
+      return;
+    }
+    Type = QualType();
+  });
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterVariableInit(Decl *D) {
+  return update([&]() {
+    auto *VD = llvm::dyn_cast_or_null<ValueDecl>(D);
+    Type = VD ? VD->getType() : QualType();
+  });
+}
+
+static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS,
+                                            tok::TokenKind Op) {
+  if (!LHS)
+    return QualType();
+
+  QualType LHSType = LHS->getType();
+  if (LHSType->isPointerType()) {
+    if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal)
+      return S.getASTContext().getPointerDiffType();
+    // Pointer difference is more common than subtracting an int from a pointer.
+    if (Op == tok::minus)
+      return LHSType;
+  }
+
+  switch (Op) {
+  // No way to infer the type of RHS from LHS.
+  case tok::comma:
+    return QualType();
+  // Prefer the type of the left operand for all of these.
+  // Arithmetic operations.
+  case tok::plus:
+  case tok::plusequal:
+  case tok::minus:
+  case tok::minusequal:
+  case tok::percent:
+  case tok::percentequal:
+  case tok::slash:
+  case tok::slashequal:
+  case tok::star:
+  case tok::starequal:
+  // Assignment.
+  case tok::equal:
+  // Comparison operators.
+  case tok::equalequal:
+  case tok::exclaimequal:
+  case tok::less:
+  case tok::lessequal:
+  case tok::greater:
+  case tok::greaterequal:
+  case tok::spaceship:
+    return LHS->getType();
+  // Binary shifts are often overloaded, so don't try to guess those.
+  case tok::greatergreater:
+  case tok::greatergreaterequal:
+  case tok::lessless:
+  case tok::lesslessequal:
+    if (LHSType->isIntegralOrEnumerationType())
+      return S.getASTContext().IntTy;
+    return QualType();
+  // Logical operators, assume we want bool.
+  case tok::ampamp:
+  case tok::pipepipe:
+  case tok::caretcaret:
+    return S.getASTContext().BoolTy;
+  // Operators often used for bit manipulation are typically used with the type
+  // of the left argument.
+  case tok::pipe:
+  case tok::pipeequal:
+  case tok::caret:
+  case tok::caretequal:
+  case tok::amp:
+  case tok::ampequal:
+    if (LHSType->isIntegralOrEnumerationType())
+      return LHSType;
+    return QualType();
+  // RHS should be a pointer to a member of the 'LHS' type, but we can't give
+  // any particular type here.
+  case tok::periodstar:
+  case tok::arrowstar:
+    return QualType();
+  default:
+    // FIXME(ibiryukov): handle the missing op, re-add the assertion.
+    // assert(false && "unhandled binary op");
+    return QualType();
+  }
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterBinary(Sema &S, Expr *LHS, tok::TokenKind Op) {
+  return update([&] { Type = getPreferredTypeOfBinaryRHS(S, LHS, Op); });
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterUnary(Sema &S, tok::TokenKind Op) {
+  return update([&] {
+    switch (Op) {
+    case tok::exclaim:
+      Type = S.getASTContext().BoolTy;
+      break;
+    case tok::amp:
+      if (!Type.isNull() && Type->isPointerType())
+        Type = Type->getPointeeType();
+      else
+        Type = QualType();
+      break;
+    case tok::star:
+      if (Type.isNull())
+        break;
+      Type = S.getASTContext().getPointerType(Type.getNonReferenceType());
+      break;
+    case tok::plus:
+    case tok::minus:
+    case tok::tilde:
+    case tok::minusminus:
+    case tok::plusplus:
+      if (Type.isNull())
+        Type = S.getASTContext().IntTy;
+      // else leave as is, these operators typically return the same type.
+      break;
+    case tok::kw___real:
+    case tok::kw___imag:
+      Type = QualType();
+      break;
+    default:
+      assert(false && "unhnalded unary op");
+      Type = QualType();
+      break;
+    }
+  });
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterSubscript(Sema &S, Expr *LHS) {
+  return update([&]() { Type = S.getASTContext().IntTy; });
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterTypeCast(QualType CastType) {
+  return update([&] { Type = CastType; });
+}
+
+PreferredTypeBuilder::RestoreRAII PreferredTypeBuilder::enterUnknown() {
+  return update([&] { Type = QualType(); });
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterCondition(Sema &S) {
+  return update([&] { Type = S.getASTContext().BoolTy; });
+}
+
 class ResultBuilder::ShadowMapEntry::iterator {
   llvm::PointerUnion<const NamedDecl *, const DeclIndexPair *> DeclOrIterator;
   unsigned SingleDeclIndex;
@@ -3856,13 +4030,15 @@
 }
 
 struct Sema::CodeCompleteExpressionData {
-  CodeCompleteExpressionData(QualType PreferredType = QualType())
+  CodeCompleteExpressionData(QualType PreferredType = QualType(),
+                             bool IsParenthesized = false)
       : PreferredType(PreferredType), IntegralConstantExpression(false),
-        ObjCCollection(false) {}
+        ObjCCollection(false), IsParenthesized(IsParenthesized) {}
 
   QualType PreferredType;
   bool IntegralConstantExpression;
   bool ObjCCollection;
+  bool IsParenthesized;
   SmallVector<Decl *, 4> IgnoreDecls;
 };
 
@@ -3873,13 +4049,18 @@
   ResultBuilder Results(
       *this, CodeCompleter->getAllocator(),
       CodeCompleter->getCodeCompletionTUInfo(),
-      CodeCompletionContext(CodeCompletionContext::CCC_Expression,
-                            Data.PreferredType));
+      CodeCompletionContext(
+          Data.IsParenthesized
+              ? CodeCompletionContext::CCC_ParenthesizedExpression
+              : CodeCompletionContext::CCC_Expression,
+          Data.PreferredType));
+  auto PCC =
+      Data.IsParenthesized ? PCC_ParenthesizedExpression : PCC_Expression;
   if (Data.ObjCCollection)
     Results.setFilter(&ResultBuilder::IsObjCCollection);
   else if (Data.IntegralConstantExpression)
     Results.setFilter(&ResultBuilder::IsIntegralConstantValue);
-  else if (WantTypesInContext(PCC_Expression, getLangOpts()))
+  else if (WantTypesInContext(PCC, getLangOpts()))
     Results.setFilter(&ResultBuilder::IsOrdinaryName);
   else
     Results.setFilter(&ResultBuilder::IsOrdinaryNonTypeName);
@@ -3897,7 +4078,7 @@
                      CodeCompleter->loadExternal());
 
   Results.EnterNewScope();
-  AddOrdinaryNameResults(PCC_Expression, S, *this, Results);
+  AddOrdinaryNameResults(PCC, S, *this, Results);
   Results.ExitScope();
 
   bool PreferredTypeIsPointer = false;
@@ -3917,13 +4098,16 @@
                             Results.data(), Results.size());
 }
 
-void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType) {
-  return CodeCompleteExpression(S, CodeCompleteExpressionData(PreferredType));
+void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType,
+                                  bool IsParenthesized) {
+  return CodeCompleteExpression(
+      S, CodeCompleteExpressionData(PreferredType, IsParenthesized));
 }
 
-void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E) {
+void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E,
+                                         QualType PreferredType) {
   if (E.isInvalid())
-    CodeCompleteOrdinaryName(S, PCC_RecoveryInFunction);
+    CodeCompleteExpression(S, PreferredType);
   else if (getLangOpts().ObjC)
     CodeCompleteObjCInstanceMessage(S, E.get(), None, false);
 }
@@ -4211,7 +4395,8 @@
 void Sema::CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base,
                                            Expr *OtherOpBase,
                                            SourceLocation OpLoc, bool IsArrow,
-                                           bool IsBaseExprStatement) {
+                                           bool IsBaseExprStatement,
+                                           QualType PreferredType) {
   if (!Base || !CodeCompleter)
     return;
 
@@ -4239,6 +4424,7 @@
   }
 
   CodeCompletionContext CCContext(contextKind, ConvertedBaseType);
+  CCContext.setPreferredType(PreferredType);
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
                         CodeCompleter->getCodeCompletionTUInfo(), CCContext,
                         &ResultBuilder::IsMember);
@@ -4800,22 +4986,6 @@
   CodeCompleteExpression(S, Data);
 }
 
-void Sema::CodeCompleteReturn(Scope *S) {
-  QualType ResultType;
-  if (isa<BlockDecl>(CurContext)) {
-    if (BlockScopeInfo *BSI = getCurBlock())
-      ResultType = BSI->ReturnType;
-  } else if (const auto *Function = dyn_cast<FunctionDecl>(CurContext))
-    ResultType = Function->getReturnType();
-  else if (const auto *Method = dyn_cast<ObjCMethodDecl>(CurContext))
-    ResultType = Method->getReturnType();
-
-  if (ResultType.isNull())
-    CodeCompleteOrdinaryName(S, PCC_Expression);
-  else
-    CodeCompleteExpression(S, ResultType);
-}
-
 void Sema::CodeCompleteAfterIf(Scope *S) {
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
                         CodeCompleter->getCodeCompletionTUInfo(),
@@ -4877,91 +5047,6 @@
                             Results.data(), Results.size());
 }
 
-static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS,
-                                            tok::TokenKind Op) {
-  if (!LHS)
-    return QualType();
-
-  QualType LHSType = LHS->getType();
-  if (LHSType->isPointerType()) {
-    if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal)
-      return S.getASTContext().getPointerDiffType();
-    // Pointer difference is more common than subtracting an int from a pointer.
-    if (Op == tok::minus)
-      return LHSType;
-  }
-
-  switch (Op) {
-  // No way to infer the type of RHS from LHS.
-  case tok::comma:
-    return QualType();
-  // Prefer the type of the left operand for all of these.
-  // Arithmetic operations.
-  case tok::plus:
-  case tok::plusequal:
-  case tok::minus:
-  case tok::minusequal:
-  case tok::percent:
-  case tok::percentequal:
-  case tok::slash:
-  case tok::slashequal:
-  case tok::star:
-  case tok::starequal:
-  // Assignment.
-  case tok::equal:
-  // Comparison operators.
-  case tok::equalequal:
-  case tok::exclaimequal:
-  case tok::less:
-  case tok::lessequal:
-  case tok::greater:
-  case tok::greaterequal:
-  case tok::spaceship:
-    return LHS->getType();
-  // Binary shifts are often overloaded, so don't try to guess those.
-  case tok::greatergreater:
-  case tok::greatergreaterequal:
-  case tok::lessless:
-  case tok::lesslessequal:
-    if (LHSType->isIntegralOrEnumerationType())
-      return S.getASTContext().IntTy;
-    return QualType();
-  // Logical operators, assume we want bool.
-  case tok::ampamp:
-  case tok::pipepipe:
-  case tok::caretcaret:
-    return S.getASTContext().BoolTy;
-  // Operators often used for bit manipulation are typically used with the type
-  // of the left argument.
-  case tok::pipe:
-  case tok::pipeequal:
-  case tok::caret:
-  case tok::caretequal:
-  case tok::amp:
-  case tok::ampequal:
-    if (LHSType->isIntegralOrEnumerationType())
-      return LHSType;
-    return QualType();
-  // RHS should be a pointer to a member of the 'LHS' type, but we can't give
-  // any particular type here.
-  case tok::periodstar:
-  case tok::arrowstar:
-    return QualType();
-  default:
-    // FIXME(ibiryukov): handle the missing op, re-add the assertion.
-    // assert(false && "unhandled binary op");
-    return QualType();
-  }
-}
-
-void Sema::CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op) {
-  auto PreferredType = getPreferredTypeOfBinaryRHS(*this, LHS, Op);
-  if (!PreferredType.isNull())
-    CodeCompleteExpression(S, PreferredType);
-  else
-    CodeCompleteOrdinaryName(S, PCC_Expression);
-}
-
 void Sema::CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                    bool EnteringContext, QualType BaseType) {
   if (SS.isEmpty() || !CodeCompleter)
Index: lib/Parse/ParseTemplate.cpp
===================================================================
--- lib/Parse/ParseTemplate.cpp
+++ lib/Parse/ParseTemplate.cpp
@@ -1304,6 +1304,7 @@
 ///         template-argument-list ',' template-argument
 bool
 Parser::ParseTemplateArgumentList(TemplateArgList &TemplateArgs) {
+  auto TypeRAII = PreferredType.enterUnknown();
 
   ColonProtectionRAIIObject ColonProtection(*this, false);
 
Index: lib/Parse/ParseStmt.cpp
===================================================================
--- lib/Parse/ParseStmt.cpp
+++ lib/Parse/ParseStmt.cpp
@@ -1971,9 +1971,11 @@
 
   ExprResult R;
   if (Tok.isNot(tok::semi)) {
+    auto TypeRAII = IsCoreturn ? PreferredType.enterUnknown()
+                               : PreferredType.enterReturn(Actions);
     // FIXME: Code completion for co_return.
     if (Tok.is(tok::code_completion) && !IsCoreturn) {
-      Actions.CodeCompleteReturn(getCurScope());
+      Actions.CodeCompleteExpression(getCurScope(), PreferredType.get());
       cutOffParsing();
       return StmtError();
     }
Index: lib/Parse/ParseInit.cpp
===================================================================
--- lib/Parse/ParseInit.cpp
+++ lib/Parse/ParseInit.cpp
@@ -386,6 +386,8 @@
 ///         initializer-list ',' designation[opt] initializer ...[opt]
 ///
 ExprResult Parser::ParseBraceInitializer() {
+  auto TypeRAII = PreferredType.enterUnknown();
+
   InMessageExpressionRAIIObject InMessage(*this, false);
 
   BalancedDelimiterTracker T(*this, tok::l_brace);
Index: lib/Parse/ParseExprCXX.cpp
===================================================================
--- lib/Parse/ParseExprCXX.cpp
+++ lib/Parse/ParseExprCXX.cpp
@@ -675,6 +675,7 @@
 ///           trailing-return-type[opt]
 ///
 ExprResult Parser::ParseLambdaExpression() {
+  auto TypeRAII = PreferredType.enterUnknown();
   // Parse lambda-introducer.
   LambdaIntroducer Intro;
   Optional<unsigned> DiagID = ParseLambdaIntroducer(Intro);
@@ -1384,6 +1385,8 @@
 ExprResult Parser::ParseCXXTypeid() {
   assert(Tok.is(tok::kw_typeid) && "Not 'typeid'!");
 
+  auto TypeRAII = PreferredType.enterUnknown();
+
   SourceLocation OpLoc = ConsumeToken();
   SourceLocation LParenLoc, RParenLoc;
   BalancedDelimiterTracker T(*this, tok::l_paren);
@@ -1451,6 +1454,8 @@
 ExprResult Parser::ParseCXXUuidof() {
   assert(Tok.is(tok::kw___uuidof) && "Not '__uuidof'!");
 
+  auto TypeRAII = PreferredType.enterUnknown();
+
   SourceLocation OpLoc = ConsumeToken();
   BalancedDelimiterTracker T(*this, tok::l_paren);
 
@@ -1606,11 +1611,15 @@
   case tok::comma:
     return Actions.ActOnCXXThrow(getCurScope(), ThrowLoc, nullptr);
 
-  default:
+  default: {
+    auto TypeRAII = PreferredType.enterUnknown();
+
     ExprResult Expr(ParseAssignmentExpression());
-    if (Expr.isInvalid()) return Expr;
+    if (Expr.isInvalid())
+      return Expr;
     return Actions.ActOnCXXThrow(getCurScope(), ThrowLoc, Expr.get());
   }
+  }
 }
 
 /// Parse the C++ Coroutines co_yield expression.
@@ -1620,6 +1629,8 @@
 ExprResult Parser::ParseCoyieldExpression() {
   assert(Tok.is(tok::kw_co_yield) && "Not co_yield!");
 
+  auto TypeRAII = PreferredType.enterUnknown();
+
   SourceLocation Loc = ConsumeToken();
   ExprResult Expr = Tok.is(tok::l_brace) ? ParseBraceInitializer()
                                          : ParseAssignmentExpression();
@@ -1657,6 +1668,8 @@
   Declarator DeclaratorInfo(DS, DeclaratorContext::FunctionalCastContext);
   ParsedType TypeRep = Actions.ActOnTypeName(getCurScope(), DeclaratorInfo).get();
 
+  auto TypeRAII = PreferredType.enterTypeCast(TypeRep.get());
+
   assert((Tok.is(tok::l_paren) ||
           (getLangOpts().CPlusPlus11 && Tok.is(tok::l_brace)))
          && "Expected '(' or '{'!");
@@ -1740,6 +1753,7 @@
                                                 Sema::ConditionKind CK,
                                                 ForRangeInfo *FRI) {
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
+  auto TypeRAII = PreferredType.enterCondition(Actions);
 
   if (Tok.is(tok::code_completion)) {
     Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Condition);
@@ -1859,6 +1873,7 @@
          diag::warn_cxx98_compat_generalized_initializer_lists);
     InitExpr = ParseBraceInitializer();
   } else if (CopyInitialization) {
+    auto TypeRAII = PreferredType.enterVariableInit(DeclOut);
     InitExpr = ParseAssignmentExpression();
   } else if (Tok.is(tok::l_paren)) {
     // This was probably an attempt to initialize the variable.
@@ -2966,6 +2981,7 @@
   assert(Tok.is(tok::kw_delete) && "Expected 'delete' keyword");
   ConsumeToken(); // Consume 'delete'
 
+  auto TypeRAII = PreferredType.enterUnknown();
   // Array delete?
   bool ArrayDelete = false;
   if (Tok.is(tok::l_square) && NextToken().is(tok::r_square)) {
@@ -3043,6 +3059,8 @@
 ///          type-id ...[opt] type-id-seq[opt]
 ///
 ExprResult Parser::ParseTypeTrait() {
+  auto TypeRAII = PreferredType.enterUnknown();
+
   tok::TokenKind Kind = Tok.getKind();
   unsigned Arity = TypeTraitArity(Kind);
 
@@ -3102,6 +3120,8 @@
 /// [Embarcadero]     '__array_extent' '(' type-id ',' expression ')'
 ///
 ExprResult Parser::ParseArrayTypeTrait() {
+  auto TypeRAII = PreferredType.enterUnknown();
+
   ArrayTypeTrait ATT = ArrayTypeTraitFromTokKind(Tok.getKind());
   SourceLocation Loc = ConsumeToken();
 
@@ -3145,6 +3165,8 @@
 /// [Embarcadero]     expression-trait '(' expression ')'
 ///
 ExprResult Parser::ParseExpressionTrait() {
+  auto TypeRAII = PreferredType.enterUnknown();
+
   ExpressionTrait ET = ExpressionTraitFromTokKind(Tok.getKind());
   SourceLocation Loc = ConsumeToken();
 
Index: lib/Parse/ParseExpr.cpp
===================================================================
--- lib/Parse/ParseExpr.cpp
+++ lib/Parse/ParseExpr.cpp
@@ -159,7 +159,7 @@
 /// Parse an expr that doesn't include (top-level) commas.
 ExprResult Parser::ParseAssignmentExpression(TypeCastState isTypeCast) {
   if (Tok.is(tok::code_completion)) {
-    Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(), PreferredType.get());
     cutOffParsing();
     return ExprError();
   }
@@ -393,15 +393,8 @@
       }
     }
 
-    // Code completion for the right-hand side of a binary expression goes
-    // through a special hook that takes the left-hand side into account.
-    if (Tok.is(tok::code_completion)) {
-      Actions.CodeCompleteBinaryRHS(getCurScope(), LHS.get(),
-                                    OpToken.getKind());
-      cutOffParsing();
-      return ExprError();
-    }
-
+    auto TypeRAII =
+        PreferredType.enterBinary(Actions, LHS.get(), OpToken.getKind());
     // Parse another leaf here for the RHS of the operator.
     // ParseCastExpression works here because all RHS expressions in C have it
     // as a prefix, at least. However, in C++, an assignment-expression could
@@ -1115,6 +1108,8 @@
     //     -- cast-expression
     Token SavedTok = Tok;
     ConsumeToken();
+
+    auto TypeRAII = PreferredType.enterUnary(Actions, SavedTok.getKind());
     // One special case is implicitly handled here: if the preceding tokens are
     // an ambiguous cast expression, such as "(T())++", then we recurse to
     // determine whether the '++' is prefix or postfix.
@@ -1134,6 +1129,7 @@
     return Res;
   }
   case tok::amp: {         // unary-expression: '&' cast-expression
+    auto TypeRAII = PreferredType.enterUnary(Actions, tok::amp);
     // Special treatment because of member pointers
     SourceLocation SavedLoc = ConsumeToken();
     Res = ParseCastExpression(false, true);
@@ -1149,6 +1145,8 @@
   case tok::exclaim:       // unary-expression: '!' cast-expression
   case tok::kw___real:     // unary-expression: '__real' cast-expression [GNU]
   case tok::kw___imag: {   // unary-expression: '__imag' cast-expression [GNU]
+    auto TypeRAII = PreferredType.enterUnary(Actions, Tok.getKind());
+
     SourceLocation SavedLoc = ConsumeToken();
     Res = ParseCastExpression(false);
     if (!Res.isInvalid())
@@ -1184,9 +1182,13 @@
                            // unary-expression: 'sizeof' '(' type-name ')'
   case tok::kw_vec_step:   // unary-expression: OpenCL 'vec_step' expression
   // unary-expression: '__builtin_omp_required_simd_align' '(' type-name ')'
-  case tok::kw___builtin_omp_required_simd_align:
+  case tok::kw___builtin_omp_required_simd_align: {
+    auto TypeRAII = PreferredType.enterUnknown();
     return ParseUnaryExprOrTypeTraitExpression();
+  }
   case tok::ampamp: {      // unary-expression: '&&' identifier
+    auto TypeRAII = PreferredType.enterUnknown();
+
     SourceLocation AmpAmpLoc = ConsumeToken();
     if (Tok.isNot(tok::identifier))
       return ExprError(Diag(Tok, diag::err_expected) << tok::identifier);
@@ -1386,6 +1388,7 @@
     SourceLocation KeyLoc = ConsumeToken();
     BalancedDelimiterTracker T(*this, tok::l_paren);
 
+    auto TypeRAII = PreferredType.enterUnknown();
     if (T.expectAndConsume(diag::err_expected_lparen_after, "noexcept"))
       return ExprError();
     // C++11 [expr.unary.noexcept]p1:
@@ -1424,7 +1427,7 @@
     Res = ParseBlockLiteralExpression();
     break;
   case tok::code_completion: {
-    Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(), PreferredType.get());
     cutOffParsing();
     return ExprError();
   }
@@ -1504,7 +1507,8 @@
       if (InMessageExpression)
         return LHS;
 
-      Actions.CodeCompletePostfixExpression(getCurScope(), LHS);
+      Actions.CodeCompletePostfixExpression(getCurScope(), LHS,
+                                            PreferredType.get());
       cutOffParsing();
       return ExprError();
 
@@ -1541,6 +1545,8 @@
         return ExprError();
       }
 
+      auto TypeRAII = PreferredType.enterSubscript(Actions, LHS.get());
+
       BalancedDelimiterTracker T(*this, tok::l_square);
       T.consumeOpen();
       Loc = T.getOpenLocation();
@@ -1773,7 +1779,8 @@
         // Code completion for a member access expression.
         Actions.CodeCompleteMemberReferenceExpr(
             getCurScope(), Base, CorrectedBase, OpLoc, OpKind == tok::arrow,
-            Base && ExprStatementTokLoc == Base->getBeginLoc());
+            Base && ExprStatementTokLoc == Base->getBeginLoc(),
+            PreferredType.get());
 
         cutOffParsing();
         return ExprError();
@@ -2332,9 +2339,9 @@
   CastTy = nullptr;
 
   if (Tok.is(tok::code_completion)) {
-    Actions.CodeCompleteOrdinaryName(getCurScope(),
-                 ExprType >= CompoundLiteral? Sema::PCC_ParenthesizedExpression
-                                            : Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(), PreferredType.get(),
+                                   /*IsParenthesized=*/ExprType >=
+                                       CompoundLiteral);
     cutOffParsing();
     return ExprError();
   }
@@ -2414,6 +2421,9 @@
     TypeResult Ty = ParseTypeName();
     T.consumeClose();
     ColonProtection.restore();
+
+    auto TypeRAII = PreferredType.enterTypeCast(Ty.get().get());
+
     RParenLoc = T.getCloseLocation();
     ExprResult SubExpr = ParseCastExpression(/*isUnaryExpression=*/false);
 
@@ -2545,6 +2555,7 @@
           return ExprError();
         }
 
+        auto TypeRAII = PreferredType.enterTypeCast(CastTy.get());
         // Parse the cast-expression that follows it next.
         // TODO: For cast expression with CastTy.
         Result = ParseCastExpression(/*isUnaryExpression=*/false,
@@ -2840,13 +2851,15 @@
 bool Parser::ParseExpressionList(SmallVectorImpl<Expr *> &Exprs,
                                  SmallVectorImpl<SourceLocation> &CommaLocs,
                                  llvm::function_ref<void()> Completer) {
+  auto TypeRAII = PreferredType.enterUnknown();
+
   bool SawError = false;
   while (1) {
     if (Tok.is(tok::code_completion)) {
       if (Completer)
         Completer();
       else
-        Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+        Actions.CodeCompleteExpression(getCurScope(), PreferredType.get());
       cutOffParsing();
       return true;
     }
Index: lib/Parse/ParseDeclCXX.cpp
===================================================================
--- lib/Parse/ParseDeclCXX.cpp
+++ lib/Parse/ParseDeclCXX.cpp
@@ -922,6 +922,8 @@
   assert(Tok.isOneOf(tok::kw_decltype, tok::annot_decltype)
            && "Not a decltype specifier");
 
+  auto TypeRAII = PreferredType.enterUnknown();
+
   ExprResult Result;
   SourceLocation StartLoc = Tok.getLocation();
   SourceLocation EndLoc;
Index: lib/Parse/ParseDecl.cpp
===================================================================
--- lib/Parse/ParseDecl.cpp
+++ lib/Parse/ParseDecl.cpp
@@ -45,6 +45,8 @@
                                  AccessSpecifier AS,
                                  Decl **OwnedType,
                                  ParsedAttributes *Attrs) {
+  auto TypeRAII = PreferredType.enterUnknown();
+
   DeclSpecContext DSC = getDeclSpecContextFromDeclaratorContext(Context);
   if (DSC == DeclSpecContext::DSC_normal)
     DSC = DeclSpecContext::DSC_type_specifier;
@@ -2275,7 +2277,11 @@
         return nullptr;
       }
 
-      ExprResult Init(ParseInitializer());
+      ExprResult Init;
+      {
+        auto TypeRAII = PreferredType.enterVariableInit(ThisDecl);
+        Init = ParseInitializer();
+      }
 
       // If this is the only decl in (possibly) range based for statement,
       // our best guess is that the user meant ':' instead of '='.
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -274,6 +274,59 @@
   }
 };
 
+/// Keeps track of expected type during expression parsing.
+class PreferredTypeBuilder {
+public:
+  class RestoreRAII;
+
+  PreferredTypeBuilder() = default;
+  explicit PreferredTypeBuilder(QualType Type) : Type(Type) {}
+
+  LLVM_NODISCARD RestoreRAII enterUnknown();
+  LLVM_NODISCARD RestoreRAII enterCondition(Sema &S);
+  LLVM_NODISCARD RestoreRAII enterReturn(Sema &S);
+  LLVM_NODISCARD RestoreRAII enterVariableInit(Decl *D);
+
+  LLVM_NODISCARD RestoreRAII enterUnary(Sema &S, tok::TokenKind Op);
+  LLVM_NODISCARD RestoreRAII enterBinary(Sema &S, Expr *LHS, tok::TokenKind Op);
+  LLVM_NODISCARD RestoreRAII enterSubscript(Sema &S, Expr *LHS);
+  /// Handles all type casts, including C-style cast, C++ casts, etc.
+  LLVM_NODISCARD RestoreRAII enterTypeCast(QualType CastType);
+
+  QualType get() const { return Type; }
+
+private:
+  LLVM_NODISCARD RestoreRAII update(llvm::function_ref<void()> Updater);
+
+  QualType Type;
+};
+
+class PreferredTypeBuilder::RestoreRAII {
+public:
+  RestoreRAII(RestoreRAII const &) = delete;
+  RestoreRAII &operator=(RestoreRAII const &) = delete;
+
+  explicit RestoreRAII(PreferredTypeBuilder &Builder)
+      : Old(Builder.Type), Builder(&Builder) {}
+
+  RestoreRAII(RestoreRAII &&Other) {
+    Old = Other.Old;
+    Builder = Other.Builder;
+
+    Other.Builder = nullptr;
+  }
+
+  ~RestoreRAII() {
+    if (!Builder)
+      return;
+    Builder->Type = Old;
+  }
+
+private:
+  QualType Old;
+  PreferredTypeBuilder *Builder;
+};
+
 /// Sema - This implements semantic analysis and AST building for C.
 class Sema {
   Sema(const Sema &) = delete;
@@ -10342,11 +10395,14 @@
   struct CodeCompleteExpressionData;
   void CodeCompleteExpression(Scope *S,
                               const CodeCompleteExpressionData &Data);
-  void CodeCompleteExpression(Scope *S, QualType PreferredType);
+  void CodeCompleteExpression(Scope *S, QualType PreferredType,
+                              bool IsParenthesized = false);
   void CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base, Expr *OtherOpBase,
                                        SourceLocation OpLoc, bool IsArrow,
-                                       bool IsBaseExprStatement);
-  void CodeCompletePostfixExpression(Scope *S, ExprResult LHS);
+                                       bool IsBaseExprStatement,
+                                       QualType PreferredType);
+  void CodeCompletePostfixExpression(Scope *S, ExprResult LHS,
+                                     QualType PreferredType);
   void CodeCompleteTag(Scope *S, unsigned TagSpec);
   void CodeCompleteTypeQualifiers(DeclSpec &DS);
   void CodeCompleteFunctionQualifiers(DeclSpec &DS, Declarator &D,
@@ -10368,9 +10424,7 @@
                                               IdentifierInfo *II,
                                               SourceLocation OpenParLoc);
   void CodeCompleteInitializer(Scope *S, Decl *D);
-  void CodeCompleteReturn(Scope *S);
   void CodeCompleteAfterIf(Scope *S);
-  void CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op);
 
   void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                bool EnteringContext, QualType BaseType);
Index: include/clang/Sema/CodeCompleteConsumer.h
===================================================================
--- include/clang/Sema/CodeCompleteConsumer.h
+++ include/clang/Sema/CodeCompleteConsumer.h
@@ -381,6 +381,7 @@
   /// if the expression is a variable initializer or a function argument, the
   /// type of the corresponding variable or function parameter.
   QualType getPreferredType() const { return PreferredType; }
+  void setPreferredType(QualType T) { PreferredType = T; }
 
   /// Retrieve the type of the base object in a member-access
   /// expression.
Index: include/clang/Parse/Parser.h
===================================================================
--- include/clang/Parse/Parser.h
+++ include/clang/Parse/Parser.h
@@ -220,6 +220,10 @@
   /// function call.
   bool CalledSignatureHelp = false;
 
+  /// Tracks expected type of the expression currently being parsed.
+  /// Used by code completion for ranking.
+  PreferredTypeBuilder PreferredType;
+
   /// The "depth" of the template parameters currently being parsed.
   unsigned TemplateParameterDepth;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to