Author: MaheshRavishankar Date: 2021-01-11T09:22:35-08:00 New Revision: 9c0dc0b2c1cc973056237bdd80dbba749941ea63
URL: https://github.com/llvm/llvm-project/commit/9c0dc0b2c1cc973056237bdd80dbba749941ea63 DIFF: https://github.com/llvm/llvm-project/commit/9c0dc0b2c1cc973056237bdd80dbba749941ea63.diff LOG: [mlir][Linalg] Fold init_tensor -> linalg.tensor_reshape. Reshaping an init_tensor can be folded to a init_tensor op of the final type. Differential Revision: https://reviews.llvm.org/D93773 Added: Modified: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/test/Dialect/Linalg/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 8a97753e1a5c..8732065bb042 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -718,9 +718,123 @@ struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> { }; } // namespace +static Value getCollapsedInitTensor(OpBuilder &builder, + TensorReshapeOp reshapeOp) { + Location loc = reshapeOp.getLoc(); + SmallVector<Value, 4> dynamicShapes; + SmallVector<int64_t, 4> staticShapes; + auto reassociation = reshapeOp.getReassociationMaps(); + Value src = reshapeOp.src(); + RankedTensorType srcType = reshapeOp.getSrcType(); + ArrayRef<int64_t> srcShape = srcType.getShape(); + for (auto map : reassociation) { + Value linearizedDynamicDim = nullptr; + int64_t linearizedStaticDim = 1; + for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) { + return e.cast<AffineDimExpr>().getPosition(); + })) { + if (ShapedType::isDynamic(srcShape[i])) { + Value shapeVal = builder.create<DimOp>(loc, src, i); + if (linearizedDynamicDim) { + linearizedDynamicDim = + builder.create<MulIOp>(loc, linearizedDynamicDim, shapeVal); + } else { + linearizedDynamicDim = shapeVal; + } + } else { + linearizedStaticDim *= srcShape[i]; + } + } + if (linearizedDynamicDim) { + if (linearizedStaticDim != 1) { + linearizedDynamicDim = builder.create<MulIOp>( + loc, linearizedDynamicDim, + builder.create<ConstantIndexOp>(loc, linearizedStaticDim)); + } + dynamicShapes.push_back(linearizedDynamicDim); + staticShapes.push_back(ShapedType::kDynamicSize); + } else { + staticShapes.push_back(linearizedStaticDim); + } + } + return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes, + srcType.getElementType()); +} + +static Value getExpandedInitTensor(OpBuilder &builder, + TensorReshapeOp reshapeOp) { + SmallVector<Value, 4> dynamicShapes; + SmallVector<int64_t, 4> staticShapes; + auto reassociation = reshapeOp.getReassociationMaps(); + Value src = reshapeOp.src(); + RankedTensorType srcType = reshapeOp.getSrcType(); + ArrayRef<int64_t> srcShape = srcType.getShape(); + ArrayRef<int64_t> dstShape = reshapeOp.getResultType().getShape(); + Location loc = reshapeOp.getLoc(); + for (auto map : enumerate(reassociation)) { + int64_t linearizedStaticDim = 1; + bool hasDynamic = false; + for (unsigned i : + llvm::map_range(map.value().getResults(), [](AffineExpr e) { + return e.cast<AffineDimExpr>().getPosition(); + })) { + if (ShapedType::isDynamic(dstShape[i])) { + // Only one of the dimensions of the expanded shape should be dynamic. + if (hasDynamic) + return nullptr; + hasDynamic = true; + staticShapes.push_back(ShapedType::kDynamicSize); + continue; + } + staticShapes.push_back(dstShape[i]); + linearizedStaticDim *= dstShape[i]; + } + if (hasDynamic) { + // If the expanded dimensions has a dynamic shape, the src shape must be + // dynamic as well. + if (!ShapedType::isDynamic(srcShape[map.index()])) + return nullptr; + Value dynamicDim = builder.create<DimOp>(loc, src, map.index()); + if (linearizedStaticDim != 1) { + dynamicDim = builder.create<UnsignedDivIOp>( + loc, dynamicDim, + builder.create<ConstantIndexOp>(loc, linearizedStaticDim)); + } + dynamicShapes.push_back(dynamicDim); + } + } + return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes, + srcType.getElementType()); +} + +namespace { +struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> { + using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.src().getDefiningOp<InitTensorOp>()) + return failure(); + RankedTensorType collapsedType = reshapeOp.getSrcType(); + RankedTensorType expandedType = reshapeOp.getResultType(); + bool isCollapsed = expandedType.getRank() < collapsedType.getRank(); + if (isCollapsed) + std::swap(collapsedType, expandedType); + Value initTensorOp = isCollapsed + ? getCollapsedInitTensor(rewriter, reshapeOp) + : getExpandedInitTensor(rewriter, reshapeOp); + if (!initTensorOp) + return failure(); + rewriter.replaceOp(reshapeOp, initTensorOp); + return success(); + } +}; +} // namespace + void InitTensorOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert<ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context); + results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp, + ReplaceStaticShapeDims>(context); } //===----------------------------------------------------------------------===// @@ -1043,23 +1157,23 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, ArrayRef<int64_t> expandedShape = expandedType.getShape(); unsigned expandedDimStart = 0; for (auto map : llvm::enumerate(op.getReassociationMaps())) { - Optional<int64_t> dynamicDims; + Optional<int64_t> dynamicShape; int64_t linearizedStaticShape = 1; for (auto dim : llvm::enumerate(expandedShape.slice( expandedDimStart, map.value().getNumResults()))) { if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicDims) { + if (isExpandingReshape && dynamicShape) { return op->emitOpError("invalid to have a single dimension (") << map.index() << ") expanded into multiple dynamic dims (" - << expandedDimStart + dynamicDims.getValue() << "," + << expandedDimStart + dynamicShape.getValue() << "," << expandedDimStart + dim.index() << ")"; } - dynamicDims = dim.index(); + dynamicShape = dim.index(); } else { linearizedStaticShape *= dim.value(); } } - if (dynamicDims) { + if (dynamicShape) { if (!ShapedType::isDynamic(collapsedShape[map.index()])) { return op->emitOpError("expected dimension ") << map.index() diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 4102a1326b96..6b806c801341 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -413,3 +413,39 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>, // CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>) // CHECK: dim [[ARG_0]] // CHECK: dim [[ARG_1]] + +// ----- + +func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: func @init_tensor_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7] +// CHECK: return %[[T1]] + +// ----- + +func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] : + tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: func @init_tensor_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] +// CHECK: return %[[T1]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits