steveire created this revision.
steveire added a reviewer: njames93.
Herald added a subscriber: mgrang.
steveire requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

This would be less complicated if the ReturnTypeRequirement were not a
nested class. That should probably be changed in a prior patch.

Implement the llvm-style rtti for ConceptReference -> TypeConstraint.
The generated NodeIntrospection relies on the ability to use dyn_cast
with classes to provide locations.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D101589

Files:
  clang/include/clang/AST/ASTConcept.h
  clang/include/clang/Tooling/NodeIntrospection.h
  clang/lib/Tooling/DumpTool/APIData.h
  clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.cpp
  clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.h
  clang/lib/Tooling/DumpTool/generate_cxx_src_locs.py
  clang/lib/Tooling/EmptyNodeIntrospection.inc.in
  clang/unittests/Introspection/IntrospectionTest.cpp

Index: clang/unittests/Introspection/IntrospectionTest.cpp
===================================================================
--- clang/unittests/Introspection/IntrospectionTest.cpp
+++ clang/unittests/Introspection/IntrospectionTest.cpp
@@ -1604,3 +1604,346 @@
   EXPECT_THAT(ExpectedRanges, UnorderedElementsAre(STRING_LOCATION_PAIR(
                                   (&NI), getSourceRange())));
 }
+
+TEST(Introspection, SourceLocations_TypeConstraint) {
+  if (!NodeIntrospection::hasIntrospectionSupport())
+    return;
+  auto AST =
+      buildASTFromCodeWithArgs(R"cpp(
+namespace ns {
+template <typename T, typename U>
+concept binary_concept = true;
+}
+template <ns::template binary_concept<int> T>
+void bar(T);
+)cpp",
+                               {"-std=c++20"}, "foo.cpp", "clang-tool",
+                               std::make_shared<PCHContainerOperations>());
+  auto &Ctx = AST->getASTContext();
+  auto &TU = *Ctx.getTranslationUnitDecl();
+
+  auto BoundNodes = ast_matchers::match(
+      decl(hasDescendant(functionTemplateDecl(
+          hasName("bar"), has(templateTypeParmDecl().bind("tparm"))))),
+      TU, Ctx);
+
+  EXPECT_EQ(BoundNodes.size(), 1u);
+
+  const auto *Tparm = BoundNodes[0].getNodeAs<TemplateTypeParmDecl>("tparm");
+
+  auto Result = NodeIntrospection::GetLocations(Tparm);
+
+  auto ExpectedLocations =
+      FormatExpected<SourceLocation>(Result.LocationAccessors);
+
+  llvm::sort(ExpectedLocations);
+
+  // clang-format off
+  EXPECT_EQ(
+      llvm::makeArrayRef(ExpectedLocations),
+      (ArrayRef<std::pair<std::string, SourceLocation>>{
+          STRING_LOCATION_STDPAIR(Tparm, getBeginLoc()),
+          STRING_LOCATION_STDPAIR(Tparm, getEndLoc()),
+          STRING_LOCATION_STDPAIR(Tparm, getLocation()),
+          STRING_LOCATION_STDPAIR(
+              Tparm, getTypeConstraint()->getConceptNameInfo().getBeginLoc()),
+          STRING_LOCATION_STDPAIR(
+              Tparm, getTypeConstraint()->getConceptNameInfo().getEndLoc()),
+          STRING_LOCATION_STDPAIR(
+              Tparm, getTypeConstraint()->getConceptNameInfo().getLoc()),
+          STRING_LOCATION_STDPAIR(Tparm,
+                                  getTypeConstraint()->getConceptNameLoc()),
+          STRING_LOCATION_STDPAIR(
+              Tparm,
+              getTypeConstraint()->getNestedNameSpecifierLoc().getBeginLoc()),
+          STRING_LOCATION_STDPAIR(
+              Tparm,
+              getTypeConstraint()->getNestedNameSpecifierLoc().getEndLoc()),
+          STRING_LOCATION_STDPAIR(Tparm,
+            getTypeConstraint()->getNestedNameSpecifierLoc().getLocalBeginLoc()),
+          STRING_LOCATION_STDPAIR(Tparm,
+            getTypeConstraint()->getNestedNameSpecifierLoc().getLocalEndLoc())}));
+  // clang-format on
+
+  auto ExpectedRanges = FormatExpected<SourceRange>(Result.RangeAccessors);
+
+  // clang-format off
+  EXPECT_THAT(
+      ExpectedRanges,
+      UnorderedElementsAre(
+STRING_LOCATION_PAIR(Tparm,
+    getTypeConstraint()->getConceptNameInfo().getSourceRange()),
+STRING_LOCATION_PAIR(Tparm, getSourceRange()),
+STRING_LOCATION_PAIR(Tparm,
+  getTypeConstraint()->getNestedNameSpecifierLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(Tparm,
+  getTypeConstraint()->getNestedNameSpecifierLoc().getSourceRange())
+  ));
+  // clang-format on
+}
+
+TEST(Introspection, SourceLocations_ConstrainedAuto) {
+  if (!NodeIntrospection::hasIntrospectionSupport())
+    return;
+  auto AST =
+      buildASTFromCodeWithArgs(R"cpp(
+namespace ns {
+template <typename T, typename U>
+concept binary_concept = true;
+}
+void bar()
+{
+  ns::template binary_concept<int> auto vd = 4;
+}
+)cpp",
+                               {"-std=c++20"}, "foo.cpp", "clang-tool",
+                               std::make_shared<PCHContainerOperations>());
+  auto &Ctx = AST->getASTContext();
+  auto &TU = *Ctx.getTranslationUnitDecl();
+
+  auto BoundNodes = ast_matchers::match(
+      decl(hasDescendant(varDecl(hasName("vd")).bind("vd"))), TU, Ctx);
+
+  EXPECT_EQ(BoundNodes.size(), 1u);
+
+  const auto *VD = BoundNodes[0].getNodeAs<VarDecl>("vd");
+
+  auto Result = NodeIntrospection::GetLocations(VD);
+
+  auto ExpectedLocations =
+      FormatExpected<SourceLocation>(Result.LocationAccessors);
+
+  llvm::sort(ExpectedLocations);
+
+  // clang-format off
+  EXPECT_EQ(
+      llvm::makeArrayRef(ExpectedLocations),
+      (ArrayRef<std::pair<std::string, SourceLocation>>{
+STRING_LOCATION_STDPAIR(VD, getBeginLoc()),
+STRING_LOCATION_STDPAIR(VD, getEndLoc()),
+STRING_LOCATION_STDPAIR(VD, getInnerLocStart()),
+STRING_LOCATION_STDPAIR(VD, getLocation()),
+STRING_LOCATION_STDPAIR(VD, getOuterLocStart()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getConceptNameInfo().getBeginLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getConceptNameInfo().getEndLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getConceptNameInfo().getLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getConceptNameLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getLAngleLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getNestedNameSpecifierLoc().getBeginLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getNestedNameSpecifierLoc().getEndLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getNestedNameSpecifierLoc().getLocalBeginLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getNestedNameSpecifierLoc().getLocalEndLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getRAngleLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getTemplateKWLoc()),
+STRING_LOCATION_STDPAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::TypeSpecTypeLoc>().getNameLoc()),
+STRING_LOCATION_STDPAIR(VD, getTypeSourceInfo()->getTypeLoc().getBeginLoc()),
+STRING_LOCATION_STDPAIR(VD, getTypeSourceInfo()->getTypeLoc().getEndLoc()),
+STRING_LOCATION_STDPAIR(VD, getTypeSpecEndLoc()),
+STRING_LOCATION_STDPAIR(VD, getTypeSpecStartLoc())
+        }));
+  // clang-format on
+
+  auto ExpectedRanges = FormatExpected<SourceRange>(Result.RangeAccessors);
+
+  // clang-format off
+  EXPECT_THAT(
+      ExpectedRanges,
+      UnorderedElementsAre(
+STRING_LOCATION_PAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getNestedNameSpecifierLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getNestedNameSpecifierLoc().getSourceRange()),
+STRING_LOCATION_PAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getSourceRange()),
+STRING_LOCATION_PAIR(VD, getSourceRange()),
+STRING_LOCATION_PAIR(VD,
+  getTypeSourceInfo()->getTypeLoc().getAs<clang::AutoTypeLoc>().getConceptNameInfo().getSourceRange())
+              ));
+  // clang-format on
+}
+
+TEST(Introspection, SourceLocations_TypeRequirement) {
+  if (!NodeIntrospection::hasIntrospectionSupport())
+    return;
+  auto AST =
+      buildASTFromCodeWithArgs(R"cpp(
+template<typename T> concept C =
+requires {
+    typename T::template inner<int>;
+};
+)cpp",
+                               {"-std=c++20"}, "foo.cpp", "clang-tool",
+                               std::make_shared<PCHContainerOperations>());
+  auto &Ctx = AST->getASTContext();
+  auto &TU = *Ctx.getTranslationUnitDecl();
+
+  auto BoundNodes =
+      ast_matchers::match(decl(hasDescendant(namedDecl(
+                              hasName("C"), has(expr().bind("requiresExpr"))))),
+                          TU, Ctx);
+
+  EXPECT_EQ(BoundNodes.size(), 1u);
+
+  const auto *RE = BoundNodes[0].getNodeAs<RequiresExpr>("requiresExpr");
+
+  auto Req = RE->getRequirements()[0];
+
+  auto Result = NodeIntrospection::GetLocations(Req);
+
+  auto TReq = dyn_cast<concepts::TypeRequirement>(Req);
+
+  auto ExpectedLocations =
+      FormatExpected<SourceLocation>(Result.LocationAccessors);
+
+  llvm::sort(ExpectedLocations);
+
+  // clang-format off
+  EXPECT_EQ(
+      llvm::makeArrayRef(ExpectedLocations),
+      (ArrayRef<std::pair<std::string, SourceLocation>>{
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getElaboratedKeywordLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getLAngleLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getBeginLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getEndLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getLocalBeginLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getLocalEndLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getTypeLoc().getAs<clang::TypeSpecTypeLoc>().getNameLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getTypeLoc().getBeginLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getTypeLoc().getEndLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getRAngleLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getTemplateKeywordLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getTemplateNameLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getBeginLoc()),
+STRING_LOCATION_STDPAIR(TReq,
+  getType()->getTypeLoc().getEndLoc())
+        }));
+  // clang-format on
+
+  auto ExpectedRanges = FormatExpected<SourceRange>(Result.RangeAccessors);
+
+  // clang-format off
+  EXPECT_THAT(
+      ExpectedRanges,
+      UnorderedElementsAre(
+STRING_LOCATION_PAIR(TReq, getType()->getTypeLoc().getSourceRange()),
+STRING_LOCATION_PAIR(TReq, getType()->getTypeLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getTypeLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getTypeLoc().getSourceRange()),
+STRING_LOCATION_PAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(TReq,
+  getType()->getTypeLoc().getAs<clang::DependentTemplateSpecializationTypeLoc>().getQualifierLoc().getSourceRange())
+    ));
+  // clang-format on
+}
+
+TEST(Introspection, SourceLocations_ExprRequirement) {
+  if (!NodeIntrospection::hasIntrospectionSupport())
+    return;
+  auto AST =
+      buildASTFromCodeWithArgs(R"cpp(
+namespace ns {
+template <typename T, typename U>
+concept binary_concept = true;
+template <typename T>
+concept unary_concept = true;
+}
+
+template<typename T>
+concept CastableToString = requires(T a) {
+  { a } noexcept -> ns::binary_concept<int>;
+};
+)cpp",
+                               {"-std=c++20"}, "foo.cpp", "clang-tool",
+                               std::make_shared<PCHContainerOperations>());
+  auto &Ctx = AST->getASTContext();
+  auto &TU = *Ctx.getTranslationUnitDecl();
+
+  auto BoundNodes = ast_matchers::match(
+      decl(hasDescendant(namedDecl(hasName("CastableToString"),
+                                   has(expr().bind("requiresExpr"))))),
+      TU, Ctx);
+
+  EXPECT_EQ(BoundNodes.size(), 1u);
+
+  const auto *RE = BoundNodes[0].getNodeAs<RequiresExpr>("requiresExpr");
+
+  auto Req = RE->getRequirements()[0];
+
+  auto Result = NodeIntrospection::GetLocations(Req);
+
+  auto ExpectedLocations =
+      FormatExpected<SourceLocation>(Result.LocationAccessors);
+
+  llvm::sort(ExpectedLocations);
+
+  auto ER = dyn_cast<concepts::ExprRequirement>(Req);
+
+  // clang-format off
+  EXPECT_EQ(
+      llvm::makeArrayRef(ExpectedLocations),
+      (ArrayRef<std::pair<std::string, SourceLocation>>{
+STRING_LOCATION_STDPAIR(ER, getNoexceptLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getConceptNameInfo().getBeginLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getConceptNameInfo().getEndLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getConceptNameInfo().getLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getConceptNameLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getNestedNameSpecifierLoc().getBeginLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getNestedNameSpecifierLoc().getEndLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getNestedNameSpecifierLoc().getLocalBeginLoc()),
+STRING_LOCATION_STDPAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getNestedNameSpecifierLoc().getLocalEndLoc())
+  }));
+  // clang-format on
+
+  auto ExpectedRanges = FormatExpected<SourceRange>(Result.RangeAccessors);
+
+  // clang-format off
+  EXPECT_THAT(
+      ExpectedRanges,
+      UnorderedElementsAre(
+STRING_LOCATION_PAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getNestedNameSpecifierLoc().getSourceRange()),
+STRING_LOCATION_PAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getNestedNameSpecifierLoc().getLocalSourceRange()),
+STRING_LOCATION_PAIR(ER,
+  getReturnTypeRequirement().getTypeConstraint()->getConceptNameInfo().getSourceRange())
+          ));
+  // clang-format on
+}
Index: clang/lib/Tooling/EmptyNodeIntrospection.inc.in
===================================================================
--- clang/lib/Tooling/EmptyNodeIntrospection.inc.in
+++ clang/lib/Tooling/EmptyNodeIntrospection.inc.in
@@ -40,6 +40,18 @@
     clang::DeclarationNameInfo const&) {
   return {};
 }
+NodeLocationAccessors NodeIntrospection::GetLocations(
+    clang::ConceptReference const*) {
+  return {};
+}
+NodeLocationAccessors GetLocations(clang::concepts::Requirement const *) {
+  return {};
+}
+NodeLocationAccessors
+GetLocations(clang::concepts::ExprRequirement::ReturnTypeRequirement const &)
+{
+  return {};
+}
 NodeLocationAccessors
 NodeIntrospection::GetLocations(clang::DynTypedNode const &) {
   return {};
Index: clang/lib/Tooling/DumpTool/generate_cxx_src_locs.py
===================================================================
--- clang/lib/Tooling/DumpTool/generate_cxx_src_locs.py
+++ clang/lib/Tooling/DumpTool/generate_cxx_src_locs.py
@@ -15,7 +15,9 @@
     RefClades = {"DeclarationNameInfo",
         "NestedNameSpecifierLoc",
         "TemplateArgumentLoc",
-        "TypeLoc"}
+        "TypeLoc",
+        "concepts::ExprRequirement::ReturnTypeRequirement"
+        }
 
     def __init__(self, templateClasses):
         self.templateClasses = templateClasses
@@ -74,7 +76,9 @@
     def GenerateSrcLocMethod(self,
             ClassName, ClassData, CreateLocalRecursionGuard):
 
-        NormalClassName = ClassName
+        NormalClassName = \
+            ClassName.replace("concepts::", "").replace("ExprRequirement::", "")
+
         RecursionGuardParam = ('' if CreateLocalRecursionGuard else \
             ', std::vector<clang::TypeLoc>& TypeLocRecursionGuard')
 
@@ -125,6 +129,8 @@
 
         if 'typeLocs' in ClassData or 'typeSourceInfos' in ClassData \
                 or 'nestedNameLocs' in ClassData \
+                or 'conceptReferences' in ClassData \
+                or 'returnTypeRequirements' in ClassData \
                 or 'declNameInfos' in ClassData:
             if CreateLocalRecursionGuard:
                 self.implementationContent += \
@@ -169,6 +175,26 @@
                     Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
               """.format(NN)
 
+            if 'conceptReferences' in ClassData:
+                for TC in ClassData['conceptReferences']:
+                    self.implementationContent += \
+                        """
+              if (Object.{0}())
+                GetLocationsImpl(
+                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}",
+                        LocationCall::ReturnsPointer),
+                    Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
+              """.format(TC)
+
+            if 'returnTypeRequirements' in ClassData:
+                for TC in ClassData['returnTypeRequirements']:
+                    self.implementationContent += \
+                        """
+                GetLocationsImpl(
+                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
+                    Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
+              """.format(TC)
+
             if 'declNameInfos' in ClassData:
                 for declName in ClassData['declNameInfos']:
 
@@ -228,7 +254,8 @@
             ArgPrefix = ''
         self.implementationContent += \
             'GetLocations{0}(Prefix, {1}Object, Locs, Rngs {2});'.format(
-                CladeName, ArgPrefix, RecursionGuardParam)
+                CladeName.replace("concepts::", "").replace("ExprRequirement::", ""),
+                ArgPrefix, RecursionGuardParam)
 
         if CladeName == "TypeLoc":
             self.implementationContent += \
@@ -251,9 +278,11 @@
                 self.implementationContent += \
                 """
 if (auto Derived = llvm::dyn_cast<clang::{0}>(Object)) {{
-  GetLocations{0}(Prefix, *Derived, Locs, Rngs {1});
+  GetLocations{1}(Prefix, *Derived, Locs, Rngs {2});
 }}
-""".format(ASTClassName, RecursionGuardParam)
+""".format(ASTClassName,
+    ASTClassName.replace("concepts::", "").replace("ExprRequirement::", ""),
+    RecursionGuardParam)
                 continue
 
             self.GenerateBaseTypeLocVisit(ASTClassName, ClassEntries,
@@ -315,6 +344,12 @@
         for CladeName in CladeNames:
             if CladeName == "DeclarationNameInfo":
                 continue
+            if CladeName == "ConceptReference":
+                continue
+            if CladeName == "concepts::Requirement":
+                continue
+            if CladeName == "concepts::ExprRequirement::ReturnTypeRequirement":
+                continue
             self.implementationContent += \
                 """
     if (const auto *N = Node.get<{0}>())
Index: clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.h
===================================================================
--- clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.h
+++ clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.h
@@ -40,7 +40,7 @@
   }
 
   llvm::StringMap<std::string> ClassInheritance;
-  llvm::StringMap<std::vector<StringRef>> ClassesInClade;
+  llvm::StringMap<std::vector<std::string>> ClassesInClade;
   llvm::StringMap<ClassData> ClassEntries;
 
   std::string JsonPath;
Index: clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.cpp
===================================================================
--- clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.cpp
+++ clang/lib/Tooling/DumpTool/ASTSrcLocProcessor.cpp
@@ -31,7 +31,10 @@
                       "clang::Stmt", "clang::Decl", "clang::CXXCtorInitializer",
                       "clang::NestedNameSpecifierLoc",
                       "clang::TemplateArgumentLoc", "clang::CXXBaseSpecifier",
-                      "clang::DeclarationNameInfo", "clang::TypeLoc"))
+                      "clang::DeclarationNameInfo", "clang::TypeLoc",
+                      "clang::ConceptReference", "clang::concepts::Requirement",
+                      "clang::concepts::ExprRequirement::"
+                      "ReturnTypeRequirement"))
                   .bind("nodeClade")),
           optionally(isDerivedFrom(cxxRecordDecl().bind("derivedFrom"))))
           .bind("className"),
@@ -49,7 +52,8 @@
   return Finder->newASTConsumer();
 }
 
-llvm::json::Object toJSON(llvm::StringMap<std::vector<StringRef>> const &Obj) {
+llvm::json::Object
+toJSON(llvm::StringMap<std::vector<std::string>> const &Obj) {
   using llvm::json::toJSON;
 
   llvm::json::Object JsonObj;
@@ -86,6 +90,10 @@
     JsonObj["nestedNameLocs"] = Obj.NestedNameLocs;
   if (!Obj.DeclNameInfos.empty())
     JsonObj["declNameInfos"] = Obj.DeclNameInfos;
+  if (!Obj.ConceptReferences.empty())
+    JsonObj["conceptReferences"] = Obj.ConceptReferences;
+  if (!Obj.ReturnTypeRequirements.empty())
+    JsonObj["returnTypeRequirements"] = Obj.ReturnTypeRequirements;
   return JsonObj;
 }
 
@@ -133,6 +141,25 @@
 }
 
 void ASTSrcLocProcessor::generate() {
+
+  // ConceptSpecializationExpr multiply-inherits from Expr
+  // and ConceptReference. Rather than try to add generic multiple
+  // inheritance support, just handle it as a special case.
+  ClassEntries["ConceptSpecializationExpr"].ASTClassLocations =
+      ClassEntries["ConceptReference"].ASTClassLocations;
+  ClassEntries["ConceptSpecializationExpr"].ASTClassRanges =
+      ClassEntries["ConceptReference"].ASTClassRanges;
+  ClassEntries["ConceptSpecializationExpr"].TemplateParms =
+      ClassEntries["ConceptReference"].TemplateParms;
+  ClassEntries["ConceptSpecializationExpr"].TypeSourceInfos =
+      ClassEntries["ConceptReference"].TypeSourceInfos;
+  ClassEntries["ConceptSpecializationExpr"].TypeLocs =
+      ClassEntries["ConceptReference"].TypeLocs;
+  ClassEntries["ConceptSpecializationExpr"].NestedNameLocs =
+      ClassEntries["ConceptReference"].NestedNameLocs;
+  ClassEntries["ConceptSpecializationExpr"].DeclNameInfos =
+      ClassEntries["ConceptReference"].DeclNameInfos;
+
   WriteJSON(JsonPath, ::toJSON(ClassInheritance), ::toJSON(ClassesInClade),
             ::toJSON(ClassEntries));
 }
@@ -200,17 +227,34 @@
   const auto *ASTClass =
       Result.Nodes.getNodeAs<clang::CXXRecordDecl>("className");
 
-  StringRef CladeName;
+  bool IsConceptRequirementClade = false;
+  bool IsReturnTypeRequirementClade = false;
+  std::string CladeName;
   if (ASTClass) {
     if (const auto *NodeClade =
             Result.Nodes.getNodeAs<clang::CXXRecordDecl>("nodeClade"))
-      CladeName = NodeClade->getName();
+      CladeName = NodeClade->getName().str();
+    IsConceptRequirementClade = CladeName == "Requirement";
+    if (IsConceptRequirementClade) {
+      CladeName = "concepts::" + CladeName;
+    }
+    IsReturnTypeRequirementClade = CladeName == "ReturnTypeRequirement";
+    if (IsReturnTypeRequirementClade) {
+      CladeName = "concepts::ExprRequirement::" + CladeName;
+    }
   } else {
     ASTClass = Result.Nodes.getNodeAs<clang::CXXRecordDecl>("templateName");
     CladeName = "TypeLoc";
   }
 
-  StringRef ClassName = ASTClass->getName();
+  std::string ClassName = ASTClass->getName().str();
+
+  if (IsConceptRequirementClade) {
+    ClassName = "concepts::" + ClassName;
+  }
+  if (IsReturnTypeRequirementClade) {
+    ClassName = "concepts::ExprRequirement::" + ClassName;
+  }
 
   ClassData CD;
 
@@ -223,12 +267,22 @@
   CD.TypeLocs = CaptureMethods("class clang::TypeLoc", ASTClass, Result);
   CD.NestedNameLocs =
       CaptureMethods("class clang::NestedNameSpecifierLoc", ASTClass, Result);
+  auto NN = CaptureMethods("const class clang::NestedNameSpecifierLoc &",
+                           ASTClass, Result);
+  CD.NestedNameLocs.insert(CD.NestedNameLocs.end(), NN.begin(), NN.end());
   CD.DeclNameInfos =
       CaptureMethods("struct clang::DeclarationNameInfo", ASTClass, Result);
   auto DI = CaptureMethods("const struct clang::DeclarationNameInfo &",
                            ASTClass, Result);
   CD.DeclNameInfos.insert(CD.DeclNameInfos.end(), DI.begin(), DI.end());
 
+  CD.ConceptReferences =
+      CaptureMethods("const class clang::TypeConstraint *", ASTClass, Result);
+
+  CD.ReturnTypeRequirements = CaptureMethods(
+      "const class clang::concepts::ExprRequirement::ReturnTypeRequirement &",
+      ASTClass, Result);
+
   if (const auto *DerivedFrom =
           Result.Nodes.getNodeAs<clang::CXXRecordDecl>("derivedFrom")) {
 
@@ -254,7 +308,13 @@
 
       ClassInheritance[ClassName] = TArgsString.str().str();
     } else {
-      ClassInheritance[ClassName] = DerivedFrom->getName().str();
+      auto DerivedName = DerivedFrom->getName().str();
+
+      if (IsConceptRequirementClade) {
+        DerivedName = "concepts::" + DerivedName;
+      }
+
+      ClassInheritance[ClassName] = DerivedName;
     }
   }
 
Index: clang/lib/Tooling/DumpTool/APIData.h
===================================================================
--- clang/lib/Tooling/DumpTool/APIData.h
+++ clang/lib/Tooling/DumpTool/APIData.h
@@ -23,6 +23,8 @@
   std::vector<std::string> TypeLocs;
   std::vector<std::string> NestedNameLocs;
   std::vector<std::string> DeclNameInfos;
+  std::vector<std::string> ConceptReferences;
+  std::vector<std::string> ReturnTypeRequirements;
 };
 
 } // namespace tooling
Index: clang/include/clang/Tooling/NodeIntrospection.h
===================================================================
--- clang/include/clang/Tooling/NodeIntrospection.h
+++ clang/include/clang/Tooling/NodeIntrospection.h
@@ -15,6 +15,7 @@
 
 #include "clang/AST/ASTTypeTraits.h"
 #include "clang/AST/DeclarationName.h"
+#include "clang/AST/ExprConcepts.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include <set>
 
@@ -27,6 +28,10 @@
 class TemplateArgumentLoc;
 class CXXBaseSpecifier;
 struct DeclarationNameInfo;
+class ConceptReference;
+namespace concepts {
+class Requirement;
+}
 
 namespace tooling {
 
@@ -94,6 +99,10 @@
 NodeLocationAccessors GetLocations(clang::CXXBaseSpecifier const *);
 NodeLocationAccessors GetLocations(clang::TypeLoc const &);
 NodeLocationAccessors GetLocations(clang::DeclarationNameInfo const &);
+NodeLocationAccessors GetLocations(clang::ConceptReference const *);
+NodeLocationAccessors GetLocations(clang::concepts::Requirement const *);
+NodeLocationAccessors
+GetLocations(clang::concepts::ExprRequirement::ReturnTypeRequirement const &);
 NodeLocationAccessors GetLocations(clang::DynTypedNode const &Node);
 } // namespace NodeIntrospection
 } // namespace tooling
Index: clang/include/clang/AST/ASTConcept.h
===================================================================
--- clang/include/clang/AST/ASTConcept.h
+++ clang/include/clang/AST/ASTConcept.h
@@ -123,17 +123,35 @@
   const ASTTemplateArgumentListInfo *ArgsAsWritten;
 
 public:
+  enum ConceptReferenceKind { CK_ConceptReference, CK_TypeConstraint };
+
+private:
+  const ConceptReferenceKind Kind;
+
+public:
+  ConceptReference(ConceptReferenceKind Kind, NestedNameSpecifierLoc NNS,
+                   SourceLocation TemplateKWLoc,
+                   DeclarationNameInfo ConceptNameInfo, NamedDecl *FoundDecl,
+                   ConceptDecl *NamedConcept,
+                   const ASTTemplateArgumentListInfo *ArgsAsWritten)
+      : NestedNameSpec(NNS), TemplateKWLoc(TemplateKWLoc),
+        ConceptName(ConceptNameInfo), FoundDecl(FoundDecl),
+        NamedConcept(NamedConcept), ArgsAsWritten(ArgsAsWritten), Kind(Kind) {}
 
   ConceptReference(NestedNameSpecifierLoc NNS, SourceLocation TemplateKWLoc,
                    DeclarationNameInfo ConceptNameInfo, NamedDecl *FoundDecl,
                    ConceptDecl *NamedConcept,
-                   const ASTTemplateArgumentListInfo *ArgsAsWritten) :
-      NestedNameSpec(NNS), TemplateKWLoc(TemplateKWLoc),
-      ConceptName(ConceptNameInfo), FoundDecl(FoundDecl),
-      NamedConcept(NamedConcept), ArgsAsWritten(ArgsAsWritten) {}
+                   const ASTTemplateArgumentListInfo *ArgsAsWritten)
+      : ConceptReference(CK_ConceptReference, NNS, TemplateKWLoc,
+                         ConceptNameInfo, FoundDecl, NamedConcept,
+                         ArgsAsWritten) {}
+
+  ConceptReference()
+      : NestedNameSpec(), TemplateKWLoc(), ConceptName(), FoundDecl(nullptr),
+        NamedConcept(nullptr), ArgsAsWritten(nullptr),
+        Kind(CK_ConceptReference) {}
 
-  ConceptReference() : NestedNameSpec(), TemplateKWLoc(), ConceptName(),
-      FoundDecl(nullptr), NamedConcept(nullptr), ArgsAsWritten(nullptr) {}
+  ConceptReferenceKind getKind() const { return Kind; }
 
   const NestedNameSpecifierLoc &getNestedNameSpecifierLoc() const {
     return NestedNameSpec;
@@ -176,10 +194,11 @@
                  DeclarationNameInfo ConceptNameInfo, NamedDecl *FoundDecl,
                  ConceptDecl *NamedConcept,
                  const ASTTemplateArgumentListInfo *ArgsAsWritten,
-                 Expr *ImmediatelyDeclaredConstraint) :
-      ConceptReference(NNS, /*TemplateKWLoc=*/SourceLocation(), ConceptNameInfo,
-                       FoundDecl, NamedConcept, ArgsAsWritten),
-      ImmediatelyDeclaredConstraint(ImmediatelyDeclaredConstraint) {}
+                 Expr *ImmediatelyDeclaredConstraint)
+      : ConceptReference(CK_TypeConstraint, NNS,
+                         /*TemplateKWLoc=*/SourceLocation(), ConceptNameInfo,
+                         FoundDecl, NamedConcept, ArgsAsWritten),
+        ImmediatelyDeclaredConstraint(ImmediatelyDeclaredConstraint) {}
 
   /// \brief Get the immediately-declared constraint expression introduced by
   /// this type-constraint, that is - the constraint expression that is added to
@@ -189,6 +208,10 @@
   }
 
   void print(llvm::raw_ostream &OS, PrintingPolicy Policy) const;
+
+  static bool classof(const ConceptReference *R) {
+    return R->getKind() == CK_TypeConstraint;
+  }
 };
 
 } // clang
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to