sammccall updated this revision to Diff 397075.
sammccall marked an inline comment as done.
sammccall added a comment.

address comments


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D116326/new/

https://reviews.llvm.org/D116326

Files:
  clang-tools-extra/clangd/CodeComplete.cpp
  clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
  clang/include/clang/Sema/CodeCompleteConsumer.h
  clang/lib/Sema/CodeCompleteConsumer.cpp
  clang/lib/Sema/SemaCodeComplete.cpp
  clang/test/CodeCompletion/ctor-signature.cpp

Index: clang/test/CodeCompletion/ctor-signature.cpp
===================================================================
--- clang/test/CodeCompletion/ctor-signature.cpp
+++ clang/test/CodeCompletion/ctor-signature.cpp
@@ -42,13 +42,29 @@
 // RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:41:22 %s | FileCheck -check-prefix=CHECK-BRACED %s
 
 struct Aggregate {
-  // FIXME: no support for aggregates yet.
-  // CHECK-AGGREGATE-NOT: OVERLOAD: Aggregate{<#const Aggregate &#>}
-  // CHECK-AGGREGATE-NOT: OVERLOAD: {{.*}}first
   int first;
   int second;
+  int third;
 };
 
-Aggregate a{};
-// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:52:13 %s | FileCheck -check-prefix=CHECK-AGGREGATE %s
+Aggregate a{1, 2, 3};
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:50:13 %s | FileCheck -check-prefix=CHECK-AGGREGATE-1 %s
+// CHECK-AGGREGATE-1: OVERLOAD: Aggregate{<#int first#>, int second, int third}
+// CHECK-AGGREGATE-1: OVERLOAD: Aggregate{<#const Aggregate &#>}
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:50:16 %s | FileCheck -check-prefix=CHECK-AGGREGATE-2 %s
+// CHECK-AGGREGATE-2: OVERLOAD: Aggregate{int first, <#int second#>, int third}
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:50:18 %s | FileCheck -check-prefix=CHECK-AGGREGATE-3 %s
+// CHECK-AGGREGATE-3: OVERLOAD: Aggregate{int first, int second, <#int third#>}
 
+Aggregate d{.second=1, .first=2, 3, 4, };
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:59:13 %s | FileCheck -check-prefix=CHECK-DESIG-1 %s
+// CHECK-DESIG-1: OVERLOAD: Aggregate{<#int first#>, int second, int third}
+// CHECK-DESIG-1: OVERLOAD: Aggregate{<#const Aggregate &#>}
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:59:24 %s | FileCheck -check-prefix=CHECK-DESIG-2 %s
+// CHECK-DESIG-2: OVERLOAD: Aggregate{int first, int second, <#int third#>}
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:59:34 %s | FileCheck -check-prefix=CHECK-DESIG-3 %s
+// CHECK-DESIG-3: OVERLOAD: Aggregate{int first, <#int second#>, int third}
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:59:37 %s | FileCheck -check-prefix=CHECK-DESIG-4 %s
+// CHECK-DESIG-4: OVERLOAD: Aggregate{int first, int second, <#int third#>}
+// RUN: %clang_cc1 -fsyntax-only -code-completion-at=%s:59:38 %s | FileCheck -check-prefix=CHECK-DESIG-5 %s --allow-empty
+// CHECK-DESIG-5-NOT: OVERLOAD
Index: clang/lib/Sema/SemaCodeComplete.cpp
===================================================================
--- clang/lib/Sema/SemaCodeComplete.cpp
+++ clang/lib/Sema/SemaCodeComplete.cpp
@@ -2817,14 +2817,18 @@
                        Optional<ArrayRef<QualType>> ObjCSubsts = None);
 
 static std::string
-FormatFunctionParameter(const PrintingPolicy &Policy, const ParmVarDecl *Param,
-                        bool SuppressName = false, bool SuppressBlock = false,
+FormatFunctionParameter(const PrintingPolicy &Policy,
+                        const DeclaratorDecl *Param, bool SuppressName = false,
+                        bool SuppressBlock = false,
                         Optional<ArrayRef<QualType>> ObjCSubsts = None) {
   // Params are unavailable in FunctionTypeLoc if the FunctionType is invalid.
   // It would be better to pass in the param Type, which is usually available.
   // But this case is rare, so just pretend we fell back to int as elsewhere.
   if (!Param)
     return "int";
+  Decl::ObjCDeclQualifier ObjCQual = Decl::OBJC_TQ_None;
+  if (const auto *PVD = dyn_cast<ParmVarDecl>(Param))
+    ObjCQual = PVD->getObjCDeclQualifier();
   bool ObjCMethodParam = isa<ObjCMethodDecl>(Param->getDeclContext());
   if (Param->getType()->isDependentType() ||
       !Param->getType()->isBlockPointerType()) {
@@ -2840,8 +2844,7 @@
       Type = Type.substObjCTypeArgs(Param->getASTContext(), *ObjCSubsts,
                                     ObjCSubstitutionContext::Parameter);
     if (ObjCMethodParam) {
-      Result =
-          "(" + formatObjCParamQualifiers(Param->getObjCDeclQualifier(), Type);
+      Result = "(" + formatObjCParamQualifiers(ObjCQual, Type);
       Result += Type.getAsString(Policy) + ")";
       if (Param->getIdentifier() && !SuppressName)
         Result += Param->getIdentifier()->getName();
@@ -2878,8 +2881,7 @@
 
     if (ObjCMethodParam) {
       Result = Type.getAsString(Policy);
-      std::string Quals =
-          formatObjCParamQualifiers(Param->getObjCDeclQualifier(), Type);
+      std::string Quals = formatObjCParamQualifiers(ObjCQual, Type);
       if (!Quals.empty())
         Result = "(" + Quals + " " + Result + ")";
       if (Result.back() != ')')
@@ -3689,6 +3691,31 @@
   return nullptr;
 }
 
+static void AddOverloadAggregateChunks(const RecordDecl *RD,
+                                       const PrintingPolicy &Policy,
+                                       CodeCompletionBuilder &Result,
+                                       unsigned CurrentArg) {
+  unsigned ChunkIndex = 0;
+  auto AddChunk = [&](std::string Placeholder) {
+    if (ChunkIndex > 0)
+      Result.AddChunk(CodeCompletionString::CK_Comma);
+    const char *Copy = Result.getAllocator().CopyString(Placeholder);
+    if (ChunkIndex == CurrentArg)
+      Result.AddCurrentParameterChunk(Copy);
+    else
+      Result.AddPlaceholderChunk(Copy);
+    ++ChunkIndex;
+  };
+  // Aggregate initialization has all bases followed by all fields.
+  // (Bases are not legal in C++11 but in that case we never get here).
+  if (auto *CRD = llvm::dyn_cast<CXXRecordDecl>(RD)) {
+    for (const auto &Base : CRD->bases())
+      AddChunk(Base.getType().getAsString(Policy));
+  }
+  for (const auto &Field : RD->fields())
+    AddChunk(FormatFunctionParameter(Policy, Field));
+}
+
 /// Add function overload parameter chunks to the given code completion
 /// string.
 static void AddOverloadParameterChunks(ASTContext &Context,
@@ -3698,6 +3725,11 @@
                                        CodeCompletionBuilder &Result,
                                        unsigned CurrentArg, unsigned Start = 0,
                                        bool InOptional = false) {
+  if (!Function && !Prototype) {
+    Result.AddChunk(CodeCompletionString::CK_CurrentParameter, "...");
+    return;
+  }
+
   bool FirstParameter = true;
   unsigned NumParams =
       Function ? Function->getNumParams() : Prototype->getNumParams();
@@ -3774,22 +3806,13 @@
                                CXAvailability_Available);
   FunctionDecl *FDecl = getFunction();
   const FunctionProtoType *Proto =
-      dyn_cast<FunctionProtoType>(getFunctionType());
-  if (!FDecl && !Proto) {
-    // Function without a prototype. Just give the return type and a
-    // highlighted ellipsis.
-    const FunctionType *FT = getFunctionType();
-    Result.AddResultTypeChunk(Result.getAllocator().CopyString(
-        FT->getReturnType().getAsString(Policy)));
-    Result.AddChunk(Braced ? CodeCompletionString::CK_LeftBrace
-                           : CodeCompletionString::CK_LeftParen);
-    Result.AddChunk(CodeCompletionString::CK_CurrentParameter, "...");
-    Result.AddChunk(Braced ? CodeCompletionString::CK_RightBrace
-                           : CodeCompletionString::CK_RightParen);
-    return Result.TakeString();
-  }
+      dyn_cast_or_null<FunctionProtoType>(getFunctionType());
 
-  if (FDecl) {
+  // First, the name/type of the callee.
+  if (getKind() == CK_Aggregate) {
+    Result.AddTextChunk(
+        Result.getAllocator().CopyString(getAggregate()->getName()));
+  } else if (FDecl) {
     if (IncludeBriefComments) {
       if (auto RC = getParameterComment(S.getASTContext(), *this, CurrentArg))
         Result.addBriefComment(RC->getBriefText(S.getASTContext()));
@@ -3800,15 +3823,23 @@
     llvm::raw_string_ostream OS(Name);
     FDecl->getDeclName().print(OS, Policy);
     Result.AddTextChunk(Result.getAllocator().CopyString(OS.str()));
-  } else {
+  } else if (Proto) {
     Result.AddResultTypeChunk(Result.getAllocator().CopyString(
         Proto->getReturnType().getAsString(Policy)));
+  } else {
+    // Function without a prototype. Just give the return type.
+    Result.AddResultTypeChunk(Result.getAllocator().CopyString(
+        getFunctionType()->getReturnType().getAsString(Policy)));
   }
 
+  // Next, the brackets and parameters.
   Result.AddChunk(Braced ? CodeCompletionString::CK_LeftBrace
                          : CodeCompletionString::CK_LeftParen);
-  AddOverloadParameterChunks(S.getASTContext(), Policy, FDecl, Proto, Result,
-                             CurrentArg);
+  if (getKind() == CK_Aggregate)
+    AddOverloadAggregateChunks(getAggregate(), Policy, Result, CurrentArg);
+  else
+    AddOverloadParameterChunks(S.getASTContext(), Policy, FDecl, Proto, Result,
+                               CurrentArg);
   Result.AddChunk(Braced ? CodeCompletionString::CK_RightBrace
                          : CodeCompletionString::CK_RightParen);
 
@@ -5849,17 +5880,16 @@
   // overload candidates.
   QualType ParamType;
   for (auto &Candidate : Candidates) {
-    if (const auto *FType = Candidate.getFunctionType())
-      if (const auto *Proto = dyn_cast<FunctionProtoType>(FType))
-        if (N < Proto->getNumParams()) {
-          if (ParamType.isNull())
-            ParamType = Proto->getParamType(N);
-          else if (!SemaRef.Context.hasSameUnqualifiedType(
-                       ParamType.getNonReferenceType(),
-                       Proto->getParamType(N).getNonReferenceType()))
-            // Otherwise return a default-constructed QualType.
-            return QualType();
-        }
+    QualType CandidateParamType = Candidate.getParamType(N);
+    if (!CandidateParamType.isNull()) {
+      if (ParamType.isNull())
+        ParamType = CandidateParamType;
+      else if (!SemaRef.Context.hasSameUnqualifiedType(
+                   ParamType.getNonReferenceType(),
+                   CandidateParamType.getNonReferenceType()))
+        // Otherwise return a default-constructed QualType.
+        return QualType();
+    }
   }
 
   return ParamType;
@@ -5980,6 +6010,71 @@
   return !CandidateSet.empty() ? ParamType : QualType();
 }
 
+// Determine which param to continue aggregate initialization from after
+// a designated initializer.
+//
+// Given struct S { int a,b,c,d,e; }:
+//   after `S{.b=1,`      we want to suggest c to continue
+//   after `S{.b=1, 2,`   we continue with d (this is legal C and ext in C++)
+//
+// Possible outcomes:
+//   - we saw a designator for a field, and continue from the returned index.
+//     Only aggregate initialization is allowed.
+//   - we saw a designator, but it was complex or we couldn't find the field.
+//     Only aggregate initialization is possible, but we can't assist with it.
+//     Returns an out-of-range index.
+//   - we saw no designators, just positional arguments.
+//     Returns None.
+static llvm::Optional<unsigned>
+getNextAggregateIndexAfterDesignatedInit(const ResultCandidate &Aggregate,
+                                         ArrayRef<Expr *> Args) {
+  static constexpr unsigned Invalid = std::numeric_limits<unsigned>::max();
+
+  // Look for designated initializers.
+  // They're in their syntactic form, not yet resolved to fields.
+  IdentifierInfo *DesignatedFieldName = nullptr;
+  unsigned ArgsAfterDesignator = 0;
+  for (const Expr *Arg : Args) {
+    if (const auto *DIE = dyn_cast<DesignatedInitExpr>(Arg)) {
+      if (DIE->size() == 1 && DIE->getDesignator(0)->isFieldDesignator()) {
+        DesignatedFieldName = DIE->getDesignator(0)->getFieldName();
+        ArgsAfterDesignator = 0;
+      } else {
+        return Invalid; // Complicated designator.
+      }
+    } else if (isa<DesignatedInitUpdateExpr>(Arg)) {
+      return Invalid; // Unsupported.
+    } else {
+      ++ArgsAfterDesignator;
+    }
+  }
+  if (!DesignatedFieldName)
+    return llvm::None;
+
+  // Find the index within the class's fields.
+  // (Probing getParamDecl() directly would be quadratic in number of fields).
+  unsigned AggregateSize = Aggregate.getNumParams();
+  unsigned DesignatedIndex = 0;
+  const FieldDecl *DesignatedField = nullptr;
+  for (const auto *Field : Aggregate.getAggregate()->fields()) {
+    if (Field->getIdentifier() == DesignatedFieldName) {
+      DesignatedField = Field;
+      break;
+    }
+    ++DesignatedIndex;
+  }
+  if (!DesignatedField)
+    return Invalid; // Designator referred to a missing field, give up.
+
+  // Find the index within the aggregate (which may have leading bases).
+  while (DesignatedIndex < AggregateSize &&
+         Aggregate.getParamDecl(DesignatedIndex) != DesignatedField)
+    ++DesignatedIndex;
+
+  // Continue from the index after the last named field.
+  return DesignatedIndex + ArgsAfterDesignator + 1;
+}
+
 QualType Sema::ProduceConstructorSignatureHelp(Scope *S, QualType Type,
                                                SourceLocation Loc,
                                                ArrayRef<Expr *> Args,
@@ -5987,48 +6082,72 @@
                                                bool Braced) {
   if (!CodeCompleter)
     return QualType();
+  SmallVector<ResultCandidate, 8> Results;
 
   // A complete type is needed to lookup for constructors.
-  CXXRecordDecl *RD =
-      isCompleteType(Loc, Type) ? Type->getAsCXXRecordDecl() : nullptr;
+  RecordDecl *RD =
+      isCompleteType(Loc, Type) ? Type->getAsRecordDecl() : nullptr;
   if (!RD)
     return Type;
-  // FIXME: we don't support signature help for aggregate initialization, so
-  //        don't offer a confusing partial list (e.g. the copy constructor).
-  if (Braced && RD->isAggregate())
-    return Type;
+  CXXRecordDecl *CRD = dyn_cast<CXXRecordDecl>(RD);
+
+  // Consider aggregate initialization.
+  // We don't check that types so far are correct.
+  // We also don't handle C99/C++17 brace-elision, we assume init-list elements
+  // are 1:1 with fields.
+  // FIXME: it would be nice to support "unwrapping" aggregates that contain
+  // a single subaggregate, like std::array<T, N> -> T __elements[N].
+  if (Braced && !RD->isUnion() &&
+      (!LangOpts.CPlusPlus || (CRD && CRD->isAggregate()))) {
+    ResultCandidate AggregateSig(RD);
+    unsigned AggregateSize = AggregateSig.getNumParams();
+
+    if (auto NextIndex =
+            getNextAggregateIndexAfterDesignatedInit(AggregateSig, Args)) {
+      // A designator was used, only aggregate init is possible.
+      if (*NextIndex >= AggregateSize)
+        return Type;
+      Results.push_back(AggregateSig);
+      return ProduceSignatureHelp(*this, S, Results, *NextIndex, OpenParLoc,
+                                  Braced);
+    }
+
+    // Describe aggregate initialization, but also constructors below.
+    if (Args.size() < AggregateSize)
+      Results.push_back(AggregateSig);
+  }
 
   // FIXME: Provide support for member initializers.
   // FIXME: Provide support for variadic template constructors.
 
-  OverloadCandidateSet CandidateSet(Loc, OverloadCandidateSet::CSK_Normal);
-
-  for (NamedDecl *C : LookupConstructors(RD)) {
-    if (auto *FD = dyn_cast<FunctionDecl>(C)) {
-      // FIXME: we can't yet provide correct signature help for initializer
-      //        list constructors, so skip them entirely.
-      if (Braced && LangOpts.CPlusPlus && isInitListConstructor(FD))
-        continue;
-      AddOverloadCandidate(FD, DeclAccessPair::make(FD, C->getAccess()), Args,
-                           CandidateSet,
-                           /*SuppressUserConversions=*/false,
-                           /*PartialOverloading=*/true,
-                           /*AllowExplicit*/ true);
-    } else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(C)) {
-      if (Braced && LangOpts.CPlusPlus &&
-          isInitListConstructor(FTD->getTemplatedDecl()))
-        continue;
+  if (CRD) {
+    OverloadCandidateSet CandidateSet(Loc, OverloadCandidateSet::CSK_Normal);
+    for (NamedDecl *C : LookupConstructors(CRD)) {
+      if (auto *FD = dyn_cast<FunctionDecl>(C)) {
+        // FIXME: we can't yet provide correct signature help for initializer
+        //        list constructors, so skip them entirely.
+        if (Braced && LangOpts.CPlusPlus && isInitListConstructor(FD))
+          continue;
+        AddOverloadCandidate(FD, DeclAccessPair::make(FD, C->getAccess()), Args,
+                             CandidateSet,
+                             /*SuppressUserConversions=*/false,
+                             /*PartialOverloading=*/true,
+                             /*AllowExplicit*/ true);
+      } else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(C)) {
+        if (Braced && LangOpts.CPlusPlus &&
+            isInitListConstructor(FTD->getTemplatedDecl()))
+          continue;
 
-      AddTemplateOverloadCandidate(
-          FTD, DeclAccessPair::make(FTD, C->getAccess()),
-          /*ExplicitTemplateArgs=*/nullptr, Args, CandidateSet,
-          /*SuppressUserConversions=*/false,
-          /*PartialOverloading=*/true);
+        AddTemplateOverloadCandidate(
+            FTD, DeclAccessPair::make(FTD, C->getAccess()),
+            /*ExplicitTemplateArgs=*/nullptr, Args, CandidateSet,
+            /*SuppressUserConversions=*/false,
+            /*PartialOverloading=*/true);
+      }
     }
+    mergeCandidatesWithResults(*this, Results, CandidateSet, Loc, Args.size());
   }
 
-  SmallVector<ResultCandidate, 8> Results;
-  mergeCandidatesWithResults(*this, Results, CandidateSet, Loc, Args.size());
   return ProduceSignatureHelp(*this, S, Results, Args.size(), OpenParLoc,
                               Braced);
 }
Index: clang/lib/Sema/CodeCompleteConsumer.cpp
===================================================================
--- clang/lib/Sema/CodeCompleteConsumer.cpp
+++ clang/lib/Sema/CodeCompleteConsumer.cpp
@@ -506,11 +506,72 @@
 
   case CK_FunctionType:
     return Type;
+
+  case CK_Aggregate:
+    return nullptr;
   }
 
   llvm_unreachable("Invalid CandidateKind!");
 }
 
+unsigned CodeCompleteConsumer::OverloadCandidate::getNumParams() const {
+  if (Kind == CK_Aggregate) {
+    unsigned Count =
+        std::distance(AggregateType->field_begin(), AggregateType->field_end());
+    if (const auto *CRD = dyn_cast<CXXRecordDecl>(AggregateType))
+      Count += CRD->getNumBases();
+    return Count;
+  }
+
+  if (const auto *FT = getFunctionType())
+    if (const auto *FPT = dyn_cast<FunctionProtoType>(FT))
+      return FPT->getNumParams();
+
+  return 0;
+}
+
+QualType
+CodeCompleteConsumer::OverloadCandidate::getParamType(unsigned N) const {
+  if (Kind == CK_Aggregate) {
+    if (const auto *CRD = dyn_cast<CXXRecordDecl>(AggregateType)) {
+      if (N < CRD->getNumBases())
+        return std::next(CRD->bases_begin(), N)->getType();
+      N -= CRD->getNumBases();
+    }
+    for (const auto *Field : AggregateType->fields())
+      if (N-- == 0)
+        return Field->getType();
+    return QualType();
+  }
+
+  if (const auto *FT = getFunctionType())
+    if (const auto *FPT = dyn_cast<FunctionProtoType>(FT))
+      if (N < FPT->getNumParams())
+        return FPT->getParamType(N);
+  return QualType();
+}
+
+const NamedDecl *
+CodeCompleteConsumer::OverloadCandidate::getParamDecl(unsigned N) const {
+  if (Kind == CK_Aggregate) {
+    if (const auto *CRD = dyn_cast<CXXRecordDecl>(AggregateType)) {
+      if (N < CRD->getNumBases())
+        return std::next(CRD->bases_begin(), N)->getType()->getAsTagDecl();
+      N -= CRD->getNumBases();
+    }
+    for (const auto *Field : AggregateType->fields())
+      if (N-- == 0)
+        return Field;
+    return nullptr;
+  }
+
+  if (const auto *FD = getFunction()) {
+    if (N < FD->param_size())
+      return FD->getParamDecl(N);
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // Code completion consumer implementation
 //===----------------------------------------------------------------------===//
Index: clang/include/clang/Sema/CodeCompleteConsumer.h
===================================================================
--- clang/include/clang/Sema/CodeCompleteConsumer.h
+++ clang/include/clang/Sema/CodeCompleteConsumer.h
@@ -1014,7 +1014,10 @@
 
       /// The "candidate" is actually a variable, expression, or block
       /// for which we only have a function prototype.
-      CK_FunctionType
+      CK_FunctionType,
+
+      /// The candidate is aggregate initialization of an record type.
+      CK_Aggregate,
     };
 
   private:
@@ -1033,17 +1036,32 @@
       /// The function type that describes the entity being called,
       /// when Kind == CK_FunctionType.
       const FunctionType *Type;
+
+      /// The class being aggregate-initialized,
+      /// when Kind == CK_Aggregate
+      const RecordDecl *AggregateType;
     };
 
   public:
     OverloadCandidate(FunctionDecl *Function)
-        : Kind(CK_Function), Function(Function) {}
+        : Kind(CK_Function), Function(Function) {
+      assert(Function != nullptr);
+    }
 
     OverloadCandidate(FunctionTemplateDecl *FunctionTemplateDecl)
-        : Kind(CK_FunctionTemplate), FunctionTemplate(FunctionTemplateDecl) {}
+        : Kind(CK_FunctionTemplate), FunctionTemplate(FunctionTemplateDecl) {
+      assert(FunctionTemplateDecl != nullptr);
+    }
 
     OverloadCandidate(const FunctionType *Type)
-        : Kind(CK_FunctionType), Type(Type) {}
+        : Kind(CK_FunctionType), Type(Type) {
+      assert(Type != nullptr);
+    }
+
+    OverloadCandidate(const RecordDecl *Aggregate)
+        : Kind(CK_Aggregate), AggregateType(Aggregate) {
+      assert(Aggregate != nullptr);
+    }
 
     /// Determine the kind of overload candidate.
     CandidateKind getKind() const { return Kind; }
@@ -1062,6 +1080,22 @@
     /// function is stored.
     const FunctionType *getFunctionType() const;
 
+    /// Retrieve the aggregate type being initialized.
+    const RecordDecl *getAggregate() const {
+      return getKind() == CK_Aggregate ? AggregateType : nullptr;
+    }
+
+    /// Get the number of parameters in this signature.
+    unsigned getNumParams() const;
+
+    /// Get the type of the Nth parameter.
+    /// Returns null if the type is unknown or N is out of range.
+    QualType getParamType(unsigned N) const;
+
+    /// Get the declaration of the Nth parameter.
+    /// Returns null if the decl is unknown or N is out of range.
+    const NamedDecl *getParamDecl(unsigned N) const;
+
     /// Create a new code-completion string that describes the function
     /// signature of this overload candidate.
     CodeCompletionString *
Index: clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
+++ clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
@@ -1294,6 +1294,25 @@
   CheckBracedInit("int x(S); int i = x({^});");
 }
 
+TEST(SignatureHelpTest, Aggregates) {
+  std::string Top = R"cpp(
+    struct S {
+      int a, b, c, d;
+    };
+  )cpp";
+  auto AggregateSig = Sig("S{[[int a]], [[int b]], [[int c]], [[int d]]}");
+  EXPECT_THAT(signatures(Top + "S s{^}").signatures,
+              UnorderedElementsAre(AggregateSig, Sig("S{}"),
+                                   Sig("S{[[const S &]]}"),
+                                   Sig("S{[[S &&]]}")));
+  EXPECT_THAT(signatures(Top + "S s{1,^}").signatures,
+              ElementsAre(AggregateSig));
+  EXPECT_EQ(signatures(Top + "S s{1,^}").activeParameter, 1);
+  EXPECT_THAT(signatures(Top + "S s{.c=3,^}").signatures,
+              ElementsAre(AggregateSig));
+  EXPECT_EQ(signatures(Top + "S s{.c=3,^}").activeParameter, 3);
+}
+
 TEST(SignatureHelpTest, OverloadInitListRegression) {
   auto Results = signatures(R"cpp(
     struct A {int x;};
Index: clang-tools-extra/clangd/CodeComplete.cpp
===================================================================
--- clang-tools-extra/clangd/CodeComplete.cpp
+++ clang-tools-extra/clangd/CodeComplete.cpp
@@ -895,14 +895,9 @@
 // part of it.
 int paramIndexForArg(const CodeCompleteConsumer::OverloadCandidate &Candidate,
                      int Arg) {
-  int NumParams = 0;
-  if (const auto *F = Candidate.getFunction()) {
-    NumParams = F->getNumParams();
-    if (F->isVariadic())
-      ++NumParams;
-  } else if (auto *T = Candidate.getFunctionType()) {
+  int NumParams = Candidate.getNumParams();
+  if (auto *T = Candidate.getFunctionType()) {
     if (auto *Proto = T->getAs<FunctionProtoType>()) {
-      NumParams = Proto->getNumParams();
       if (Proto->isVariadic())
         ++NumParams;
     }
@@ -998,8 +993,7 @@
                                     const ScoredSignature &R) {
       // Ordering follows:
       // - Less number of parameters is better.
-      // - Function is better than FunctionType which is better than
-      // Function Template.
+      // - Aggregate > Function > FunctionType > FunctionTemplate
       // - High score is better.
       // - Shorter signature is better.
       // - Alphabetically smaller is better.
@@ -1011,15 +1005,20 @@
                R.Quality.NumberOfOptionalParameters;
       if (L.Quality.Kind != R.Quality.Kind) {
         using OC = CodeCompleteConsumer::OverloadCandidate;
-        switch (L.Quality.Kind) {
-        case OC::CK_Function:
-          return true;
-        case OC::CK_FunctionType:
-          return R.Quality.Kind != OC::CK_Function;
-        case OC::CK_FunctionTemplate:
-          return false;
-        }
-        llvm_unreachable("Unknown overload candidate type.");
+        auto KindPriority = [&](OC::CandidateKind K) {
+          switch (K) {
+          case OC::CK_Aggregate:
+            return 1;
+          case OC::CK_Function:
+            return 2;
+          case OC::CK_FunctionType:
+            return 3;
+          case OC::CK_FunctionTemplate:
+            return 4;
+          }
+          llvm_unreachable("Unknown overload candidate type.");
+        };
+        return KindPriority(L.Quality.Kind) < KindPriority(R.Quality.Kind);
       }
       if (L.Signature.label.size() != R.Signature.label.size())
         return L.Signature.label.size() < R.Signature.label.size();
@@ -1170,19 +1169,9 @@
            "too many arguments");
 
     for (unsigned I = 0; I < NumCandidates; ++I) {
-      OverloadCandidate Candidate = Candidates[I];
-      auto *Func = Candidate.getFunction();
-      if (!Func || Func->getNumParams() <= CurrentArg)
-        continue;
-      auto *PVD = Func->getParamDecl(CurrentArg);
-      if (!PVD)
-        continue;
-      auto *Ident = PVD->getIdentifier();
-      if (!Ident)
-        continue;
-      auto Name = Ident->getName();
-      if (!Name.empty())
-        ParamNames.insert(Name.str());
+      if (const NamedDecl *ND = Candidates[I].getParamDecl(CurrentArg))
+        if (const auto *II = ND->getIdentifier())
+          ParamNames.emplace(II->getName());
     }
   }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to