ymandel updated this revision to Diff 199465.
ymandel edited the summary of this revision.
ymandel added a comment.

Response to comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D61335/new/

https://reviews.llvm.org/D61335

Files:
  clang/include/clang/Tooling/Refactoring/Transformer.h
  clang/lib/Tooling/Refactoring/Transformer.cpp
  clang/unittests/Tooling/TransformerTest.cpp

Index: clang/unittests/Tooling/TransformerTest.cpp
===================================================================
--- clang/unittests/Tooling/TransformerTest.cpp
+++ clang/unittests/Tooling/TransformerTest.cpp
@@ -116,7 +116,8 @@
     };
   }
 
-  void testRule(RewriteRule Rule, StringRef Input, StringRef Expected) {
+  template <typename R>
+  void testRule(R Rule, StringRef Input, StringRef Expected) {
     Transformer T(std::move(Rule), consumer());
     T.registerMatchers(&MatchFinder);
     compareSnippets(Expected, rewrite(Input));
@@ -147,7 +148,7 @@
                                          .bind(StringExpr)),
                                   callee(cxxMethodDecl(hasName("c_str")))))),
       change<clang::Expr>("REPLACED"));
-  R.Explanation = text("Use size() method directly on string.");
+  R.Cases[0].Explanation = text("Use size() method directly on string.");
   return R;
 }
 
@@ -375,6 +376,92 @@
            Input, Expected);
 }
 
+TEST_F(TransformerTest, OrderedRuleUnrelated) {
+  StringRef Flag = "flag";
+  RewriteRule FlagRule = makeRule(
+      cxxMemberCallExpr(on(expr(hasType(cxxRecordDecl(
+                                    hasName("proto::ProtoCommandLineFlag"))))
+                               .bind(Flag)),
+                        unless(callee(cxxMethodDecl(hasName("GetProto"))))),
+      change<clang::Expr>(Flag, "PROTO"));
+
+  std::string Input = R"cc(
+    proto::ProtoCommandLineFlag flag;
+    int x = flag.foo();
+    int y = flag.GetProto().foo();
+    int f(string s) { return strlen(s.c_str()); }
+  )cc";
+  std::string Expected = R"cc(
+    proto::ProtoCommandLineFlag flag;
+    int x = PROTO.foo();
+    int y = flag.GetProto().foo();
+    int f(string s) { return REPLACED; }
+  )cc";
+
+  testRule(applyFirst({ruleStrlenSize(), FlagRule}), Input, Expected);
+}
+
+// Version of ruleStrlenSizeAny that inserts a method with a different name than
+// ruleStrlenSize, so we can tell their effect apart.
+RewriteRule ruleStrlenSizeDistinct() {
+  StringRef S;
+  return makeRule(
+      callExpr(callee(functionDecl(hasName("strlen"))),
+               hasArgument(0, cxxMemberCallExpr(
+                                  on(expr().bind(S)),
+                                  callee(cxxMethodDecl(hasName("c_str")))))),
+      change<clang::Expr>("DISTINCT"));
+}
+
+TEST_F(TransformerTest, OrderedRuleRelated) {
+  std::string Input = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return strlen(s.c_str()); }
+    }  // namespace foo
+    int g(string s) { return strlen(s.c_str()); }
+  )cc";
+  std::string Expected = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return DISTINCT; }
+    }  // namespace foo
+    int g(string s) { return REPLACED; }
+  )cc";
+
+  testRule(applyFirst({ruleStrlenSize(), ruleStrlenSizeDistinct()}), Input,
+           Expected);
+}
+
+// Change the order of the rules to get a different result.
+TEST_F(TransformerTest, OrderedRuleRelatedSwapped) {
+  std::string Input = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return strlen(s.c_str()); }
+    }  // namespace foo
+    int g(string s) { return strlen(s.c_str()); }
+  )cc";
+  std::string Expected = R"cc(
+    namespace foo {
+    struct mystring {
+      char* c_str();
+    };
+    int f(mystring s) { return DISTINCT; }
+    }  // namespace foo
+    int g(string s) { return DISTINCT; }
+  )cc";
+
+  testRule(applyFirst({ruleStrlenSizeDistinct(), ruleStrlenSize()}), Input,
+           Expected);
+}
+
 //
 // Negative tests (where we expect no transformation to occur).
 //
Index: clang/lib/Tooling/Refactoring/Transformer.cpp
===================================================================
--- clang/lib/Tooling/Refactoring/Transformer.cpp
+++ clang/lib/Tooling/Refactoring/Transformer.cpp
@@ -28,6 +28,7 @@
 using namespace tooling;
 
 using ast_matchers::MatchFinder;
+using ast_matchers::internal::DynTypedMatcher;
 using ast_type_traits::ASTNodeKind;
 using ast_type_traits::DynTypedNode;
 using llvm::Error;
@@ -171,18 +172,113 @@
   return Transformations;
 }
 
-RewriteRule tooling::makeRule(ast_matchers::internal::DynTypedMatcher M,
+RewriteRule tooling::makeRule(DynTypedMatcher M,
                               SmallVector<ASTEdit, 1> Edits) {
+  return RewriteRule{
+      {RewriteRule::Case{std::move(M), std::move(Edits), nullptr}}};
+}
+
+// Determines whether A is a base type of B in the class hierarchy, including
+// the implicit relationship of Type and QualType.
+static bool isBaseOf(ASTNodeKind A, ASTNodeKind B) {
+  static auto TypeKind = ASTNodeKind::getFromNodeKind<Type>();
+  static auto QualKind = ASTNodeKind::getFromNodeKind<QualType>();
+  /// Mimic the implicit conversions of Matcher<>.
+  /// - From Matcher<Type> to Matcher<QualType>
+  /// - From Matcher<Base> to Matcher<Derived>
+  return (A.isSame(TypeKind) && B.isSame(QualKind)) || A.isBaseOf(B);
+}
+
+// Try to find a common kind to which all of the rule's matchers can be
+// converted.
+static ASTNodeKind
+findCommonKind(const SmallVectorImpl<RewriteRule::Case> &Cases) {
+  assert(!Cases.empty() && "Rule must have at least one case.");
+  ASTNodeKind JoinKind = Cases[0].Matcher.getSupportedKind();
+  // Find a (least) Kind K, for which M.canConvertTo(K) holds, for all matchers
+  // M in Rules.
+  for (const auto &Case : Cases) {
+    auto K = Case.Matcher.getSupportedKind();
+    if (isBaseOf(JoinKind, K)) {
+      JoinKind = K;
+      continue;
+    }
+    if (K.isSame(JoinKind) || isBaseOf(K, JoinKind))
+      // JoinKind is already the lowest.
+      continue;
+    // K and JoinKind are unrelated -- there is no least common kind.
+    return ASTNodeKind();
+  }
+  return JoinKind;
+}
+
+// Binds each rule's matcher to a unique (and deterministic) tag based on
+// `TagBase`.
+static std::vector<DynTypedMatcher>
+taggedMatchers(StringRef TagBase,
+               const SmallVectorImpl<RewriteRule::Case> &Cases) {
+  std::vector<DynTypedMatcher> Matchers;
+  Matchers.reserve(Cases.size());
+  size_t count = 0;
+  for (const auto &Case : Cases) {
+    std::string Tag = (TagBase + Twine(count)).str();
+    ++count;
+    auto M = Case.Matcher.tryBind(Tag);
+    assert(M && "RewriteRule matchers should be bindable.");
+    Matchers.push_back(*std::move(M));
+  }
+  return Matchers;
+}
+
+// Simply gathers the contents of the various rules into a single rule. The
+// actual work to combine these into an ordered choice is deferred to matcher
+// registration.
+RewriteRule tooling::appyFirst(ArrayRef<RewriteRule> Rules) {
+  RewriteRule R;
+  for (auto &Rule : Rules) {
+    R.Cases.append(Rule.Cases.begin(), Rule.Cases.end());
+  }
+  return R;
+}
+
+static DynTypedMatcher joinCaseMatchers(const RewriteRule &Rule) {
+  assert(!Rule.Cases.empty() && "Rule must have at least one case.");
+  if (Rule.Cases.size() == 1)
+    return Rule.Cases[0].Matcher;
+
+  auto CommonKind = findCommonKind(Rule.Cases);
+  assert(!CommonKind.isNone() && "Cases must have compatible matchers.");
+  return DynTypedMatcher::constructVariadic(
+      DynTypedMatcher::VO_AnyOf, CommonKind, taggedMatchers("Tag", Rule.Cases));
+}
+
+DynTypedMatcher tooling::buildMatcher(const RewriteRule &Rule) {
+  DynTypedMatcher M = joinCaseMatchers(Rule);
   M.setAllowBind(true);
   // `tryBind` is guaranteed to succeed, because `AllowBind` was set to true.
-  return RewriteRule{*M.tryBind(RewriteRule::RootId), std::move(Edits),
-                     nullptr};
+  return *M.tryBind(RewriteRule::RootId);
+}
+
+// Finds the rule that was "selected" -- that is, whose matcher triggered the
+// `MatchResult`.
+const RewriteRule::Case &tooling::findSelectedCase(const MatchResult &Result,
+                                                   const RewriteRule &Rule) {
+  if (Rule.Cases.size() == 1)
+    return Rule.Cases[0];
+
+  auto &NodesMap = Result.Nodes.getMap();
+  for (size_t i = 0, N = Rule.Cases.size(); i < N; ++i) {
+    std::string Tag = ("Tag" + Twine(i)).str();
+    if (NodesMap.find(Tag) != NodesMap.end())
+      return Rule.Cases[i];
+  }
+  llvm_unreachable("No tag found for this rule.");
 }
 
 constexpr llvm::StringLiteral RewriteRule::RootId;
 
 void Transformer::registerMatchers(MatchFinder *MatchFinder) {
-  MatchFinder->addDynamicMatcher(Rule.Matcher, this);
+  MatchFinder->addDynamicMatcher(buildMatcher(Rule), this);
 }
 
 void Transformer::run(const MatchResult &Result) {
@@ -197,7 +293,8 @@
       Root->second.getSourceRange().getBegin());
   assert(RootLoc.isValid() && "Invalid location for Root node of match.");
 
-  auto Transformations = translateEdits(Result, Rule.Edits);
+  auto Transformations =
+      translateEdits(Result, findSelectedCase(Result, Rule).Edits);
   if (!Transformations) {
     Consumer(Transformations.takeError());
     return;
Index: clang/include/clang/Tooling/Refactoring/Transformer.h
===================================================================
--- clang/include/clang/Tooling/Refactoring/Transformer.h
+++ clang/include/clang/Tooling/Refactoring/Transformer.h
@@ -145,8 +145,8 @@
 
 /// Description of a source-code transformation.
 //
-// A *rewrite rule* describes a transformation of source code. It has the
-// following components:
+// A *rewrite rule* describes a transformation of source code. A simple rule
+// contains each of the following components:
 //
 // * Matcher: the pattern term, expressed as clang matchers (with Transformer
 //   extensions).
@@ -156,30 +156,31 @@
 // * Explanation: explanation of the rewrite.  This will be displayed to the
 //   user, where possible; for example, in clang-tidy diagnostics.
 //
-// Rules have an additional, implicit, component: the parameters. These are
-// portions of the pattern which are left unspecified, yet named so that we can
-// reference them in the replacement term.  The structure of parameters can be
-// partially or even fully specified, in which case they serve just to identify
-// matched nodes for later reference rather than abstract over portions of the
-// AST.  However, in all cases, we refer to named portions of the pattern as
-// parameters.
+// However, rules can also consist of a multiple (sub)rules, so the above
+// components are gathered as a `Case` and rules are defined as an ordered list
+// of cases.
 //
-// The \c Transformer class should be used to apply the rewrite rule and obtain
-// the corresponding replacements.
+// Rule cases have an additional, implicit, component: the parameters. These are
+// portions of the pattern which are left unspecified, yet bound in the pattern
+// so that we can reference them in the edits.
+//
+// The \c Transformer class can be used to apply the rewrite rule and obtain the
+// corresponding replacements.
 struct RewriteRule {
-  // `Matcher` describes the context of this rule. It should always be bound to
-  // at least `RootId`.
-  ast_matchers::internal::DynTypedMatcher Matcher;
-  SmallVector<ASTEdit, 1> Edits;
-  TextGenerator Explanation;
+  struct Case {
+    ast_matchers::internal::DynTypedMatcher Matcher;
+    SmallVector<ASTEdit, 1> Edits;
+    TextGenerator Explanation;
+  };
+  // We expect RewriteRules will most commonly include only one case.
+  SmallVector<Case, 1> Cases;
 
   // Id used as the default target of each match. The node described by the
   // matcher is should always be bound to this id.
   static constexpr llvm::StringLiteral RootId = "___root___";
 };
 
-/// Convenience function for constructing a \c RewriteRule. Takes care of
-/// binding the matcher to RootId.
+/// Convenience function for constructing a simple \c RewriteRule.
 RewriteRule makeRule(ast_matchers::internal::DynTypedMatcher M,
                      SmallVector<ASTEdit, 1> Edits);
 
@@ -191,12 +192,66 @@
   return makeRule(std::move(M), std::move(Edits));
 }
 
+/// Joins multiple rules into a single rule that applies the first rule in
+/// `Rules` whose pattern matches a given node.
+///
+/// N.B. Due to a technical restriction, all of the rules must use the same kind
+/// of matcher (that is, share a base class in the AST hierarchy). The plan is
+/// to remove this restriction in the future.
+//
+// For example, consider a type `T` with a deterministic serialization function,
+// `serialize()`. For performance reasons, we would like to make it
+// non-deterministic.  Therefore, we want to drop the expectation that
+// `a.serialize() = b.serialize() iff a = b` (although we'll maintain
+// `deserialize(a.serialize()) = a`).
+//
+// We have three cases to consider (for some equality function, `eq`):
+// ```
+// eq(a.serialize(), b.serialize()) --> eq(a,b)
+// eq(a, b.serialize())             --> eq(deserialize(a), b)
+// eq(a.serialize(), b)             --> eq(a, deserialize(b))
+// ```
+//
+// Ordered rules allow us to specify each independently:
+// ```
+// auto eq_fun = functionDecl(...);
+// auto method_call = cxxMemberCallExpr(...);
+//
+// auto two_calls = callExpr(callee(eq_fun), hasArgument(0, method_call),
+//                           hasArgument(1, method_call));
+// auto left_call =
+//     callExpr(callee(eq_fun), callExpr(hasArgument(0, method_call)));
+// auto right_call =
+//     callExpr(callee(eq_fun), callExpr(hasArgument(1, method_call)));
+//
+// RewriteRule R = applyFirst({makeRule(two_calls, two_calls_action),
+//                             makeRule(left_call, left_call_action),
+//                             makeRule(right_call, right_call_action)});
+// ```
+// More generally, anywhere you'd use anyOf(m1.bind("m1"), m2.bind("m2")) and
+// then dispatch on those tags in your code to decide what to do, we'll lift
+// that behavior to the rule level, so you can write
+// `applyFirst({makeRule(m1, action1), makeRule(m2, action2), ...});`
+RewriteRule applyFirst(ArrayRef<RewriteRule> Rules);
+
 // Define this overload of `change` here because RewriteRule::RootId is not in
 // scope at the declaration point above.
 template <typename T> ASTEdit change(TextGenerator Replacement) {
   return change<T>(RewriteRule::RootId, NodePart::Node, std::move(Replacement));
 }
 
+/// The following three functions are a low-level part of the API. We expose
+/// them for use in implementing the fixtures that interpret RewriteRule, like
+/// Transformer and TransfomerTidy, or for more advanced users.
+
+/// Builds the matcher needed for registration.
+ast_matchers::internal::DynTypedMatcher buildMatcher(const RewriteRule &Rule);
+
+/// Returns the \c Case of \c Rule that was selected in the match result.
+const RewriteRule::Case &
+findSelectedCase(const ast_matchers::MatchFinder::MatchResult &Result,
+                 const RewriteRule &Rule);
+
 /// A source "transformation," represented by a character range in the source to
 /// be replaced and a corresponding replacement string.
 struct Transformation {
@@ -206,9 +261,7 @@
 
 /// Attempts to translate `Edits`, which are in terms of AST nodes bound in the
 /// match `Result`, into Transformations, which are in terms of the source code
-/// text.  This function is a low-level part of the API, provided to support
-/// interpretation of a \c RewriteRule in a tool, like \c Transformer, rather
-/// than direct use by end users.
+/// text.
 ///
 /// Returns an empty vector if any of the edits apply to portions of the source
 /// that are ineligible for rewriting (certain interactions with macros, for
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to