https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/67808
>From a760c42f0c8b75361f822e1efcbdae30151b2180 Mon Sep 17 00:00:00 2001 From: Aviad Cohen <aviadcoh...@gmail.com> Date: Fri, 29 Sep 2023 15:32:18 +0300 Subject: [PATCH] [mlir][memref]: Add expand/collapse rewrite pattern to MemRef::CopyOp This pattern is useful to adjust the memref copy ranks. --- .../MemRef/Transforms/ExpandCollapseCopyOps.h | 45 ++++ .../Dialect/MemRef/Transforms/CMakeLists.txt | 1 + .../Transforms/ExpandCollapseCopyOps.cpp | 238 ++++++++++++++++++ .../Transforms/expand-collapse-copy-ops.mlir | 141 +++++++++++ mlir/test/lib/Dialect/MemRef/CMakeLists.txt | 1 + .../MemRef/TestExpandCollapseCopyOps.cpp | 66 +++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 7 files changed, 494 insertions(+) create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp create mode 100644 mlir/test/Transforms/expand-collapse-copy-ops.mlir create mode 100644 mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h new file mode 100644 index 000000000000000..27a69ab93e42c74 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h @@ -0,0 +1,45 @@ +//===-- ExpandCollapseCopyOps.h - Expand/Collapse MemRef copy ranks --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Patterns for expand collapse MemRef copies. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_ +#define MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_ + +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include <functional> + +namespace mlir { +class MLIRContext; +class RewritePatternSet; + +namespace memref { + +typedef std::function<bool(memref::CopyOp)> ExpandCollapseFuncCB; +inline bool expandCollapseAny([[maybe_unused]] memref::CopyOp copyOp) { + return true; +} + +/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks +/// if a `memref::CopyOp` should be expanded/collapsed into `minRank` +/// `maxRank` ranks. A selective callback may be provided to distinguish +/// which operations should be expanded/collapsed. +/// In some cases (i.e. the source/target are strided in whole dims), +/// it will not be possible to expanded/collapsed the `memref::CopyOp`. + +void populateExpandCollapseCopyOpsPatterns( + RewritePatternSet &patterns, unsigned minRank = 1, unsigned maxRank = 1, + ExpandCollapseFuncCB funcCB = expandCollapseAny); + +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_ diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index b16c281c93640ea..924feca4cad3012 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms AllocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp ComposeSubView.cpp + ExpandCollapseCopyOps.cpp ExpandOps.cpp ExpandRealloc.cpp ExpandStridedMetadata.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp new file mode 100644 index 000000000000000..7905254e71e19fc --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp @@ -0,0 +1,238 @@ +//===- ExpandCollapseCopyOps.cpp - Expand/Collapse rank of source/target copies +//-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------------===// +// +// This file contains rewrite patterns (transformations) to expand/collapse +// MemRef copies. This is useful in architecture which have limitations on +// dimensions of the copy operation. +// +//===--------------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include <numeric> + +#define DEBUG_TYPE "expand-collapse-copy-ops" + +using namespace mlir; + +#ifndef NDEBUG +static inline std::string shape_to_string(ArrayRef<int64_t> shape); +#endif // NDEBUG + +namespace { +/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks +/// if a `memref::CopyOp` should be expanded/collapsed into `minRank` +/// `maxRank` ranks. A selective callback may be provided to distinguish +/// which operations should be expanded/collapsed. +/// In some cases (i.e. the source/target are strided in each dim), +/// it will not be possible to expand/collapse the `memref::CopyOp`. + +struct ExpandCollapseCopyOpConverter : public OpRewritePattern<memref::CopyOp> { +public: + using OpRewritePattern::OpRewritePattern; + + ExpandCollapseCopyOpConverter(MLIRContext *context, unsigned minRank, + unsigned maxRank, + memref::ExpandCollapseFuncCB funcCB) + : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/1), + minRank(minRank), maxRank(maxRank), funcCB(funcCB) { + assert(minRank <= maxRank && "invalid ranks range"); + } + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const final { + MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType()); + unsigned rank = memRefType.getRank(); + + if (!funcCB(copyOp)) { + LLVM_DEBUG(llvm::dbgs() + << "Skip rewriting " << copyOp << ", filtered by funcCB\n"); + return failure(); + } else if (rank >= minRank && rank <= maxRank) { + LLVM_DEBUG(llvm::dbgs() + << "Skip rewriting " << copyOp + << ", operation does not need to expand/collapse\n"); + return failure(); + } + + if (rank > maxRank) { + return collapseCopyOpRank(copyOp, maxRank, rewriter); + } else { + assert(rank < minRank); + expandCopyOpRank(copyOp, minRank, rewriter); + // Expand is always successful. + return success(); + } + } + +private: + unsigned minRank; + unsigned maxRank; + // Accept callback to select which `memref::CopyOp` to collapse/expand. + memref::ExpandCollapseFuncCB funcCB; + + // Expand the `copyOp` source/target dims to newRank by + // adding new dims in size of `1`. + void expandCopyOpRank(memref::CopyOp copyOp, unsigned newRank, + PatternRewriter &rewriter) const; + // Collapse the `copyOp` source/target dims to newRank. + // The function tries to collapse starting from the most inner dims + // to the most outer dims. + // This function return failure if there are no dims to collapse. + LogicalResult collapseCopyOpRank(memref::CopyOp copyOp, unsigned newRank, + PatternRewriter &rewriter) const; + // Fill `collapsedShape` with a shape in size of `newRank`. + // The function tries to collapse starting from the most inner dims + // to the most outer dims of `memrefToCollapse`. + // This function return failure if there are no dims to collapse. + LogicalResult getCollapsedShape(MemRefType memrefToCollapse, unsigned newRank, + SmallVector<int64_t> &collapsedShape) const; +}; + +} // namespace + +void ExpandCollapseCopyOpConverter::expandCopyOpRank( + memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const { + MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType()); + + // New outer most dims will be 1s, rest dims are same as original shape. + auto shape = memRefType.getShape(); + SmallVector<int64_t> newShape(newRank - memRefType.getRank(), 1); + newShape.insert(newShape.end(), shape.begin(), shape.end()); + +#ifdef NDEBUG + LLVM_DEBUG(llvm::dbgs() << "Expanding shape " << shape_to_string(shape) + << " to " << shape_to_string(newShape) << "\n"); +#endif // NDEBUG + + // Expand reassociation is the same as collapse with opposing source/target + // shapes. + std::optional<SmallVector<ReassociationIndices>> reassociation = + getReassociationIndicesForCollapse(newShape, shape); + assert(reassociation && "expected reassociation to be valid for expand"); + + rewriter.setInsertionPoint(copyOp); + Value expandShapeSrc = rewriter.create<memref::ExpandShapeOp>( + copyOp.getLoc(), newShape, copyOp.getSource(), *reassociation); + Value expandShapeTarget = rewriter.create<memref::ExpandShapeOp>( + copyOp.getLoc(), newShape, copyOp.getTarget(), *reassociation); + + rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, expandShapeSrc, + expandShapeTarget); +} + +LogicalResult ExpandCollapseCopyOpConverter::collapseCopyOpRank( + memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const { + MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType()); + + auto shape = memRefType.getShape(); + SmallVector<int64_t> collapsedShape; + if (failed(getCollapsedShape(memRefType, newRank, collapsedShape))) + return failure(); + + std::optional<SmallVector<ReassociationIndices>> reassociation = + getReassociationIndicesForCollapse(shape, collapsedShape); + assert(reassociation && "expected reassociation to be valid for collapse"); + + rewriter.setInsertionPoint(copyOp); + Value collapseShapeSrc = rewriter.create<memref::CollapseShapeOp>( + copyOp.getLoc(), copyOp.getSource(), *reassociation); + Value collapseShapeTarget = rewriter.create<memref::CollapseShapeOp>( + copyOp.getLoc(), copyOp.getTarget(), *reassociation); + + rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, collapseShapeSrc, + collapseShapeTarget); + + return success(); +} + +LogicalResult ExpandCollapseCopyOpConverter::getCollapsedShape( + MemRefType memrefToCollapse, unsigned newRank, + SmallVector<int64_t> &collapsedShape) const { + auto shape = memrefToCollapse.getShape(); + auto rank = memrefToCollapse.getRank(); + int dimsToCollapse = rank - newRank; + assert(dimsToCollapse > 0); + + // Try to find `dimsToCollapse` dims we can collapse, starting with most inner + // dim to collapse. + for (int firstDimToCollapse = rank - dimsToCollapse - 1; + firstDimToCollapse >= 0; --firstDimToCollapse) { + SmallVector<int64_t> newShape; + + unsigned collapsedDims = + std::accumulate(shape.begin() + firstDimToCollapse, + shape.begin() + firstDimToCollapse + dimsToCollapse + 1, + 1, std::multiplies<unsigned>()); + + // Generate new shape in `newRank` size. All collapse dims we be to set + // `collapsedDims`. + for (int i = 0; i < rank; ++i) { + if (i == firstDimToCollapse) + newShape.push_back(collapsedDims); + else if (i < firstDimToCollapse || + i > firstDimToCollapse + dimsToCollapse) + newShape.push_back(shape[i]); + } + assert(newShape.size() == newRank); + assert(std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies<unsigned>()) == + std::accumulate(newShape.begin(), newShape.end(), 1, + std::multiplies<unsigned>())); + +#ifdef NDEBUG + LLVM_DEBUG(llvm::dbgs() + << "trying to collapse shape " << shape_to_string(shape) + << " to " << shape_to_string(newShape) << "\n"); +#endif // NDEBUG + + std::optional<SmallVector<ReassociationIndices>> reassociation = + getReassociationIndicesForCollapse(shape, newShape); + assert(reassociation && "reassociation must be valid for collapse"); + if (memref::CollapseShapeOp::isGuaranteedCollapsible(memrefToCollapse, + *reassociation)) { + collapsedShape = std::move(newShape); + return success(); + } + } + + return failure(); +} + +#ifndef NDEBUG +static inline std::string shape_to_string(ArrayRef<int64_t> shape) { + std::ostringstream shapeStream; + + for (auto dim : shape) { + shapeStream << dim << 'x'; + } + + std::string shapeStr = shapeStream.str(); + + // Remove the trailing 'x' character. + if (!shapeStr.empty()) { + shapeStr.pop_back(); + } + + return shapeStr; +} +#endif // NDEBUG + +void memref::populateExpandCollapseCopyOpsPatterns( + RewritePatternSet &patterns, unsigned minRank, unsigned maxRank, + memref::ExpandCollapseFuncCB funcCB) { + patterns.add<ExpandCollapseCopyOpConverter>(patterns.getContext(), minRank, + maxRank, funcCB); +} diff --git a/mlir/test/Transforms/expand-collapse-copy-ops.mlir b/mlir/test/Transforms/expand-collapse-copy-ops.mlir new file mode 100644 index 000000000000000..b3cd187424e084b --- /dev/null +++ b/mlir/test/Transforms/expand-collapse-copy-ops.mlir @@ -0,0 +1,141 @@ +// RUN: mlir-opt -test-expand-collapse-copy-ops="minRank=2 maxRank=3" %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @empty() { +// CHECK: return +// CHECK: } +func.func @empty() -> () { + return +} + +// ----- + +// CHECK-LABEL: func.func @memref_copy_to_expand( +// CHECK-SAME: %[[VAL_0:.*]]: memref<6xi32>) { +// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<6xi32> +// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32> +// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32> +// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x6xi32> to memref<1x6xi32> +// CHECK: return +// CHECK: } +func.func @memref_copy_to_expand(%arg0: memref<6xi32>) { + %0 = memref.alloc() : memref<6xi32> + memref.copy %arg0, %0 : memref<6xi32> to memref<6xi32> + return +} + +// ----- + +// CHECK-LABEL: func.func @memref_copy_to_collapse( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xi32>) { +// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32> +// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x5x1152xi32> to memref<1x5x1152xi32> +// CHECK: return +// CHECK: } +func.func @memref_copy_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) { + memref.copy %arg0, %arg1 : memref<1x5x24x48xi32> to memref<1x5x24x48xi32> + return +} + +// ----- + +// CHECK-LABEL: func.func @memref_copy_collapse_expand_in_loop( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 5760 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32> +// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32> +// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32> +// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32> +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] { +// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<16xf32> +// CHECK: %[[VAL_14:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>> +// CHECK: %[[VAL_15:.*]] = memref.expand_shape %[[VAL_13]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32> +// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32> +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xf32> +// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = memref.expand_shape %[[VAL_16]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32> +// CHECK: memref.copy %[[VAL_17]], %[[VAL_18]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32> +// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf32> +// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%[[VAL_13]], %[[VAL_16]] : memref<16xf32>, memref<16xf32>) outs(%[[VAL_19]] : memref<16xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } +// CHECK: %[[VAL_24:.*]] = memref.expand_shape %[[VAL_19]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32> +// CHECK: %[[VAL_25:.*]] = memref.expand_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>> +// CHECK: memref.copy %[[VAL_24]], %[[VAL_25]] : memref<1x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>> +// CHECK: } +// CHECK: return %[[VAL_5]] : memref<1x5x24x48xf32> +// CHECK: } +#map = affine_map<(d0) -> (d0)> +module { + func.func @memref_copy_collapse_expand_in_loop(%arg0: memref<1x5x24x48xf32>, %arg1: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> { + %c5760 = arith.constant 5760 : index + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32> + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32> + %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32> + %collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32> + scf.for %arg2 = %c0 to %c5760 step %c16 { + %subview = memref.subview %collapse_shape[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>> + %subview_2 = memref.subview %collapse_shape_0[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>> + %subview_3 = memref.subview %collapse_shape_1[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>> + %alloc_4 = memref.alloc() : memref<16xf32> + memref.copy %subview, %alloc_4 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32> + %alloc_5 = memref.alloc() : memref<16xf32> + memref.copy %subview_2, %alloc_5 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32> + %alloc_6 = memref.alloc() : memref<16xf32> + linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%alloc_4, %alloc_5 : memref<16xf32>, memref<16xf32>) outs(%alloc_6 : memref<16xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %0 = arith.addf %in, %in_7 : f32 + linalg.yield %0 : f32 + } + memref.copy %alloc_6, %subview_3 : memref<16xf32> to memref<16xf32, strided<[1], offset: ?>> + } + return %alloc : memref<1x5x24x48xf32> + } +} + +// ----- + +// CHECK-LABEL: func.func @memref_copy_strided_to_collapse( +// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xi32>) { +// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]][0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> +// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_1]][0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> +// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2], [3]] : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> into memref<1x120x24xi32, strided<[5760, 48, 1]>> +// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0], [1, 2], [3]] : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> into memref<1x120x24xi32, strided<[5760, 48, 1]>> +// CHECK: memref.copy %[[VAL_4]], %[[VAL_5]] : memref<1x120x24xi32, strided<[5760, 48, 1]>> to memref<1x120x24xi32, strided<[5760, 48, 1]>> +// CHECK: return +// CHECK: } +func.func @memref_copy_strided_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) { + %subview = memref.subview %arg0[0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>> + %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>> + memref.copy %subview, %subview0 : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>> + return +} + +// ----- + +// CHECK-LABEL: func.func @memref_copy_strided_cant_collapse( +// CHECK-SAME: %[[VAL_0:.*]]: memref<2x6x24x48xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<2x6x24x48xi32>) { +// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]][0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>> +// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_1]][0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>> +// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>> +// CHECK: return +// CHECK: } +func.func @memref_copy_strided_cant_collapse(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) { + %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>> + %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>> + memref.copy %subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>> + return +} diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt index 0498de3eb93178b..d665620b42a57b8 100644 --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp TestEmulateNarrowType.cpp TestMultiBuffer.cpp + TestExpandCollapseCopyOps.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp b/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp new file mode 100644 index 000000000000000..446a70b538cdc9d --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp @@ -0,0 +1,66 @@ +//===- TestExpandCollapseCopyOps.cpp.cpp - Test expand collapse copies ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the expand collapse copies patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestExpandCollapseCopyOpsPass + : public PassWrapper<TestExpandCollapseCopyOpsPass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandCollapseCopyOpsPass) + + TestExpandCollapseCopyOpsPass() = default; + TestExpandCollapseCopyOpsPass(const TestExpandCollapseCopyOpsPass &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const final { + return "test-expand-collapse-copy-ops"; + } + StringRef getDescription() const final { + return "Test expand collapse copies"; + } + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override; + + Option<unsigned> minRank{ + *this, "minRank", + llvm::cl::desc("Minimum rank allowed for a MemRef Copy."), + llvm::cl::init(2)}; + Option<unsigned> maxRank{ + *this, "maxRank", + llvm::cl::desc("Maximum rank allowed for a MemRef Copy."), + llvm::cl::init(3)}; +}; + +void TestExpandCollapseCopyOpsPass::getDependentDialects( + DialectRegistry ®istry) const { + registry.insert<memref::MemRefDialect>(); +} + +void TestExpandCollapseCopyOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateExpandCollapseCopyOpsPatterns(patterns, minRank, maxRank); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} +} // namespace + +namespace mlir { +namespace test { +void registerTestExpandCollapseCopyOps() { + PassRegistration<TestExpandCollapseCopyOpsPass>(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index b7647d7de78a10e..ec2ba8838fd68d2 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -92,6 +92,7 @@ void registerTestEmulateNarrowTypePass(); void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); +void registerTestExpandCollapseCopyOps(); void registerTestMultiBuffering(); void registerTestIntRangeInference(); void registerTestIRVisitorsPass(); @@ -214,6 +215,7 @@ void registerTestPasses() { mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); + mlir::test::registerTestExpandCollapseCopyOps(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestIntRangeInference(); mlir::test::registerTestIRVisitorsPass(); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits