Author: Qinkun Bao Date: 2025-06-24T11:37:27-04:00 New Revision: c0aa1f007ad7d13e7e8c7949f4d7271f870c7f58
URL: https://github.com/llvm/llvm-project/commit/c0aa1f007ad7d13e7e8c7949f4d7271f870c7f58 DIFF: https://github.com/llvm/llvm-project/commit/c0aa1f007ad7d13e7e8c7949f4d7271f870c7f58.diff LOG: Revert "[mlir] Improve mlir-query by adding matcher combinators (#141423)" This reverts commit 12611a7fc71376e88aa01e3f0bbc74517f1a1703. Added: mlir/test/mlir-query/complex-test.mlir Modified: mlir/include/mlir/Query/Matcher/Marshallers.h mlir/include/mlir/Query/Matcher/MatchFinder.h mlir/include/mlir/Query/Matcher/MatchersInternal.h mlir/include/mlir/Query/Matcher/SliceMatchers.h mlir/include/mlir/Query/Matcher/VariantValue.h mlir/lib/Query/Matcher/CMakeLists.txt mlir/lib/Query/Matcher/RegistryManager.cpp mlir/lib/Query/Matcher/VariantValue.cpp mlir/lib/Query/Query.cpp mlir/tools/mlir-query/mlir-query.cpp Removed: mlir/lib/Query/Matcher/MatchersInternal.cpp mlir/test/mlir-query/backward-slice-union.mlir mlir/test/mlir-query/forward-slice-by-predicate.mlir mlir/test/mlir-query/logical-operator-test.mlir mlir/test/mlir-query/slice-function-extraction.mlir ################################################################################ diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h index 5fe6965f32efb..012bf7b9ec4a9 100644 --- a/mlir/include/mlir/Query/Matcher/Marshallers.h +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -108,9 +108,6 @@ class MatcherDescriptor { const llvm::ArrayRef<ParserValue> args, Diagnostics *error) const = 0; - // If the matcher is variadic, it can take any number of arguments. - virtual bool isVariadic() const = 0; - // Returns the number of arguments accepted by the matcher. virtual unsigned getNumArgs() const = 0; @@ -143,8 +140,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor { return marshaller(matcherFunc, matcherName, nameRange, args, error); } - bool isVariadic() const override { return false; } - unsigned getNumArgs() const override { return argKinds.size(); } void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override { @@ -158,54 +153,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor { const std::vector<ArgKind> argKinds; }; -class VariadicOperatorMatcherDescriptor : public MatcherDescriptor { -public: - using VarOp = DynMatcher::VariadicOperator; - VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount, - VarOp varOp, StringRef matcherName) - : minCount(minCount), maxCount(maxCount), varOp(varOp), - matcherName(matcherName) {} - - VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args, - Diagnostics *error) const override { - if (args.size() < minCount || maxCount < args.size()) { - addError(error, nameRange, ErrorType::RegistryWrongArgCount, - {llvm::Twine("requires between "), llvm::Twine(minCount), - llvm::Twine(" and "), llvm::Twine(maxCount), - llvm::Twine(" args, got "), llvm::Twine(args.size())}); - return VariantMatcher(); - } - - std::vector<VariantMatcher> innerArgs; - for (int64_t i = 0, e = args.size(); i != e; ++i) { - const ParserValue &arg = args[i]; - const VariantValue &value = arg.value; - if (!value.isMatcher()) { - addError(error, arg.range, ErrorType::RegistryWrongArgType, - {llvm::Twine(i + 1), llvm::Twine("matcher: "), - llvm::Twine(value.getTypeAsString())}); - return VariantMatcher(); - } - innerArgs.push_back(value.getMatcher()); - } - return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs)); - } - - bool isVariadic() const override { return true; } - - unsigned getNumArgs() const override { return 0; } - - void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override { - kinds.push_back(ArgKind(ArgKind::Matcher)); - } - -private: - const unsigned minCount; - const unsigned maxCount; - const VarOp varOp; - const StringRef matcherName; -}; - // Helper function to check if argument count matches expected count inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, llvm::ArrayRef<ParserValue> args, @@ -277,14 +224,6 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...), reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds); } -// Variadic operator overload. -template <unsigned MinCount, unsigned MaxCount> -std::unique_ptr<MatcherDescriptor> -makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func, - StringRef matcherName) { - return std::make_unique<VariadicOperatorMatcherDescriptor>( - MinCount, MaxCount, func.varOp, matcherName); -} } // namespace mlir::query::matcher::internal #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h index 6d06ca13d1344..f8abf20ef60bb 100644 --- a/mlir/include/mlir/Query/Matcher/MatchFinder.h +++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h @@ -21,9 +21,7 @@ namespace mlir::query::matcher { -/// Finds and collects matches from the IR. After construction -/// `collectMatches` can be used to traverse the IR and apply -/// matchers. +/// A class that provides utilities to find operations in the IR. class MatchFinder { public: diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h index 88109430b6feb..183b2514e109f 100644 --- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h +++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h @@ -8,11 +8,11 @@ // // Implements the base layer of the matcher framework. // -// Matchers are methods that return a Matcher which provides a -// `match(...)` method whose parameters define the context of the match. -// Support includes simple (unary) matchers as well as matcher combinators -// (anyOf, allOf, etc.) +// Matchers are methods that return a Matcher which provides a method one of the +// following methods: match(Operation *op), match(Operation *op, +// SetVector<Operation *> &matchedOps) // +// The matcher functions are defined in include/mlir/IR/Matchers.h. // This file contains the wrapper classes needed to construct matchers for // mlir-query. // @@ -25,15 +25,6 @@ #include "llvm/ADT/IntrusiveRefCntPtr.h" namespace mlir::query::matcher { -class DynMatcher; -namespace internal { - -bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps, - ArrayRef<DynMatcher> innerMatchers); -bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps, - ArrayRef<DynMatcher> innerMatchers); - -} // namespace internal // Defaults to false if T has no match() method with the signature: // match(Operation* op). @@ -93,27 +84,6 @@ class MatcherFnImpl : public MatcherInterface { MatcherFn matcherFn; }; -// VariadicMatcher takes a vector of Matchers and returns true if any Matchers -// match the given operation. -using VariadicOperatorFunction = bool (*)(Operation *op, - SetVector<Operation *> *matchedOps, - ArrayRef<DynMatcher> innerMatchers); - -template <VariadicOperatorFunction Func> -class VariadicMatcher : public MatcherInterface { -public: - VariadicMatcher(std::vector<DynMatcher> matchers) - : matchers(std::move(matchers)) {} - - bool match(Operation *op) override { return Func(op, nullptr, matchers); } - bool match(Operation *op, SetVector<Operation *> &matchedOps) override { - return Func(op, &matchedOps, matchers); - } - -private: - std::vector<DynMatcher> matchers; -}; - // Matcher wraps a MatcherInterface implementation and provides match() // methods that redirect calls to the underlying implementation. class DynMatcher { @@ -122,31 +92,6 @@ class DynMatcher { DynMatcher(MatcherInterface *implementation) : implementation(implementation) {} - // Construct from a variadic function. - enum VariadicOperator { - // Matches operations for which all provided matchers match. - AllOf, - // Matches operations for which at least one of the provided matchers - // matches. - AnyOf - }; - - static std::unique_ptr<DynMatcher> - constructVariadic(VariadicOperator Op, - std::vector<DynMatcher> innerMatchers) { - switch (Op) { - case AllOf: - return std::make_unique<DynMatcher>( - new VariadicMatcher<internal::allOfVariadicOperator>( - std::move(innerMatchers))); - case AnyOf: - return std::make_unique<DynMatcher>( - new VariadicMatcher<internal::anyOfVariadicOperator>( - std::move(innerMatchers))); - } - llvm_unreachable("Invalid Op value."); - } - template <typename MatcherFn> static std::unique_ptr<DynMatcher> constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) { @@ -168,59 +113,6 @@ class DynMatcher { std::string functionName; }; -// VariadicOperatorMatcher related types. -template <typename... Ps> -class VariadicOperatorMatcher { -public: - VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params) - : varOp(varOp), params(std::forward<Ps>(params)...) {} - - operator std::unique_ptr<DynMatcher>() const & { - return DynMatcher::constructVariadic( - varOp, getMatchers(std::index_sequence_for<Ps...>())); - } - - operator std::unique_ptr<DynMatcher>() && { - return DynMatcher::constructVariadic( - varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>())); - } - -private: - // Helper method to unpack the tuple into a vector. - template <std::size_t... Is> - std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & { - return {DynMatcher(std::get<Is>(params))...}; - } - - template <std::size_t... Is> - std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && { - return {DynMatcher(std::get<Is>(std::move(params)))...}; - } - - const DynMatcher::VariadicOperator varOp; - std::tuple<Ps...> params; -}; - -// Overloaded function object to generate VariadicOperatorMatcher objects from -// arbitrary matchers. -template <unsigned MinCount, unsigned MaxCount> -struct VariadicOperatorMatcherFunc { - DynMatcher::VariadicOperator varOp; - - template <typename... Ms> - VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const { - static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount, - "invalid number of parameters for variadic matcher"); - return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...); - } -}; - -namespace internal { -const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()> - anyOf = {DynMatcher::AnyOf}; -const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()> - allOf = {DynMatcher::AllOf}; -} // namespace internal } // namespace mlir::query::matcher #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h index 7181648f06f89..441205b3a9615 100644 --- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h +++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h @@ -6,8 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file defines slicing-analysis matchers that extend and abstract the -// core implementations from `SliceAnalysis.h`. +// This file provides matchers for MLIRQuery that peform slicing analysis // //===----------------------------------------------------------------------===// @@ -17,9 +16,9 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Operation.h" -/// Computes the backward-slice of all transitive defs reachable from `rootOp`, -/// if `innerMatcher` matches. The traversal stops once the desired depth level -/// is reached. +/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h. +/// Additionally, it limits the slice computation to a certain depth level using +/// a custom filter. /// /// Example: starting from node 9, assuming the matcher /// computes the slice for the first two depth levels: @@ -120,77 +119,6 @@ bool BackwardSliceMatcher<Matcher>::matches( : backwardSlice.size() >= 1; } -/// Computes the backward-slice of all transitive defs reachable from `rootOp`, -/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches. -template <typename BaseMatcher, typename Filter> -class PredicateBackwardSliceMatcher { -public: - PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, - bool inclusive, bool omitBlockArguments, - bool omitUsesFromAbove) - : innerMatcher(std::move(innerMatcher)), - filterMatcher(std::move(filterMatcher)), inclusive(inclusive), - omitBlockArguments(omitBlockArguments), - omitUsesFromAbove(omitUsesFromAbove) {} - - bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) { - backwardSlice.clear(); - BackwardSliceOptions options; - options.inclusive = inclusive; - options.omitUsesFromAbove = omitUsesFromAbove; - options.omitBlockArguments = omitBlockArguments; - if (innerMatcher.match(rootOp)) { - options.filter = [&](Operation *subOp) { - return !filterMatcher.match(subOp); - }; - LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options); - assert(result.succeeded() && "expected backward slice to succeed"); - (void)result; - return options.inclusive ? backwardSlice.size() > 1 - : backwardSlice.size() >= 1; - } - return false; - } - -private: - BaseMatcher innerMatcher; - Filter filterMatcher; - bool inclusive; - bool omitBlockArguments; - bool omitUsesFromAbove; -}; - -/// Computes the forward-slice of all users reachable from `rootOp`, -/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches. -template <typename BaseMatcher, typename Filter> -class PredicateForwardSliceMatcher { -public: - PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, - bool inclusive) - : innerMatcher(std::move(innerMatcher)), - filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {} - - bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) { - forwardSlice.clear(); - ForwardSliceOptions options; - options.inclusive = inclusive; - if (innerMatcher.match(rootOp)) { - options.filter = [&](Operation *subOp) { - return !filterMatcher.match(subOp); - }; - getForwardSlice(rootOp, &forwardSlice, options); - return options.inclusive ? forwardSlice.size() > 1 - : forwardSlice.size() >= 1; - } - return false; - } - -private: - BaseMatcher innerMatcher; - Filter filterMatcher; - bool inclusive; -}; - /// Matches transitive defs of a top-level operation up to N levels. template <typename Matcher> inline BackwardSliceMatcher<Matcher> @@ -202,7 +130,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive, omitUsesFromAbove); } -/// Matches all transitive defs of a top-level operation up to N levels. +/// Matches all transitive defs of a top-level operation up to N levels template <typename Matcher> inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher, int64_t maxDepth) { @@ -211,28 +139,6 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher, false, false); } -/// Matches all transitive defs of a top-level operation and stops where -/// `filterMatcher` rejects. -template <typename BaseMatcher, typename Filter> -inline PredicateBackwardSliceMatcher<BaseMatcher, Filter> -m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, - bool inclusive, bool omitBlockArguments, - bool omitUsesFromAbove) { - return PredicateBackwardSliceMatcher<BaseMatcher, Filter>( - std::move(innerMatcher), std::move(filterMatcher), inclusive, - omitBlockArguments, omitUsesFromAbove); -} - -/// Matches all users of a top-level operation and stops where -/// `filterMatcher` rejects. -template <typename BaseMatcher, typename Filter> -inline PredicateForwardSliceMatcher<BaseMatcher, Filter> -m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, - bool inclusive) { - return PredicateForwardSliceMatcher<BaseMatcher, Filter>( - std::move(innerMatcher), std::move(filterMatcher), inclusive); -} - } // namespace mlir::query::matcher #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h index 1a47576de1841..98c0a18e25101 100644 --- a/mlir/include/mlir/Query/Matcher/VariantValue.h +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -26,12 +26,7 @@ enum class ArgKind { Boolean, Matcher, Signed, String }; // A variant matcher object to abstract simple and complex matchers into a // single object type. class VariantMatcher { - class MatcherOps { - public: - std::optional<DynMatcher> - constructVariadicOperator(DynMatcher::VariadicOperator varOp, - ArrayRef<VariantMatcher> innerMatchers) const; - }; + class MatcherOps; // Payload interface to be specialized by each matcher type. It follows a // similar interface as VariantMatcher itself. @@ -48,9 +43,6 @@ class VariantMatcher { // Clones the provided matcher. static VariantMatcher SingleMatcher(DynMatcher matcher); - static VariantMatcher - VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, - ArrayRef<VariantMatcher> args); // Makes the matcher the "null" matcher. void reset(); @@ -69,7 +61,6 @@ class VariantMatcher { : value(std::move(value)) {} class SinglePayload; - class VariadicOpPayload; std::shared_ptr<const Payload> value; }; diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt index ba202762fdfbb..629479bf7adc1 100644 --- a/mlir/lib/Query/Matcher/CMakeLists.txt +++ b/mlir/lib/Query/Matcher/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_library(MLIRQueryMatcher MatchFinder.cpp - MatchersInternal.cpp Parser.cpp RegistryManager.cpp VariantValue.cpp diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp deleted file mode 100644 index 01f412ade846b..0000000000000 --- a/mlir/lib/Query/Matcher/MatchersInternal.cpp +++ /dev/null @@ -1,33 +0,0 @@ -//===--- MatchersInternal.cpp----------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Query/Matcher/MatchersInternal.h" -#include "llvm/ADT/SetVector.h" - -namespace mlir::query::matcher { - -namespace internal { - -bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps, - ArrayRef<DynMatcher> innerMatchers) { - return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) { - if (matchedOps) - return matcher.match(op, *matchedOps); - return matcher.match(op); - }); -} -bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps, - ArrayRef<DynMatcher> innerMatchers) { - return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) { - if (matchedOps) - return matcher.match(op, *matchedOps); - return matcher.match(op); - }); -} -} // namespace internal -} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp index 08b610453b11a..4b511c5f009e7 100644 --- a/mlir/lib/Query/Matcher/RegistryManager.cpp +++ b/mlir/lib/Query/Matcher/RegistryManager.cpp @@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes( unsigned argNumber = ctxEntry.second; std::vector<ArgKind> nextTypeSet; - if (ctor->isVariadic() || argNumber < ctor->getNumArgs()) + if (argNumber < ctor->getNumArgs()) ctor->getArgKinds(argNumber, nextTypeSet); typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); @@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, const internal::MatcherDescriptor &matcher = *m.getValue(); llvm::StringRef name = m.getKey(); - unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs(); + unsigned numArgs = matcher.getNumArgs(); std::vector<std::vector<ArgKind>> argKinds(numArgs); for (const ArgKind &kind : acceptedTypes) { @@ -115,9 +115,6 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes, } } - if (matcher.isVariadic()) - os << ",..."; - os << ")"; typedText += "("; diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp index 7bf4774dba830..1cb2d48f9d56f 100644 --- a/mlir/lib/Query/Matcher/VariantValue.cpp +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -27,64 +27,12 @@ class VariantMatcher::SinglePayload : public VariantMatcher::Payload { DynMatcher matcher; }; -class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload { -public: - VariadicOpPayload(DynMatcher::VariadicOperator varOp, - std::vector<VariantMatcher> args) - : varOp(varOp), args(std::move(args)) {} - - std::optional<DynMatcher> getDynMatcher() const override { - std::vector<DynMatcher> dynMatchers; - for (auto variantMatcher : args) { - std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher(); - if (dynMatcher) - dynMatchers.push_back(dynMatcher.value()); - } - auto result = DynMatcher::constructVariadic(varOp, dynMatchers); - return *result; - } - - std::string getTypeAsString() const override { - std::string inner; - llvm::interleave( - args, [&](auto const &arg) { inner += arg.getTypeAsString(); }, - [&] { inner += " & "; }); - return inner; - } - -private: - const DynMatcher::VariadicOperator varOp; - const std::vector<VariantMatcher> args; -}; - VariantMatcher::VariantMatcher() = default; VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) { return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher))); } -VariantMatcher -VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, - ArrayRef<VariantMatcher> args) { - return VariantMatcher( - std::make_shared<VariadicOpPayload>(varOp, std::move(args))); -} - -std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator( - DynMatcher::VariadicOperator varOp, - ArrayRef<VariantMatcher> innerMatchers) const { - std::vector<DynMatcher> dynMatchers; - for (const auto &innerMatcher : innerMatchers) { - if (!innerMatcher.value) - return std::nullopt; - std::optional<DynMatcher> inner = innerMatcher.value->getDynMatcher(); - if (!inner) - return std::nullopt; - dynMatchers.push_back(*inner); - } - return *DynMatcher::constructVariadic(varOp, dynMatchers); -} - std::optional<DynMatcher> VariantMatcher::getDynMatcher() const { return value ? value->getDynMatcher() : std::nullopt; } diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp index 637e1f3cdef87..803284d6df86a 100644 --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -10,7 +10,6 @@ #include "QueryParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/Verifier.h" #include "mlir/Query/Matcher/MatchFinder.h" #include "mlir/Query/QuerySession.h" #include "llvm/ADT/SetVector.h" @@ -69,8 +68,6 @@ static Operation *extractFunction(std::vector<Operation *> &ops, // Clone operations and build function body std::vector<Operation *> clonedOps; std::vector<Value> clonedVals; - // TODO: Handle extraction of operations with compute payloads defined via - // regions. for (Operation *slicedOp : slice) { Operation *clonedOp = clonedOps.emplace_back(builder.clone(*slicedOp, mapper)); @@ -132,8 +129,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { finder.flattenMatchedOps(matches); Operation *function = extractFunction(flattenedMatches, rootOp->getContext(), functionName); - if (failed(verify(function))) - return mlir::failure(); os << "\n" << *function << "\n\n"; function->erase(); return mlir::success(); diff --git a/mlir/test/mlir-query/backward-slice-union.mlir b/mlir/test/mlir-query/complex-test.mlir similarity index 71% rename from mlir/test/mlir-query/backward-slice-union.mlir rename to mlir/test/mlir-query/complex-test.mlir index f8f88c2043749..ad96f03747a43 100644 --- a/mlir/test/mlir-query/backward-slice-union.mlir +++ b/mlir/test/mlir-query/complex-test.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-query %s -c "m anyOf(getAllDefinitions(hasOpName(\"arith.addf\"),2),getAllDefinitions(hasOpName(\"tensor.extract\"),1))" | FileCheck %s +// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) { @@ -19,23 +19,14 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) } // CHECK: Match #1: + // CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) - -// CHECK: {{.*}}.mlir:7:10: note: "root" binds here // CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32 // CHECK: Match #2: -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32> -// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index -// CHECK: {{.*}}.mlir:14:18: note: "root" binds here -// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32> - -// CHECK: Match #3: // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32> // CHECK: %[[C2:.*]] = arith.constant {{.*}} : index // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32> - -// CHECK: {{.*}}.mlir:15:10: note: "root" binds here // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32 diff --git a/mlir/test/mlir-query/forward-slice-by-predicate.mlir b/mlir/test/mlir-query/forward-slice-by-predicate.mlir deleted file mode 100644 index e11378da89d9f..0000000000000 --- a/mlir/test/mlir-query/forward-slice-by-predicate.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: mlir-query %s -c "m getUsersByPredicate(anyOf(hasOpName(\"memref.alloc\"),isConstantOp()),anyOf(hasOpName(\"affine.load\"), hasOpName(\"memref.dealloc\")),true)" | FileCheck %s - -func.func @slice_depth1_loop_nest_with_offsets() { - %0 = memref.alloc() : memref<100xf32> - %cst = arith.constant 7.000000e+00 : f32 - affine.for %i0 = 0 to 16 { - %a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0) - affine.store %cst, %0[%a0] : memref<100xf32> - } - affine.for %i1 = 4 to 8 { - %a1 = affine.apply affine_map<(d0) -> (d0 - 1)>(%i1) - %1 = affine.load %0[%a1] : memref<100xf32> - } - return -} - -// CHECK: Match #1: -// CHECK: {{.*}}.mlir:4:8: note: "root" binds here -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<100xf32> - -// CHECK: affine.store %cst, %0[%a0] : memref<100xf32> - -// CHECK: Match #2: -// CHECK: {{.*}}.mlir:5:10: note: "root" binds here -// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 - -// CHECK: affine.store %[[CST]], %0[%a0] : memref<100xf32> diff --git a/mlir/test/mlir-query/logical-operator-test.mlir b/mlir/test/mlir-query/logical-operator-test.mlir deleted file mode 100644 index ac05428287abd..0000000000000 --- a/mlir/test/mlir-query/logical-operator-test.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: mlir-query %s -c "m allOf(hasOpName(\"memref.alloca\"), hasOpAttrName(\"alignment\"))" | FileCheck %s - -func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> { - %0 = memref.alloca(%arg0, %arg1) : memref<?x?xf32> - memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32> - return %0 : memref<?x?xf32> -} - -// CHECK: Match #1: -// CHECK: {{.*}}.mlir:5:3: note: "root" binds here -// CHECK: memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32> diff --git a/mlir/test/mlir-query/slice-function-extraction.mlir b/mlir/test/mlir-query/slice-function-extraction.mlir deleted file mode 100644 index e55d5e77c5736..0000000000000 --- a/mlir/test/mlir-query/slice-function-extraction.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: mlir-query %s -c "m getDefinitionsByPredicate(hasOpName(\"memref.store\"),hasOpName(\"memref.alloc\"),true,false,false).extract(\"backward_slice\")" | FileCheck %s - -// CHECK: func.func @backward_slice(%{{.*}}: memref<10xf32>) -> (f32, index, index, f32, index, index, f32) { -// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[I0:.*]] = affine.apply affine_map<()[s0] -> (s0)>()[%[[C0]]] -// CHECK-NEXT: memref.store %[[CST0]], %{{.*}}[%[[I0]]] : memref<10xf32> -// CHECK-NEXT: %[[CST2:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[I1:.*]] = affine.apply affine_map<() -> (0)>() -// CHECK-NEXT: memref.store %[[CST2]], %{{.*}}[%[[I1]]] : memref<10xf32> -// CHECK-NEXT: %[[C1:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<10xf32> -// CHECK-NEXT: memref.store %[[LOAD]], %{{.*}}[%[[C1]]] : memref<10xf32> -// CHECK-NEXT: return %[[CST0]], %[[C0]], %[[I0]], %[[CST2]], %[[I1]], %[[C1]], %[[LOAD]] : f32, index, index, f32, index, index, f32 - -func.func @slicing_memref_store_trivial() { - %0 = memref.alloc() : memref<10xf32> - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - affine.for %i1 = 0 to 10 { - %1 = affine.apply affine_map<()[s0] -> (s0)>()[%c0] - memref.store %cst, %0[%1] : memref<10xf32> - %2 = memref.load %0[%c0] : memref<10xf32> - %3 = affine.apply affine_map<()[] -> (0)>()[] - memref.store %cst, %0[%3] : memref<10xf32> - memref.store %2, %0[%c0] : memref<10xf32> - } - return -} diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp index 8a17a33c61838..78c0ec97c0cdf 100644 --- a/mlir/tools/mlir-query/mlir-query.cpp +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -40,22 +40,12 @@ int main(int argc, char **argv) { query::matcher::Registry matcherRegistry; // Matchers registered in alphabetical order for consistency: - matcherRegistry.registerMatcher("allOf", query::matcher::internal::allOf); - matcherRegistry.registerMatcher("anyOf", query::matcher::internal::anyOf); - matcherRegistry.registerMatcher( - "getAllDefinitions", - query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>); matcherRegistry.registerMatcher( "getDefinitions", query::matcher::m_GetDefinitions<query::matcher::DynMatcher>); matcherRegistry.registerMatcher( - "getDefinitionsByPredicate", - query::matcher::m_GetDefinitionsByPredicate<query::matcher::DynMatcher, - query::matcher::DynMatcher>); - matcherRegistry.registerMatcher( - "getUsersByPredicate", - query::matcher::m_GetUsersByPredicate<query::matcher::DynMatcher, - query::matcher::DynMatcher>); + "getAllDefinitions", + query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>); matcherRegistry.registerMatcher("hasOpAttrName", static_cast<HasOpAttrName *>(m_Attr)); matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op)); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits