steveire updated this revision to Diff 320593.
steveire added a comment.

Update


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D94880/new/

https://reviews.llvm.org/D94880

Files:
  clang/include/clang/ASTMatchers/Dynamic/Diagnostics.h
  clang/include/clang/ASTMatchers/Dynamic/Parser.h
  clang/include/clang/ASTMatchers/Dynamic/Registry.h
  clang/include/clang/ASTMatchers/Dynamic/VariantValue.h
  clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp
  clang/lib/ASTMatchers/Dynamic/Marshallers.h
  clang/lib/ASTMatchers/Dynamic/Parser.cpp
  clang/lib/ASTMatchers/Dynamic/Registry.cpp
  clang/lib/ASTMatchers/Dynamic/VariantValue.cpp
  clang/unittests/ASTMatchers/Dynamic/ParserTest.cpp

Index: clang/unittests/ASTMatchers/Dynamic/ParserTest.cpp
===================================================================
--- clang/unittests/ASTMatchers/Dynamic/ParserTest.cpp
+++ clang/unittests/ASTMatchers/Dynamic/ParserTest.cpp
@@ -32,6 +32,17 @@
     return M.getID().second;
   }
 
+  bool isBuilderMatcher(MatcherCtor) const override { return false; }
+
+  ASTNodeKind nodeMatcherType(MatcherCtor) const override { return {}; }
+
+  internal::MatcherDescriptorPtr
+  buildMatcherCtor(MatcherCtor, SourceRange NameRange,
+                   ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return internal::MatcherDescriptorPtr{nullptr};
+  }
+
   void parse(StringRef Code) {
     Diagnostics Error;
     VariantValue Value;
@@ -329,7 +340,7 @@
             "1:5: Invalid token <(> found when looking for a value.",
             ParseWithError("Foo(("));
   EXPECT_EQ("1:7: Expected end of code.", ParseWithError("expr()a"));
-  EXPECT_EQ("1:11: Malformed bind() expression.",
+  EXPECT_EQ("1:11: Period not followed by valid chained call.",
             ParseWithError("isArrow().biind"));
   EXPECT_EQ("1:15: Malformed bind() expression.",
             ParseWithError("isArrow().bind"));
@@ -340,6 +351,20 @@
   EXPECT_EQ("1:1: Error building matcher isArrow.\n"
             "1:1: Matcher does not support binding.",
             ParseWithError("isArrow().bind(\"foo\")"));
+  EXPECT_EQ("1:1: Error building matcher isArrow.\n"
+            "1:11: Matcher does not support with call.",
+            ParseWithError("isArrow().with"));
+  EXPECT_EQ(
+      "1:22: Error parsing matcher. Found token <EOF> while looking for '('.",
+      ParseWithError("mapAnyOf(ifStmt).with"));
+  EXPECT_EQ(
+      "1:22: Error parsing matcher. Found end-of-code while looking for ')'.",
+      ParseWithError("mapAnyOf(ifStmt).with("));
+  EXPECT_EQ("1:1: Failed to build matcher: mapAnyOf.",
+            ParseWithError("mapAnyOf()"));
+  EXPECT_EQ("1:1: Error parsing argument 1 for matcher mapAnyOf.\n1:1: Failed "
+            "to build matcher: mapAnyOf.",
+            ParseWithError("mapAnyOf(\"foo\")"));
   EXPECT_EQ("Input value has unresolved overloaded type: "
             "Matcher<DoStmt|ForStmt|WhileStmt|CXXForRangeStmt|FunctionDecl>",
             ParseMatcherWithError("hasBody(stmt())"));
@@ -470,7 +495,8 @@
 )matcher";
     M = Parser::parseMatcherExpression(Code, nullptr, &NamedValues, &Error);
     EXPECT_FALSE(M.hasValue());
-    EXPECT_EQ("1:11: Malformed bind() expression.", Error.toStringFull());
+    EXPECT_EQ("1:11: Period not followed by valid chained call.",
+              Error.toStringFull());
   }
 
   {
@@ -515,6 +541,29 @@
   ASSERT_EQ(1u, Comps.size());
   EXPECT_EQ("bind(\"", Comps[0].TypedText);
   EXPECT_EQ("bind", Comps[0].MatcherDecl);
+
+  Code = "mapAny";
+  Comps = Parser::completeExpression(Code, 6);
+  ASSERT_EQ(1u, Comps.size());
+  EXPECT_EQ("Of(", Comps[0].TypedText);
+  EXPECT_EQ("Matcher<NestedNameSpecifierLoc|QualType|TypeLoc|...> "
+            "mapAnyOf(NestedNameSpecifierLoc|QualType|TypeLoc|"
+            "NestedNameSpecifier|Decl|Stmt|Type...)",
+            Comps[0].MatcherDecl);
+
+  Code = "mapAnyOf(ifStmt).";
+  Comps = Parser::completeExpression(Code, 17);
+  ASSERT_EQ(2u, Comps.size());
+  EXPECT_EQ("bind(\"", Comps[0].TypedText);
+  EXPECT_EQ("bind", Comps[0].MatcherDecl);
+  EXPECT_EQ("with(\"", Comps[1].TypedText);
+  EXPECT_EQ("with", Comps[1].MatcherDecl);
+
+  Code = "mapAnyOf(ifS";
+  Comps = Parser::completeExpression(Code, 12);
+  ASSERT_EQ(1u, Comps.size());
+  EXPECT_EQ("tmt", Comps[0].TypedText);
+  EXPECT_EQ("ifStmt", Comps[0].MatcherDecl);
 }
 
 TEST(ParserTest, CompletionNamedValues) {
Index: clang/lib/ASTMatchers/Dynamic/VariantValue.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/VariantValue.cpp
+++ clang/lib/ASTMatchers/Dynamic/VariantValue.cpp
@@ -22,9 +22,11 @@
 std::string ArgKind::asString() const {
   switch (getArgKind()) {
   case AK_Matcher:
-    return (Twine("Matcher<") + MatcherKind.asStringRef() + ">").str();
+    return (Twine("Matcher<") + NodeKind.asStringRef() + ">").str();
   case AK_Boolean:
     return "boolean";
+  case AK_Node:
+    return NodeKind.asStringRef().str();
   case AK_Double:
     return "double";
   case AK_Unsigned:
@@ -38,13 +40,13 @@
 bool ArgKind::isConvertibleTo(ArgKind To, unsigned *Specificity) const {
   if (K != To.K)
     return false;
-  if (K != AK_Matcher) {
+  if (K != AK_Matcher && K != AK_Node) {
     if (Specificity)
       *Specificity = 1;
     return true;
   }
   unsigned Distance;
-  if (!MatcherKind.isBaseOf(To.MatcherKind, &Distance))
+  if (!NodeKind.isBaseOf(To.NodeKind, &Distance))
     return false;
 
   if (Specificity)
@@ -107,8 +109,8 @@
   }
 
   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity) const override {
-    return ArgKind(Matcher.getSupportedKind())
-        .isConvertibleTo(Kind, Specificity);
+    return ArgKind::MakeMatcherArg(Matcher.getSupportedKind())
+        .isConvertibleTo(ArgKind::MakeMatcherArg(Kind), Specificity);
   }
 
 private:
@@ -167,8 +169,9 @@
     unsigned MaxSpecificity = 0;
     for (const DynTypedMatcher &Matcher : Matchers) {
       unsigned ThisSpecificity;
-      if (ArgKind(Matcher.getSupportedKind())
-              .isConvertibleTo(Kind, &ThisSpecificity)) {
+      if (ArgKind::MakeMatcherArg(Matcher.getSupportedKind())
+              .isConvertibleTo(ArgKind::MakeMatcherArg(Kind),
+                               &ThisSpecificity)) {
         MaxSpecificity = std::max(MaxSpecificity, ThisSpecificity);
       }
     }
@@ -446,6 +449,11 @@
     if (!isMatcher())
       return false;
     return getMatcher().isConvertibleTo(Kind.getMatcherKind(), Specificity);
+
+  case ArgKind::AK_Node:
+    if (!isNodeKind())
+      return false;
+    return getMatcher().isConvertibleTo(Kind.getNodeKind(), Specificity);
   }
   llvm_unreachable("Invalid Type");
 }
Index: clang/lib/ASTMatchers/Dynamic/Registry.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Registry.cpp
+++ clang/lib/ASTMatchers/Dynamic/Registry.cpp
@@ -102,6 +102,9 @@
   // Other:
   // equalsNode
 
+  registerMatcher("mapAnyOf",
+                  std::make_unique<internal::MapAnyOfBuilderDescriptor>());
+
   REGISTER_OVERLOADED_2(callee);
   REGISTER_OVERLOADED_2(hasAnyCapture);
   REGISTER_OVERLOADED_2(hasPrefix);
@@ -562,6 +565,26 @@
 
 static llvm::ManagedStatic<RegistryMaps> RegistryData;
 
+internal::MatcherDescriptorPtr::MatcherDescriptorPtr(MatcherDescriptor *Ptr)
+    : Ptr(Ptr) {}
+
+internal::MatcherDescriptorPtr::~MatcherDescriptorPtr() { delete Ptr; }
+
+bool Registry::isBuilderMatcher(MatcherCtor Ctor) {
+  return Ctor->isBuilderMatcher();
+}
+
+ASTNodeKind Registry::nodeMatcherType(MatcherCtor Ctor) {
+  return Ctor->nodeMatcherType();
+}
+
+internal::MatcherDescriptorPtr
+Registry::buildMatcherCtor(MatcherCtor Ctor, SourceRange NameRange,
+                           ArrayRef<ParserValue> Args, Diagnostics *Error) {
+  return internal::MatcherDescriptorPtr(
+      Ctor->buildMatcherCtor(NameRange, Args, Error).release());
+}
+
 // static
 llvm::Optional<MatcherCtor> Registry::lookupMatcherCtor(StringRef MatcherName) {
   auto it = RegistryData->constructors().find(MatcherName);
@@ -599,7 +622,10 @@
 
   // Starting with the above seed of acceptable top-level matcher types, compute
   // the acceptable type set for the argument indicated by each context element.
-  std::set<ArgKind> TypeSet(std::begin(InitialTypes), std::end(InitialTypes));
+  std::set<ArgKind> TypeSet;
+  for (auto IT : InitialTypes) {
+    TypeSet.insert(ArgKind::MakeMatcherArg(IT));
+  }
   for (const auto &CtxEntry : Context) {
     MatcherCtor Ctor = CtxEntry.first;
     unsigned ArgNumber = CtxEntry.second;
@@ -611,7 +637,9 @@
         Ctor->getArgKinds(Kind.getMatcherKind(), ArgNumber, NextTypeSet);
     }
     TypeSet.clear();
-    TypeSet.insert(NextTypeSet.begin(), NextTypeSet.end());
+    for (auto IT : NextTypeSet) {
+      TypeSet.insert(IT);
+    }
   }
   return std::vector<ArgKind>(TypeSet.begin(), TypeSet.end());
 }
@@ -630,20 +658,40 @@
     bool IsPolymorphic = Matcher.isPolymorphic();
     std::vector<std::vector<ArgKind>> ArgsKinds(NumArgs);
     unsigned MaxSpecificity = 0;
+    bool NodeArgs = false;
     for (const ArgKind& Kind : AcceptedTypes) {
-      if (Kind.getArgKind() != Kind.AK_Matcher)
+      if (Kind.getArgKind() != Kind.AK_Matcher &&
+          Kind.getArgKind() != Kind.AK_Node) {
         continue;
-      unsigned Specificity;
-      ASTNodeKind LeastDerivedKind;
-      if (Matcher.isConvertibleTo(Kind.getMatcherKind(), &Specificity,
-                                  &LeastDerivedKind)) {
-        if (MaxSpecificity < Specificity)
-          MaxSpecificity = Specificity;
-        RetKinds.insert(LeastDerivedKind);
-        for (unsigned Arg = 0; Arg != NumArgs; ++Arg)
-          Matcher.getArgKinds(Kind.getMatcherKind(), Arg, ArgsKinds[Arg]);
-        if (IsPolymorphic)
-          break;
+      }
+
+      if (Kind.getArgKind() == Kind.AK_Node) {
+        NodeArgs = true;
+        unsigned Specificity;
+        ASTNodeKind LeastDerivedKind;
+        if (Matcher.isConvertibleTo(Kind.getNodeKind(), &Specificity,
+                                    &LeastDerivedKind)) {
+          if (MaxSpecificity < Specificity)
+            MaxSpecificity = Specificity;
+          RetKinds.insert(LeastDerivedKind);
+          for (unsigned Arg = 0; Arg != NumArgs; ++Arg)
+            Matcher.getArgKinds(Kind.getNodeKind(), Arg, ArgsKinds[Arg]);
+          if (IsPolymorphic)
+            break;
+        }
+      } else {
+        unsigned Specificity;
+        ASTNodeKind LeastDerivedKind;
+        if (Matcher.isConvertibleTo(Kind.getMatcherKind(), &Specificity,
+                                    &LeastDerivedKind)) {
+          if (MaxSpecificity < Specificity)
+            MaxSpecificity = Specificity;
+          RetKinds.insert(LeastDerivedKind);
+          for (unsigned Arg = 0; Arg != NumArgs; ++Arg)
+            Matcher.getArgKinds(Kind.getMatcherKind(), Arg, ArgsKinds[Arg]);
+          if (IsPolymorphic)
+            break;
+        }
       }
     }
 
@@ -651,42 +699,49 @@
       std::string Decl;
       llvm::raw_string_ostream OS(Decl);
 
-      if (IsPolymorphic) {
-        OS << "Matcher<T> " << Name << "(Matcher<T>";
+      std::string TypedText = std::string(Name);
+
+      if (NodeArgs) {
+        OS << Name;
       } else {
-        OS << "Matcher<" << RetKinds << "> " << Name << "(";
-        for (const std::vector<ArgKind> &Arg : ArgsKinds) {
-          if (&Arg != &ArgsKinds[0])
-            OS << ", ";
-
-          bool FirstArgKind = true;
-          std::set<ASTNodeKind> MatcherKinds;
-          // Two steps. First all non-matchers, then matchers only.
-          for (const ArgKind &AK : Arg) {
-            if (AK.getArgKind() == ArgKind::AK_Matcher) {
-              MatcherKinds.insert(AK.getMatcherKind());
-            } else {
+
+        if (IsPolymorphic) {
+          OS << "Matcher<T> " << Name << "(Matcher<T>";
+        } else {
+          OS << "Matcher<" << RetKinds << "> " << Name << "(";
+          for (const std::vector<ArgKind> &Arg : ArgsKinds) {
+            if (&Arg != &ArgsKinds[0])
+              OS << ", ";
+
+            bool FirstArgKind = true;
+            std::set<ASTNodeKind> MatcherKinds;
+            // Two steps. First all non-matchers, then matchers only.
+            for (const ArgKind &AK : Arg) {
+              if (AK.getArgKind() == ArgKind::AK_Matcher) {
+                MatcherKinds.insert(AK.getMatcherKind());
+              } else {
+                if (!FirstArgKind)
+                  OS << "|";
+                FirstArgKind = false;
+                OS << AK.asString();
+              }
+            }
+            if (!MatcherKinds.empty()) {
               if (!FirstArgKind) OS << "|";
-              FirstArgKind = false;
-              OS << AK.asString();
+              OS << "Matcher<" << MatcherKinds << ">";
             }
           }
-          if (!MatcherKinds.empty()) {
-            if (!FirstArgKind) OS << "|";
-            OS << "Matcher<" << MatcherKinds << ">";
-          }
         }
+        if (Matcher.isVariadic())
+          OS << "...";
+        OS << ")";
+
+        TypedText += "(";
+        if (ArgsKinds.empty())
+          TypedText += ")";
+        else if (ArgsKinds[0][0].getArgKind() == ArgKind::AK_String)
+          TypedText += "\"";
       }
-      if (Matcher.isVariadic())
-        OS << "...";
-      OS << ")";
-
-      std::string TypedText = std::string(Name);
-      TypedText += "(";
-      if (ArgsKinds.empty())
-        TypedText += ")";
-      else if (ArgsKinds[0][0].getArgKind() == ArgKind::AK_String)
-        TypedText += "\"";
 
       Completions.emplace_back(TypedText, OS.str(), MaxSpecificity);
     }
Index: clang/lib/ASTMatchers/Dynamic/Parser.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Parser.cpp
+++ clang/lib/ASTMatchers/Dynamic/Parser.cpp
@@ -20,6 +20,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/ManagedStatic.h"
+
 #include <algorithm>
 #include <cassert>
 #include <cerrno>
@@ -52,6 +53,7 @@
 
   /// Some known identifiers.
   static const char* const ID_Bind;
+  static const char *const ID_With;
 
   TokenInfo() = default;
 
@@ -62,6 +64,7 @@
 };
 
 const char* const Parser::TokenInfo::ID_Bind = "bind";
+const char *const Parser::TokenInfo::ID_With = "with";
 
 /// Simple tokenizer for the parser.
 class Parser::CodeTokenizer {
@@ -366,6 +369,29 @@
       }
 
       std::string BindID;
+      Tokenizer->consumeNextToken(); // Consume the period.
+      const TokenInfo ChainCallToken = Tokenizer->consumeNextToken();
+      if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) {
+        addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1));
+        return false;
+      }
+
+      if (ChainCallToken.Kind != TokenInfo::TK_Ident ||
+          (ChainCallToken.Text != TokenInfo::ID_Bind &&
+           ChainCallToken.Text != TokenInfo::ID_With)) {
+        Error->addError(ChainCallToken.Range,
+                        Error->ET_ParserMalformedChainedExpr);
+        return false;
+      }
+      if (ChainCallToken.Text == TokenInfo::ID_With) {
+
+        Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error,
+                                 NameToken.Text, NameToken.Range);
+
+        Error->addError(ChainCallToken.Range,
+                        Error->ET_RegistryMatcherNoWithSupport);
+        return false;
+      }
       if (!parseBindID(BindID))
         return false;
 
@@ -405,31 +431,28 @@
 
   Tokenizer->SkipNewlines();
 
+  assert(NameToken.Kind == TokenInfo::TK_Ident);
+  TokenInfo OpenToken = Tokenizer->consumeNextToken();
+  if (OpenToken.Kind != TokenInfo::TK_OpenParen) {
+    Error->addError(OpenToken.Range, Error->ET_ParserNoOpenParen)
+        << OpenToken.Text;
+    return false;
+  }
+
+  llvm::Optional<MatcherCtor> Ctor = S->lookupMatcherCtor(NameToken.Text);
+
   // Parse as a matcher expression.
-  return parseMatcherExpressionImpl(NameToken, Value);
+  return parseMatcherExpressionImpl(NameToken, OpenToken, Ctor, Value);
 }
 
 bool Parser::parseBindID(std::string &BindID) {
-  // Parse .bind("foo")
-  assert(Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period);
-  Tokenizer->consumeNextToken(); // consume the period.
-  const TokenInfo BindToken = Tokenizer->consumeNextToken();
-  if (BindToken.Kind == TokenInfo::TK_CodeCompletion) {
-    addCompletion(BindToken, MatcherCompletion("bind(\"", "bind", 1));
-    return false;
-  }
-
+  // Parse argument to .bind("foo")
   const TokenInfo OpenToken = Tokenizer->consumeNextToken();
   const TokenInfo IDToken = Tokenizer->consumeNextTokenIgnoreNewlines();
   const TokenInfo CloseToken = Tokenizer->consumeNextTokenIgnoreNewlines();
 
   // TODO: We could use different error codes for each/some to be more
   //       explicit about the syntax error.
-  if (BindToken.Kind != TokenInfo::TK_Ident ||
-      BindToken.Text != TokenInfo::ID_Bind) {
-    Error->addError(BindToken.Range, Error->ET_ParserMalformedBindExpr);
-    return false;
-  }
   if (OpenToken.Kind != TokenInfo::TK_OpenParen) {
     Error->addError(OpenToken.Range, Error->ET_ParserMalformedBindExpr);
     return false;
@@ -446,28 +469,177 @@
   return true;
 }
 
+bool Parser::parseMatcherBuilder(MatcherCtor Ctor, const TokenInfo &NameToken,
+                                 const TokenInfo &OpenToken,
+                                 VariantValue *Value) {
+  std::vector<ParserValue> Args;
+  TokenInfo EndToken;
+
+  Tokenizer->SkipNewlines();
+
+  {
+    ScopedContextEntry SCE(this, Ctor);
+
+    while (Tokenizer->nextTokenKind() != TokenInfo::TK_Eof) {
+      if (Tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) {
+        // End of args.
+        EndToken = Tokenizer->consumeNextToken();
+        break;
+      }
+      if (!Args.empty()) {
+        // We must find a , token to continue.
+        TokenInfo CommaToken = Tokenizer->consumeNextToken();
+        if (CommaToken.Kind != TokenInfo::TK_Comma) {
+          Error->addError(CommaToken.Range, Error->ET_ParserNoComma)
+              << CommaToken.Text;
+          return false;
+        }
+      }
+
+      Diagnostics::Context Ctx(Diagnostics::Context::MatcherArg, Error,
+                               NameToken.Text, NameToken.Range,
+                               Args.size() + 1);
+      ParserValue ArgValue;
+      Tokenizer->SkipNewlines();
+
+      if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_CodeCompletion) {
+        addExpressionCompletions();
+        return false;
+      }
+
+      TokenInfo NodeMatcherToken = Tokenizer->consumeNextToken();
+
+      if (NodeMatcherToken.Kind != TokenInfo::TK_Ident) {
+        Error->addError(NameToken.Range, Error->ET_ParserFailedToBuildMatcher)
+            << NameToken.Text;
+        return false;
+      }
+
+      ArgValue.Text = NodeMatcherToken.Text;
+      ArgValue.Range = NodeMatcherToken.Range;
+
+      llvm::Optional<MatcherCtor> MappedMatcher =
+          S->lookupMatcherCtor(ArgValue.Text);
+
+      if (!MappedMatcher) {
+        Error->addError(NodeMatcherToken.Range,
+                        Error->ET_RegistryMatcherNotFound)
+            << NodeMatcherToken.Text;
+        return false;
+      }
+
+      ASTNodeKind NK = S->nodeMatcherType(*MappedMatcher);
+
+      if (NK.isNone()) {
+        Error->addError(NodeMatcherToken.Range,
+                        Error->ET_RegistryNonNodeMatcher)
+            << NodeMatcherToken.Text;
+        return false;
+      }
+
+      ArgValue.Value = NK;
+
+      Tokenizer->SkipNewlines();
+      Args.push_back(ArgValue);
+
+      SCE.nextArg();
+    }
+  }
+
+  if (EndToken.Kind == TokenInfo::TK_Eof) {
+    Error->addError(OpenToken.Range, Error->ET_ParserNoCloseParen);
+    return false;
+  }
+
+  internal::MatcherDescriptorPtr BuiltCtor =
+      S->buildMatcherCtor(Ctor, NameToken.Range, Args, Error);
+
+  if (!BuiltCtor.get()) {
+    Error->addError(NameToken.Range, Error->ET_ParserFailedToBuildMatcher)
+        << NameToken.Text;
+    return false;
+  }
+
+  std::string BindID;
+  if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) {
+    Tokenizer->consumeNextToken(); // Consume the period.
+    TokenInfo ChainCallToken = Tokenizer->consumeNextToken();
+    if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) {
+      addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1));
+      addCompletion(ChainCallToken, MatcherCompletion("with(\"", "with", 1));
+      return false;
+    }
+    if (ChainCallToken.Kind != TokenInfo::TK_Ident ||
+        (ChainCallToken.Text != TokenInfo::ID_Bind &&
+         ChainCallToken.Text != TokenInfo::ID_With)) {
+      Error->addError(ChainCallToken.Range,
+                      Error->ET_ParserMalformedChainedExpr);
+      return false;
+    }
+    if (ChainCallToken.Text == TokenInfo::ID_Bind) {
+      if (!parseBindID(BindID))
+        return false;
+      Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error,
+                               NameToken.Text, NameToken.Range);
+      SourceRange MatcherRange = NameToken.Range;
+      MatcherRange.End = ChainCallToken.Range.End;
+      VariantMatcher Result = S->actOnMatcherExpression(
+          BuiltCtor.get(), MatcherRange, BindID, {}, Error);
+      if (Result.isNull())
+        return false;
+
+      *Value = Result;
+      return true;
+    } else if (ChainCallToken.Text == TokenInfo::ID_With) {
+      Tokenizer->SkipNewlines();
+
+      if (Tokenizer->nextTokenKind() != TokenInfo::TK_OpenParen) {
+        StringRef ErrTxt = Tokenizer->nextTokenKind() == TokenInfo::TK_Eof
+                               ? StringRef("EOF")
+                               : Tokenizer->peekNextToken().Text;
+        Error->addError(Tokenizer->peekNextToken().Range,
+                        Error->ET_ParserNoOpenParen)
+            << ErrTxt;
+        return false;
+      }
+
+      const TokenInfo WithOpenToken = Tokenizer->consumeNextToken();
+
+      return parseMatcherExpressionImpl(NameToken, WithOpenToken,
+                                        BuiltCtor.get(), Value);
+    }
+  }
+
+  Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error,
+                           NameToken.Text, NameToken.Range);
+  SourceRange MatcherRange = NameToken.Range;
+  MatcherRange.End = EndToken.Range.End;
+  VariantMatcher Result = S->actOnMatcherExpression(
+      BuiltCtor.get(), MatcherRange, BindID, {}, Error);
+  if (Result.isNull())
+    return false;
+
+  *Value = Result;
+  return true;
+}
+
 /// Parse and validate a matcher expression.
 /// \return \c true on success, in which case \c Value has the matcher parsed.
 ///   If the input is malformed, or some argument has an error, it
 ///   returns \c false.
 bool Parser::parseMatcherExpressionImpl(const TokenInfo &NameToken,
+                                        const TokenInfo &OpenToken,
+                                        llvm::Optional<MatcherCtor> Ctor,
                                         VariantValue *Value) {
-  assert(NameToken.Kind == TokenInfo::TK_Ident);
-  const TokenInfo OpenToken = Tokenizer->consumeNextToken();
-  if (OpenToken.Kind != TokenInfo::TK_OpenParen) {
-    Error->addError(OpenToken.Range, Error->ET_ParserNoOpenParen)
-        << OpenToken.Text;
-    return false;
-  }
-
-  llvm::Optional<MatcherCtor> Ctor = S->lookupMatcherCtor(NameToken.Text);
-
   if (!Ctor) {
     Error->addError(NameToken.Range, Error->ET_RegistryMatcherNotFound)
         << NameToken.Text;
     // Do not return here. We need to continue to give completion suggestions.
   }
 
+  if (Ctor && *Ctor && S->isBuilderMatcher(*Ctor))
+    return parseMatcherBuilder(*Ctor, NameToken, OpenToken, Value);
+
   std::vector<ParserValue> Args;
   TokenInfo EndToken;
 
@@ -516,6 +688,33 @@
 
   std::string BindID;
   if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) {
+    Tokenizer->consumeNextToken();
+    TokenInfo ChainCallToken = Tokenizer->consumeNextToken();
+    if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) {
+      addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1));
+      return false;
+    }
+
+    if (ChainCallToken.Kind != TokenInfo::TK_Ident) {
+      Error->addError(ChainCallToken.Range,
+                      Error->ET_ParserMalformedChainedExpr);
+      return false;
+    }
+    if (ChainCallToken.Text == TokenInfo::ID_With) {
+
+      Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error,
+                               NameToken.Text, NameToken.Range);
+
+      Error->addError(ChainCallToken.Range,
+                      Error->ET_RegistryMatcherNoWithSupport);
+      return false;
+    }
+    if (ChainCallToken.Text != TokenInfo::ID_Bind) {
+      Error->addError(ChainCallToken.Range,
+                      Error->ET_ParserMalformedChainedExpr);
+      return false;
+    }
+
     if (!parseBindID(BindID))
       return false;
   }
@@ -657,6 +856,21 @@
   return Registry::getMatcherCompletions(AcceptedTypes);
 }
 
+bool Parser::RegistrySema::isBuilderMatcher(MatcherCtor Ctor) const {
+  return Registry::isBuilderMatcher(Ctor);
+}
+
+ASTNodeKind Parser::RegistrySema::nodeMatcherType(MatcherCtor Ctor) const {
+  return Registry::nodeMatcherType(Ctor);
+}
+
+internal::MatcherDescriptorPtr
+Parser::RegistrySema::buildMatcherCtor(MatcherCtor Ctor, SourceRange NameRange,
+                                       ArrayRef<ParserValue> Args,
+                                       Diagnostics *Error) const {
+  return Registry::buildMatcherCtor(Ctor, NameRange, Args, Error);
+}
+
 bool Parser::parseExpression(StringRef &Code, Sema *S,
                              const NamedValueMap *NamedValues,
                              VariantValue *Value, Diagnostics *Error) {
Index: clang/lib/ASTMatchers/Dynamic/Marshallers.h
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Marshallers.h
+++ clang/lib/ASTMatchers/Dynamic/Marshallers.h
@@ -93,7 +93,7 @@
   }
 
   static ArgKind getKind() {
-    return ArgKind(ASTNodeKind::getFromNodeKind<T>());
+    return ArgKind::MakeMatcherArg(ASTNodeKind::getFromNodeKind<T>());
   }
 
   static llvm::Optional<std::string> getBestGuess(const VariantValue &) {
@@ -309,6 +309,14 @@
                                 ArrayRef<ParserValue> Args,
                                 Diagnostics *Error) const = 0;
 
+  virtual ASTNodeKind nodeMatcherType() const = 0;
+
+  virtual bool isBuilderMatcher() const = 0;
+
+  virtual std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const = 0;
+
   /// Returns whether the matcher is variadic. Variadic matchers can take any
   /// number of arguments, but they must be of the same type.
   virtual bool isVariadic() const = 0;
@@ -343,7 +351,8 @@
                                    ASTNodeKind Kind, unsigned *Specificity,
                                    ASTNodeKind *LeastDerivedKind) {
   for (const ASTNodeKind &NodeKind : RetKinds) {
-    if (ArgKind(NodeKind).isConvertibleTo(Kind, Specificity)) {
+    if (ArgKind::MakeMatcherArg(NodeKind).isConvertibleTo(
+            ArgKind::MakeMatcherArg(Kind), Specificity)) {
       if (LeastDerivedKind)
         *LeastDerivedKind = NodeKind;
       return true;
@@ -386,6 +395,16 @@
     return Marshaller(Func, MatcherName, NameRange, Args, Error);
   }
 
+  bool isBuilderMatcher() const override { return false; }
+
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue>,
+                   Diagnostics *) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
   bool isVariadic() const override { return false; }
   unsigned getNumArgs() const override { return ArgKinds.size(); }
 
@@ -561,6 +580,18 @@
     return Func(MatcherName, NameRange, Args, Error);
   }
 
+  bool isBuilderMatcher() const override { return false; }
+
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue>,
+                   Diagnostics *) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override {
+    return ArgsKind.getMatcherKind();
+  }
+
   bool isVariadic() const override { return true; }
   unsigned getNumArgs() const override { return 0; }
 
@@ -610,6 +641,8 @@
     }
   }
 
+  ASTNodeKind nodeMatcherType() const override { return DerivedKind; }
+
 private:
   const ASTNodeKind DerivedKind;
 };
@@ -786,6 +819,15 @@
     return false;
   }
 
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue>,
+                   Diagnostics *) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
 private:
   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
 };
@@ -856,6 +898,15 @@
                   ArgTypeTraits<llvm::Regex::RegexFlags>::get(Args[1].Value)));
   }
 
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue>,
+                   Diagnostics *) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
 private:
   ReturnType (*const WithFlags)(StringRef, llvm::Regex::RegexFlags);
   ReturnType (*const NoFlags)(StringRef);
@@ -904,7 +955,7 @@
 
   void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
                    std::vector<ArgKind> &Kinds) const override {
-    Kinds.push_back(ThisKind);
+    Kinds.push_back(ArgKind::MakeMatcherArg(ThisKind));
   }
 
   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
@@ -918,6 +969,15 @@
 
   bool isPolymorphic() const override { return true; }
 
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue>,
+                   Diagnostics *) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
 private:
   const unsigned MinCount;
   const unsigned MaxCount;
@@ -976,7 +1036,7 @@
 
   void getArgKinds(ASTNodeKind ThisKind, unsigned,
                    std::vector<ArgKind> &Kinds) const override {
-    Kinds.push_back(ThisKind);
+    Kinds.push_back(ArgKind::MakeMatcherArg(ThisKind));
   }
 
   bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity,
@@ -987,6 +1047,67 @@
       *LeastDerivedKind = CladeNodeKind;
     return true;
   }
+
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue>,
+                   Diagnostics *) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+};
+
+class MapAnyOfBuilderDescriptor : public MatcherDescriptor {
+public:
+  VariantMatcher create(SourceRange, ArrayRef<ParserValue>,
+                        Diagnostics *) const override {
+    return {};
+  }
+
+  bool isBuilderMatcher() const override { return true; }
+
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *) const override {
+
+    std::vector<ASTNodeKind> NodeKinds;
+    for (auto Arg : Args) {
+      if (!Arg.Value.isNodeKind())
+        return {};
+      NodeKinds.push_back(Arg.Value.getNodeKind());
+    }
+
+    if (NodeKinds.empty())
+      return {};
+
+    ASTNodeKind CladeNodeKind = NodeKinds.front().getCladeKind();
+
+    return std::make_unique<MapAnyOfMatcherDescriptor>(CladeNodeKind,
+                                                       NodeKinds);
+  }
+
+  bool isVariadic() const override { return true; }
+
+  unsigned getNumArgs() const override { return 0; }
+
+  void getArgKinds(ASTNodeKind ThisKind, unsigned ArgNo,
+                   std::vector<ArgKind> &ArgKinds) const override {
+    ArgKinds.push_back(ArgKind::MakeNodeArg(ThisKind));
+    return;
+  }
+  bool isConvertibleTo(ASTNodeKind Kind, unsigned *Specificity = nullptr,
+                       ASTNodeKind *LeastDerivedKind = nullptr) const override {
+    if (Specificity)
+      *Specificity = 1;
+    if (LeastDerivedKind)
+      *LeastDerivedKind = Kind;
+    return true;
+  }
+
+  bool isPolymorphic() const override { return false; }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
 };
 
 /// Helper functions to select the appropriate marshaller functions.
Index: clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp
+++ clang/lib/ASTMatchers/Dynamic/Diagnostics.cpp
@@ -100,6 +100,10 @@
     return "Value not found: $0";
   case Diagnostics::ET_RegistryUnknownEnumWithReplace:
     return "Unknown value '$1' for arg $0; did you mean '$2'";
+  case Diagnostics::ET_RegistryNonNodeMatcher:
+    return "Matcher not a node matcher: $0";
+  case Diagnostics::ET_RegistryMatcherNoWithSupport:
+    return "Matcher does not support with call.";
 
   case Diagnostics::ET_ParserStringError:
     return "Error parsing string token: <$0>";
@@ -123,6 +127,10 @@
     return "Error parsing numeric literal: <$0>";
   case Diagnostics::ET_ParserOverloadedType:
     return "Input value has unresolved overloaded type: $0";
+  case Diagnostics::ET_ParserMalformedChainedExpr:
+    return "Period not followed by valid chained call.";
+  case Diagnostics::ET_ParserFailedToBuildMatcher:
+    return "Failed to build matcher: $0.";
 
   case Diagnostics::ET_None:
     return "<N/A>";
Index: clang/include/clang/ASTMatchers/Dynamic/VariantValue.h
===================================================================
--- clang/include/clang/ASTMatchers/Dynamic/VariantValue.h
+++ clang/include/clang/ASTMatchers/Dynamic/VariantValue.h
@@ -33,23 +33,34 @@
 /// It supports all types that VariantValue can contain.
 class ArgKind {
  public:
-  enum Kind {
-    AK_Matcher,
-    AK_Boolean,
-    AK_Double,
-    AK_Unsigned,
-    AK_String
-  };
-  /// Constructor for non-matcher types.
-  ArgKind(Kind K) : K(K) { assert(K != AK_Matcher); }
-
-  /// Constructor for matcher types.
-  ArgKind(ASTNodeKind MatcherKind) : K(AK_Matcher), MatcherKind(MatcherKind) {}
+   enum Kind {
+     AK_Matcher,
+     AK_Node,
+     AK_Boolean,
+     AK_Double,
+     AK_Unsigned,
+     AK_String
+   };
+   /// Constructor for non-matcher types.
+   ArgKind(Kind K) : K(K) { assert(K != AK_Matcher && K != AK_Node); }
+
+   /// Constructor for matcher types.
+   static ArgKind MakeMatcherArg(ASTNodeKind MatcherKind) {
+     return ArgKind{AK_Matcher, MatcherKind};
+   }
+
+   static ArgKind MakeNodeArg(ASTNodeKind MatcherKind) {
+     return ArgKind{AK_Node, MatcherKind};
+   }
 
   Kind getArgKind() const { return K; }
   ASTNodeKind getMatcherKind() const {
     assert(K == AK_Matcher);
-    return MatcherKind;
+    return NodeKind;
+  }
+  ASTNodeKind getNodeKind() const {
+    assert(K == AK_Node);
+    return NodeKind;
   }
 
   /// Determines if this type can be converted to \p To.
@@ -61,8 +72,9 @@
   bool isConvertibleTo(ArgKind To, unsigned *Specificity) const;
 
   bool operator<(const ArgKind &Other) const {
-    if (K == AK_Matcher && Other.K == AK_Matcher)
-      return MatcherKind < Other.MatcherKind;
+    if ((K == AK_Matcher && Other.K == AK_Matcher) ||
+        (K == AK_Node && Other.K == AK_Node))
+      return NodeKind < Other.NodeKind;
     return K < Other.K;
   }
 
@@ -70,8 +82,10 @@
   std::string asString() const;
 
 private:
+  ArgKind(Kind K, ASTNodeKind NK) : K(K), NodeKind(NK) {}
+
   Kind K;
-  ASTNodeKind MatcherKind;
+  ASTNodeKind NodeKind;
 };
 
 using ast_matchers::internal::DynTypedMatcher;
Index: clang/include/clang/ASTMatchers/Dynamic/Registry.h
===================================================================
--- clang/include/clang/ASTMatchers/Dynamic/Registry.h
+++ clang/include/clang/ASTMatchers/Dynamic/Registry.h
@@ -33,6 +33,23 @@
 
 class MatcherDescriptor;
 
+/// A smart (owning) pointer for MatcherDescriptor. We can't use unique_ptr
+/// because MatcherDescriptor is forward declared
+class MatcherDescriptorPtr {
+public:
+  explicit MatcherDescriptorPtr(MatcherDescriptor *);
+  ~MatcherDescriptorPtr();
+  MatcherDescriptorPtr(MatcherDescriptorPtr &&) = default;
+  MatcherDescriptorPtr &operator=(MatcherDescriptorPtr &&) = default;
+  MatcherDescriptorPtr(const MatcherDescriptorPtr &) = delete;
+  MatcherDescriptorPtr &operator=(const MatcherDescriptorPtr &) = delete;
+
+  MatcherDescriptor *get() { return Ptr; }
+
+private:
+  MatcherDescriptor *Ptr;
+};
+
 } // namespace internal
 
 using MatcherCtor = const internal::MatcherDescriptor *;
@@ -66,6 +83,14 @@
 public:
   Registry() = delete;
 
+  static bool isBuilderMatcher(MatcherCtor Ctor);
+
+  static ASTNodeKind nodeMatcherType(MatcherCtor);
+
+  static internal::MatcherDescriptorPtr
+  buildMatcherCtor(MatcherCtor, SourceRange NameRange,
+                   ArrayRef<ParserValue> Args, Diagnostics *Error);
+
   /// Look up a matcher in the registry by name,
   ///
   /// \return An opaque value which may be used to refer to the matcher
Index: clang/include/clang/ASTMatchers/Dynamic/Parser.h
===================================================================
--- clang/include/clang/ASTMatchers/Dynamic/Parser.h
+++ clang/include/clang/ASTMatchers/Dynamic/Parser.h
@@ -100,6 +100,14 @@
     virtual llvm::Optional<MatcherCtor>
     lookupMatcherCtor(StringRef MatcherName) = 0;
 
+    virtual bool isBuilderMatcher(MatcherCtor) const = 0;
+
+    virtual ASTNodeKind nodeMatcherType(MatcherCtor) const = 0;
+
+    virtual internal::MatcherDescriptorPtr
+    buildMatcherCtor(MatcherCtor, SourceRange NameRange,
+                     ArrayRef<ParserValue> Args, Diagnostics *Error) const = 0;
+
     /// Compute the list of completion types for \p Context.
     ///
     /// Each element of \p Context represents a matcher invocation, going from
@@ -142,6 +150,15 @@
     std::vector<ArgKind> getAcceptedCompletionTypes(
         llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> Context) override;
 
+    bool isBuilderMatcher(MatcherCtor Ctor) const override;
+
+    ASTNodeKind nodeMatcherType(MatcherCtor) const override;
+
+    internal::MatcherDescriptorPtr
+    buildMatcherCtor(MatcherCtor, SourceRange NameRange,
+                     ArrayRef<ParserValue> Args,
+                     Diagnostics *Error) const override;
+
     std::vector<MatcherCompletion>
     getMatcherCompletions(llvm::ArrayRef<ArgKind> AcceptedTypes) override;
   };
@@ -233,7 +250,11 @@
 
   bool parseBindID(std::string &BindID);
   bool parseExpressionImpl(VariantValue *Value);
+  bool parseMatcherBuilder(MatcherCtor Ctor, const TokenInfo &NameToken,
+                           const TokenInfo &OpenToken, VariantValue *Value);
   bool parseMatcherExpressionImpl(const TokenInfo &NameToken,
+                                  const TokenInfo &OpenToken,
+                                  llvm::Optional<MatcherCtor> Ctor,
                                   VariantValue *Value);
   bool parseIdentifierPrefixImpl(VariantValue *Value);
 
Index: clang/include/clang/ASTMatchers/Dynamic/Diagnostics.h
===================================================================
--- clang/include/clang/ASTMatchers/Dynamic/Diagnostics.h
+++ clang/include/clang/ASTMatchers/Dynamic/Diagnostics.h
@@ -66,6 +66,8 @@
     ET_RegistryAmbiguousOverload = 5,
     ET_RegistryValueNotFound = 6,
     ET_RegistryUnknownEnumWithReplace = 7,
+    ET_RegistryNonNodeMatcher = 8,
+    ET_RegistryMatcherNoWithSupport = 9,
 
     ET_ParserStringError = 100,
     ET_ParserNoOpenParen = 101,
@@ -77,7 +79,9 @@
     ET_ParserMalformedBindExpr = 107,
     ET_ParserTrailingCode = 108,
     ET_ParserNumberError = 109,
-    ET_ParserOverloadedType = 110
+    ET_ParserOverloadedType = 110,
+    ET_ParserMalformedChainedExpr = 111,
+    ET_ParserFailedToBuildMatcher = 112
   };
 
   /// Helper stream class.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to