Author: Rob Suderman Date: 2021-03-17T17:28:18-07:00 New Revision: f4bb076a4419767cf35a17e3c08f392505a5acd2
URL: https://github.com/llvm/llvm-project/commit/f4bb076a4419767cf35a17e3c08f392505a5acd2 DIFF: https://github.com/llvm/llvm-project/commit/f4bb076a4419767cf35a17e3c08f392505a5acd2.diff LOG: [mlir][tosa] Add tosa.slice to std.subtensor lowering Lowering to subtensor is added for tosa.slice operator. Differential Revision: https://reviews.llvm.org/D98825 Added: Modified: mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir Removed: ################################################################################ diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp index 21a8da291aee..6e5411dd5ecb 100644 --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -32,9 +32,28 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> { } }; +class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> { +public: + using OpRewritePattern<tosa::SliceOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const final { + Value input = sliceOp.input(); + SmallVector<int64_t> strides; + strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1); + + rewriter.replaceOpWithNewOp<SubTensorOp>( + sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}), + ValueRange({}), sliceOp.start(), sliceOp.size(), + rewriter.getI64ArrayAttr(strides)); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToStandardConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert<ConstOpConverter>(context); + patterns->insert<ConstOpConverter, SliceOpConverter>(context); } diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp index 225855e78bda..78a0e65da81b 100644 --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -32,7 +32,8 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addIllegalOp<tosa::ConstOp>(); - target.addLegalOp<ConstantOp>(); + target.addIllegalOp<tosa::SliceOp>(); + target.addLegalDialect<StandardOpsDialect>(); auto *op = getOperation(); mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(), diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir index 86304dcba862..94925aec15c7 100644 --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -8,3 +8,11 @@ func @const_test() -> (tensor<i32>) { // CHECK: return [[C3]] return %0 : tensor<i32> } + +// ---- + +func @slice(%arg0: tensor<6xf32>) ->() { + // CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1] + %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>) + return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits