steveire created this revision.
steveire added a reviewer: aaron.ballman.
steveire requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D94881

Files:
  clang/include/clang/ASTMatchers/Dynamic/Parser.h
  clang/lib/ASTMatchers/Dynamic/Marshallers.h
  clang/lib/ASTMatchers/Dynamic/Parser.cpp

Index: clang/lib/ASTMatchers/Dynamic/Parser.cpp
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Parser.cpp
+++ clang/lib/ASTMatchers/Dynamic/Parser.cpp
@@ -20,6 +20,9 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/ManagedStatic.h"
+
+#include "Marshallers.h"
+
 #include <algorithm>
 #include <cassert>
 #include <cerrno>
@@ -52,6 +55,7 @@
 
   /// Some known identifiers.
   static const char* const ID_Bind;
+  static const char *const ID_With;
 
   TokenInfo() = default;
 
@@ -62,6 +66,7 @@
 };
 
 const char* const Parser::TokenInfo::ID_Bind = "bind";
+const char *const Parser::TokenInfo::ID_With = "with";
 
 /// Simple tokenizer for the parser.
 class Parser::CodeTokenizer {
@@ -366,7 +371,14 @@
       }
 
       std::string BindID;
-      if (!parseBindID(BindID))
+      Tokenizer->consumeNextToken(); // consume the period.
+      const TokenInfo ChainCallToken = Tokenizer->consumeNextToken();
+      if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) {
+        addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1));
+        addCompletion(ChainCallToken, MatcherCompletion("with(\"", "with", 1));
+        return false;
+      }
+      if (!parseBindID(ChainCallToken, BindID))
         return false;
 
       assert(NamedValue.isMatcher());
@@ -405,31 +417,26 @@
 
   Tokenizer->SkipNewlines();
 
+  assert(NameToken.Kind == TokenInfo::TK_Ident);
+  const TokenInfo OpenToken = Tokenizer->consumeNextToken();
+  if (OpenToken.Kind != TokenInfo::TK_OpenParen) {
+    Error->addError(OpenToken.Range, Error->ET_ParserNoOpenParen)
+        << OpenToken.Text;
+    return false;
+  }
+
+  llvm::Optional<MatcherCtor> Ctor = S->lookupMatcherCtor(NameToken.Text);
+
   // Parse as a matcher expression.
-  return parseMatcherExpressionImpl(NameToken, Value);
+  return parseMatcherExpressionImpl(NameToken, OpenToken, Ctor, Value);
 }
 
-bool Parser::parseBindID(std::string &BindID) {
+bool Parser::parseBindID(TokenInfo BindToken, std::string &BindID) {
   // Parse .bind("foo")
-  assert(Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period);
-  Tokenizer->consumeNextToken(); // consume the period.
-  const TokenInfo BindToken = Tokenizer->consumeNextToken();
-  if (BindToken.Kind == TokenInfo::TK_CodeCompletion) {
-    addCompletion(BindToken, MatcherCompletion("bind(\"", "bind", 1));
-    return false;
-  }
-
   const TokenInfo OpenToken = Tokenizer->consumeNextToken();
   const TokenInfo IDToken = Tokenizer->consumeNextTokenIgnoreNewlines();
   const TokenInfo CloseToken = Tokenizer->consumeNextTokenIgnoreNewlines();
 
-  // TODO: We could use different error codes for each/some to be more
-  //       explicit about the syntax error.
-  if (BindToken.Kind != TokenInfo::TK_Ident ||
-      BindToken.Text != TokenInfo::ID_Bind) {
-    Error->addError(BindToken.Range, Error->ET_ParserMalformedBindExpr);
-    return false;
-  }
   if (OpenToken.Kind != TokenInfo::TK_OpenParen) {
     Error->addError(OpenToken.Range, Error->ET_ParserMalformedBindExpr);
     return false;
@@ -446,28 +453,140 @@
   return true;
 }
 
+bool Parser::parseMatcherBuilder(MatcherCtor Ctor, const TokenInfo &NameToken,
+                                 const TokenInfo &OpenToken,
+                                 VariantValue *Value) {
+  std::vector<ParserValue> Args;
+  TokenInfo EndToken;
+
+  Tokenizer->SkipNewlines();
+
+  {
+    ScopedContextEntry SCE(this, Ctor);
+
+    while (Tokenizer->nextTokenKind() != TokenInfo::TK_Eof) {
+      if (Tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) {
+        // End of args.
+        EndToken = Tokenizer->consumeNextToken();
+        break;
+      }
+      if (!Args.empty()) {
+        // We must find a , token to continue.
+        const TokenInfo CommaToken = Tokenizer->consumeNextToken();
+        if (CommaToken.Kind != TokenInfo::TK_Comma) {
+          Error->addError(CommaToken.Range, Error->ET_ParserNoComma)
+              << CommaToken.Text;
+          return false;
+        }
+      }
+
+      Diagnostics::Context Ctx(Diagnostics::Context::MatcherArg, Error,
+                               NameToken.Text, NameToken.Range,
+                               Args.size() + 1);
+      ParserValue ArgValue;
+      Tokenizer->SkipNewlines();
+      const auto NodeMatcherToken = Tokenizer->consumeNextToken();
+      ArgValue.Text = NodeMatcherToken.Text;
+      ArgValue.Range = NodeMatcherToken.Range;
+
+      auto MappedMatcher = S->lookupMatcherCtor(ArgValue.Text);
+
+      if (!MappedMatcher) {
+        // TODO: error
+      }
+
+      if ((*MappedMatcher)->nodeMatcherType().isNone()) {
+        // TODO: error
+      }
+
+      ArgValue.Value = (*MappedMatcher)->nodeMatcherType();
+
+      Tokenizer->SkipNewlines();
+      Args.push_back(ArgValue);
+
+      SCE.nextArg();
+    }
+  }
+
+  if (EndToken.Kind == TokenInfo::TK_Eof) {
+    Error->addError(OpenToken.Range, Error->ET_ParserNoCloseParen);
+    return false;
+  }
+
+  auto BuiltCtor = Ctor->buildMatcherCtor(NameToken.Range, Args, Error);
+
+  if (!BuiltCtor) {
+    // Too early for completion?
+    return false;
+  }
+
+  std::string BindID;
+  if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) {
+    Tokenizer->consumeNextToken(); // consume the period.
+    const TokenInfo ChainCallToken = Tokenizer->consumeNextToken();
+    if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) {
+      addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1));
+      addCompletion(ChainCallToken, MatcherCompletion("with(\"", "with", 1));
+      return false;
+    }
+    if (ChainCallToken.Kind != TokenInfo::TK_Ident) {
+      // TODO: Change error diagnostic to also account for with
+      Error->addError(ChainCallToken.Range, Error->ET_ParserMalformedBindExpr);
+      return false;
+    }
+    if (ChainCallToken.Text == TokenInfo::ID_Bind) {
+      if (!parseBindID(ChainCallToken, BindID))
+        return false;
+      Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error,
+                               NameToken.Text, NameToken.Range);
+      SourceRange MatcherRange = NameToken.Range;
+      MatcherRange.End = ChainCallToken.Range.End;
+      VariantMatcher Result = S->actOnMatcherExpression(
+          BuiltCtor.get(), MatcherRange, BindID, {}, Error);
+      if (Result.isNull())
+        return false;
+
+      *Value = Result;
+      return true;
+    } else if (ChainCallToken.Text == TokenInfo::ID_With) {
+      Tokenizer->SkipNewlines();
+      const TokenInfo WithOpenToken = Tokenizer->consumeNextToken();
+      return parseMatcherExpressionImpl(NameToken, WithOpenToken,
+                                        BuiltCtor.get(), Value);
+    }
+  }
+
+  Diagnostics::Context Ctx(Diagnostics::Context::ConstructMatcher, Error,
+                           NameToken.Text, NameToken.Range);
+  SourceRange MatcherRange = NameToken.Range;
+  MatcherRange.End = EndToken.Range.End;
+  VariantMatcher Result = S->actOnMatcherExpression(
+      BuiltCtor.get(), MatcherRange, BindID, {}, Error);
+  if (Result.isNull())
+    return false;
+
+  *Value = Result;
+  return true;
+}
+
 /// Parse and validate a matcher expression.
 /// \return \c true on success, in which case \c Value has the matcher parsed.
 ///   If the input is malformed, or some argument has an error, it
 ///   returns \c false.
 bool Parser::parseMatcherExpressionImpl(const TokenInfo &NameToken,
+                                        const TokenInfo &OpenToken,
+                                        llvm::Optional<MatcherCtor> Ctor,
                                         VariantValue *Value) {
-  assert(NameToken.Kind == TokenInfo::TK_Ident);
-  const TokenInfo OpenToken = Tokenizer->consumeNextToken();
-  if (OpenToken.Kind != TokenInfo::TK_OpenParen) {
-    Error->addError(OpenToken.Range, Error->ET_ParserNoOpenParen)
-        << OpenToken.Text;
-    return false;
-  }
-
-  llvm::Optional<MatcherCtor> Ctor = S->lookupMatcherCtor(NameToken.Text);
-
   if (!Ctor) {
     Error->addError(NameToken.Range, Error->ET_RegistryMatcherNotFound)
         << NameToken.Text;
     // Do not return here. We need to continue to give completion suggestions.
   }
 
+  if (Ctor && *Ctor && (*Ctor)->isBuilderMatcher()) {
+    return parseMatcherBuilder(*Ctor, NameToken, OpenToken, Value);
+  }
+
   std::vector<ParserValue> Args;
   TokenInfo EndToken;
 
@@ -516,7 +635,15 @@
 
   std::string BindID;
   if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) {
-    if (!parseBindID(BindID))
+    Tokenizer->consumeNextToken();
+    const TokenInfo ChainCallToken = Tokenizer->consumeNextToken();
+    if (ChainCallToken.Kind == TokenInfo::TK_CodeCompletion) {
+      addCompletion(ChainCallToken, MatcherCompletion("bind(\"", "bind", 1));
+      addCompletion(ChainCallToken, MatcherCompletion("with(\"", "with", 1));
+      return false;
+    }
+
+    if (!parseBindID(ChainCallToken, BindID))
       return false;
   }
 
Index: clang/lib/ASTMatchers/Dynamic/Marshallers.h
===================================================================
--- clang/lib/ASTMatchers/Dynamic/Marshallers.h
+++ clang/lib/ASTMatchers/Dynamic/Marshallers.h
@@ -309,6 +309,14 @@
                                 ArrayRef<ParserValue> Args,
                                 Diagnostics *Error) const = 0;
 
+  virtual ASTNodeKind nodeMatcherType() const = 0;
+
+  virtual bool isBuilderMatcher() const = 0;
+
+  virtual std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const = 0;
+
   /// Returns whether the matcher is variadic. Variadic matchers can take any
   /// number of arguments, but they must be of the same type.
   virtual bool isVariadic() const = 0;
@@ -386,6 +394,16 @@
     return Marshaller(Func, MatcherName, NameRange, Args, Error);
   }
 
+  bool isBuilderMatcher() const override { return false; }
+
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
   bool isVariadic() const override { return false; }
   unsigned getNumArgs() const override { return ArgKinds.size(); }
 
@@ -551,7 +569,8 @@
       StringRef MatcherName)
       : Func(&variadicMatcherDescriptor<ResultT, ArgT, F>),
         MatcherName(MatcherName.str()),
-        ArgsKind(ArgTypeTraits<ArgT>::getKind()) {
+        ArgsKind(ArgTypeTraits<ArgT>::getKind()),
+        ArgsNodeKind(ASTNodeKind::getFromNodeKind<ArgT>()) {
     BuildReturnTypeVector<ResultT>::build(RetKinds);
   }
 
@@ -561,6 +580,16 @@
     return Func(MatcherName, NameRange, Args, Error);
   }
 
+  bool isBuilderMatcher() const override { return false; }
+
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ArgsNodeKind; }
+
   bool isVariadic() const override { return true; }
   unsigned getNumArgs() const override { return 0; }
 
@@ -580,6 +609,7 @@
   const std::string MatcherName;
   std::vector<ASTNodeKind> RetKinds;
   const ArgKind ArgsKind;
+  const ASTNodeKind ArgsNodeKind;
 };
 
 /// Return CK_Trivial when appropriate for VariadicDynCastAllOfMatchers.
@@ -610,6 +640,8 @@
     }
   }
 
+  ASTNodeKind nodeMatcherType() const override { return DerivedKind; }
+
 private:
   const ASTNodeKind DerivedKind;
 };
@@ -786,6 +818,15 @@
     return false;
   }
 
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
 private:
   std::vector<std::unique_ptr<MatcherDescriptor>> Overloads;
 };
@@ -856,6 +897,15 @@
                   ArgTypeTraits<llvm::Regex::RegexFlags>::get(Args[1].Value)));
   }
 
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
 private:
   ReturnType (*const WithFlags)(StringRef, llvm::Regex::RegexFlags);
   ReturnType (*const NoFlags)(StringRef);
@@ -918,6 +968,15 @@
 
   bool isPolymorphic() const override { return true; }
 
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
+
 private:
   const unsigned MinCount;
   const unsigned MaxCount;
@@ -987,6 +1046,15 @@
       *LeastDerivedKind = CladeNodeKind;
     return true;
   }
+
+  bool isBuilderMatcher() const override { return false; }
+  std::unique_ptr<MatcherDescriptor>
+  buildMatcherCtor(SourceRange NameRange, ArrayRef<ParserValue> Args,
+                   Diagnostics *Error) const override {
+    return {};
+  }
+
+  ASTNodeKind nodeMatcherType() const override { return ASTNodeKind(); }
 };
 
 /// Helper functions to select the appropriate marshaller functions.
Index: clang/include/clang/ASTMatchers/Dynamic/Parser.h
===================================================================
--- clang/include/clang/ASTMatchers/Dynamic/Parser.h
+++ clang/include/clang/ASTMatchers/Dynamic/Parser.h
@@ -231,9 +231,13 @@
          const NamedValueMap *NamedValues,
          Diagnostics *Error);
 
-  bool parseBindID(std::string &BindID);
+  bool parseBindID(TokenInfo BindToken, std::string &BindID);
   bool parseExpressionImpl(VariantValue *Value);
+  bool parseMatcherBuilder(MatcherCtor Ctor, const TokenInfo &NameToken,
+                           const TokenInfo &OpenToken, VariantValue *Value);
   bool parseMatcherExpressionImpl(const TokenInfo &NameToken,
+                                  const TokenInfo &OpenToken,
+                                  llvm::Optional<MatcherCtor> Ctor,
                                   VariantValue *Value);
   bool parseIdentifierPrefixImpl(VariantValue *Value);
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to