Author: Christian Sigg Date: 2020-12-08T16:44:51+01:00 New Revision: 02c9050155dff70497b3423ae95ed7d2ab7675a8
URL: https://github.com/llvm/llvm-project/commit/02c9050155dff70497b3423ae95ed7d2ab7675a8 DIFF: https://github.com/llvm/llvm-project/commit/02c9050155dff70497b3423ae95ed7d2ab7675a8.diff LOG: [mlir] Tighten access of RewritePattern methods. In RewritePattern, only expose `matchAndRewrite` as a public function. `match` can be protected (but needs to be protected because we want to call it from an override of `matchAndRewrite`). `rewrite` can be private. For classes deriving from RewritePattern, all 3 functions can be private. Side note: I didn't understand the need for the `using RewritePattern::matchAndRewrite` in derived classes, and started poking around. They are gone now, and I think the result is (only very slightly) cleaner. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D92670 Added: Modified: mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h mlir/include/mlir/IR/PatternMatch.h mlir/include/mlir/Transforms/DialectConversion.h Removed: ################################################################################ diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index bf41f29749de..5b605c165be6 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -571,11 +571,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { &typeConverter.getContext(), typeConverter, benefit) {} - /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast<SourceOp>(op), operands, rewriter); - } +private: + /// Wrappers around the ConversionPattern methods that pass the derived op + /// type. LogicalResult match(Operation *op) const final { return match(cast<SourceOp>(op)); } @@ -584,6 +582,10 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast<SourceOp>(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -603,10 +605,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { } return failure(); } - -private: - using ConvertToLLVMPattern::match; - using ConvertToLLVMPattern::matchAndRewrite; }; namespace LLVM { @@ -636,6 +634,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>; +private: /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult @@ -655,6 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; +private: LogicalResult matchAndRewrite(SourceOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 0bbb2216ee7b..1739cfa4a80c 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -156,17 +156,6 @@ class RewritePattern : public Pattern { public: virtual ~RewritePattern() {} - /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern, generating any new operations with the specified - /// builder. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; - - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). - virtual LogicalResult match(Operation *op) const; - /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this /// function will automatically perform the rewrite. @@ -183,6 +172,18 @@ class RewritePattern : public Pattern { /// Inherit the base constructors from `Pattern`. using Pattern::Pattern; + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). + virtual LogicalResult match(Operation *op) const; + +private: + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// builder. If an unexpected error is encountered (an internal + /// compiler error), it is emitted through the normal MLIR diagnostic + /// hooks and the IR is left in a valid state. + virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; + /// An anchor for the virtual table. virtual void anchor(); }; @@ -190,12 +191,15 @@ class RewritePattern : public Pattern { /// OpRewritePattern is a wrapper around RewritePattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. -template <typename SourceOp> struct OpRewritePattern : public RewritePattern { +template <typename SourceOp> +class OpRewritePattern : public RewritePattern { +public: /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) : RewritePattern(SourceOp::getOperationName(), benefit, context) {} +private: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, PatternRewriter &rewriter) const final { rewrite(cast<SourceOp>(op), rewriter); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index e02cf8fe4c0a..ecbb653f7ed9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -313,6 +313,30 @@ class TypeConverter { /// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + +protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. + using RewritePattern::RewritePattern; + /// Construct a conversion pattern that matches an operation with the given + /// root name. This constructor allows for providing a type converter to use + /// within the pattern. + ConversionPattern(StringRef rootName, PatternBenefit benefit, + TypeConverter &typeConverter, MLIRContext *ctx) + : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} + /// Construct a conversion pattern that matches any operation type. This + /// constructor allows for providing a type converter to use within the + /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, + MatchAnyOpTypeTag tag) + : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} + +private: /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of the rewritten /// operand values that are passed to `op`, `rewriter` can be used to emit the @@ -323,6 +347,10 @@ class ConversionPattern : public RewritePattern { llvm_unreachable("unimplemented rewrite"); } + void rewrite(Operation *op, PatternRewriter &rewriter) const final { + llvm_unreachable("never called"); + } + /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, @@ -337,42 +365,17 @@ class ConversionPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; - /// Return the type converter held by this pattern, or nullptr if the pattern - /// does not require type conversion. - TypeConverter *getTypeConverter() const { return typeConverter; } - -protected: - /// See `RewritePattern::RewritePattern` for information on the other - /// available constructors. - using RewritePattern::RewritePattern; - /// Construct a conversion pattern that matches an operation with the given - /// root name. This constructor allows for providing a type converter to use - /// within the pattern. - ConversionPattern(StringRef rootName, PatternBenefit benefit, - TypeConverter &typeConverter, MLIRContext *ctx) - : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} - /// Construct a conversion pattern that matches any operation type. This - /// constructor allows for providing a type converter to use within the - /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" - /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should - /// always be supplied here. - ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, - MatchAnyOpTypeTag tag) - : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} - protected: /// An optional type converter for use by this pattern. TypeConverter *typeConverter = nullptr; - -private: - using RewritePattern::rewrite; }; /// OpConversionPattern is a wrapper around ConversionPattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. template <typename SourceOp> -struct OpConversionPattern : public ConversionPattern { +class OpConversionPattern : public ConversionPattern { +public: OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -380,6 +383,7 @@ struct OpConversionPattern : public ConversionPattern { : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, context) {} +private: /// Wrappers around the ConversionPattern methods that pass the derived op /// type. void rewrite(Operation *op, ArrayRef<Value> operands, @@ -409,9 +413,6 @@ struct OpConversionPattern : public ConversionPattern { rewrite(op, operands, rewriter); return success(); } - -private: - using ConversionPattern::matchAndRewrite; }; /// Add a pattern to the given pattern list to convert the signature of a FuncOp _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits