Author: MaheshRavishankar Date: 2021-01-22T12:55:25-08:00 New Revision: 430d43e010bdd07d73c4d0d6536206d22d35a2cb
URL: https://github.com/llvm/llvm-project/commit/430d43e010bdd07d73c4d0d6536206d22d35a2cb DIFF: https://github.com/llvm/llvm-project/commit/430d43e010bdd07d73c4d0d6536206d22d35a2cb.diff LOG: [mlir][Linalg] Disable fusion of tensor_reshape op by expansion when unit-dims are involved Fusion of generic/indexed_generic operations with tensor_reshape by expansion when the latter just adds/removes unit-dimensions is disabled since it just adds unit-trip count loops. Differential Revision: https://reviews.llvm.org/D94626 Added: Modified: mlir/include/mlir/Dialect/Linalg/Passes.h mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir mlir/test/Dialect/Linalg/reshape_fusion.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index d041df86d169..5d68328acc7e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -72,6 +72,15 @@ void populateFoldReshapeOpsByExpansionPatterns( void populateFoldReshapeOpsByLinearizationPatterns( MLIRContext *context, OwningRewritePatternList &patterns); +/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its +/// producer (consumer) generic/indexed_generic operation by linearizing the +/// indexing map used to access the source (target) of the reshape operation in +/// the generic/indexed_generic operation. The patterns are applied only when +/// the tensor reshape involved is collapsing (introducing) unit-extent +/// dimensions. +void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + /// Patterns for fusing linalg operation on tensors. void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, OwningRewritePatternList &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 8d09d58b9d7a..3c7b2223ee49 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -497,6 +497,7 @@ void mlir::populateLinalgFoldUnitExtentDimsPatterns( ReplaceUnitExtentTensors<IndexedGenericOp>>(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); patterns.insert<FoldReshapeOpWithUnitExtent>(context); + populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 670d456ad2f2..0c5b8486824f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -302,9 +302,18 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap, assert(!collapsedDims.empty()); unsigned startDim = collapsedDims.front().cast<AffineDimExpr>().getPosition(); - AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( - sourceShape.slice(startDim, collapsedDims.size()), - sourceExprs.slice(startDim, collapsedDims.size()), context); + SmallVector<int64_t, 4> sizes; + SmallVector<AffineExpr, 4> dimExprs; + for (auto en : + llvm::zip(sourceShape.slice(startDim, collapsedDims.size()), + sourceExprs.slice(startDim, collapsedDims.size()))) { + if (std::get<0>(en) == 1) + continue; + sizes.push_back(std::get<0>(en)); + dimExprs.push_back(std::get<1>(en)); + } + AffineExpr linearizedExpr = + makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); resultExprs.push_back(linearizedExpr); } return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), @@ -349,6 +358,23 @@ static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, return nullptr; } +/// Check if the reshape operation is only expansion into/collapsing of +/// unit-dimension. +static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape, + ArrayRef<AffineMap> reassociation) { + for (auto &map : reassociation) { + unsigned numUnitDims = 0; + for (AffineExpr expr : map.getResults()) { + unsigned position = expr.cast<AffineDimExpr>().getPosition(); + if (expandedShape[position] == 1) + numUnitDims++; + } + if (numUnitDims != map.getNumResults() - 1) + return false; + } + return true; +} + /// Conditions for folding a generic/indexed-generic operation with a reshape op /// by expanding the iteration space dimensionality for tensor operations. These /// are preconditions assumed by `foldReshapeByDimExpansion` which implements @@ -776,7 +802,7 @@ namespace { /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ... /// -> tensor<?x?x4x?xf32> -template <typename LinalgOpTy> +template <typename LinalgOpTy, bool foldUnitDimReshapesOnly> struct FoldProducerReshapeOpByLinearization : public OpRewritePattern<LinalgOpTy> { using OpRewritePattern<LinalgOpTy>::OpRewritePattern; @@ -792,7 +818,10 @@ struct FoldProducerReshapeOpByLinearization if (!reshapeOp || !isTensorReshapeOpFoldableByLinearization( reshapeOp, linalgOp.getInputIndexingMap(operand.index()), - /*asProducer =*/true)) + /*asProducer =*/true) || + (foldUnitDimReshapesOnly && + !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps()))) continue; // Compute the fused operands list, @@ -858,7 +887,9 @@ struct FoldWithProducerReshapeOpByExpansion // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || - !isFusableWithReshapeByDimExpansion(linalgOp, operand.index())) + !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || + isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps())) continue; Optional<SmallVector<Value, 1>> replacementValues = @@ -877,6 +908,7 @@ struct FoldWithProducerReshapeOpByExpansion /// Pattern to fold tensor_reshape op with its producer. The corresponding index /// map in the consumer needs to be modified to linearize the folded dimension. +template <bool foldUnitDimReshapesOnly> struct FoldConsumerReshapeOpByLinearization : public OpRewritePattern<TensorReshapeOp> { using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; @@ -888,7 +920,11 @@ struct FoldConsumerReshapeOpByLinearization !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) || !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || !isTensorReshapeOpFoldableByLinearization( - reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false)) + reshapeOp, producer.getOutputIndexingMap(0), + /*asProducer =*/false) || + (foldUnitDimReshapesOnly && + !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()))) return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. @@ -949,7 +985,10 @@ struct FoldReshapeWithGenericOpByExpansion return failure(); LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>(); if (!producer || producer.getNumOutputs() != 1 || - !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs())) + !isFusableWithReshapeByDimExpansion(producer, + producer.getNumInputs()) || + isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps())) return failure(); Optional<SmallVector<Value, 1>> replacementValues = fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), @@ -1098,9 +1137,16 @@ struct FoldReshapeOpsByLinearizationPass void mlir::populateFoldReshapeOpsByLinearizationPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp>, - FoldProducerReshapeOpByLinearization<IndexedGenericOp>, - FoldConsumerReshapeOpByLinearization>(context); + patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>, + FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>, + FoldConsumerReshapeOpByLinearization<false>>(context); +} + +void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>, + FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>, + FoldConsumerReshapeOpByLinearization<true>>(context); } void mlir::populateFoldReshapeOpsByExpansionPatterns( diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 17b8bda967b1..d40a91667500 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -331,3 +331,26 @@ func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32> ] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0, d1, d2) -> (d2)> +func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32> +{ + %1 = linalg.init_tensor [1, 2, 5] : tensor<1x2x5xf32> + %2 = linalg.generic {i64, indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + linalg.yield %arg1 : f32 + } -> tensor<1x2x5xf32> + %3 = linalg.tensor_reshape %2 [#map3, #map4] + : tensor<1x2x5xf32> into tensor<2x5xf32> + return %3 : tensor<2x5xf32> +} +// CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 447917548c5c..50269e36751b 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -188,42 +188,6 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) // ----- -func @scalar_reshape( - %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32> -{ - %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32> - %1 = linalg.init_tensor [10] : tensor<10xf32> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%0 : tensor<f32>) - outs(%1 : tensor<10xf32>) { - ^bb0(%arg2: f32, %s: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<10xf32> - %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] - : tensor<10xf32> into tensor<1x10xf32> - return %3 : tensor<1x10xf32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()> -// CHECK: func @scalar_reshape -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32> -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] [] -// CHECK-SAME: tensor<1xf32> into tensor<f32> -// CHECK: %[[T1:.+]] = linalg.init_tensor [10] -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]] -// CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[T0]] : tensor<f32>) -// CHECK-SAME: outs(%[[T2]] : tensor<1x10xf32>) -// CHECK: return %[[T3]] : tensor<1x10xf32> - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>, @@ -336,7 +300,7 @@ func @reshape_as_consumer_permutation %5 = addi %3, %4 : i32 %6 = index_cast %arg2 : index to i32 %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 + linalg.yield %7 : i32 } -> tensor<6x4x210xi32> %d = linalg.tensor_reshape %c [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, @@ -493,3 +457,77 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>, // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>) // CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>) // CHECK: return %[[T3]] : tensor<?x?x4x5xf32> + +// ----- + +func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> { + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1) -> (d0, d1)>] : tensor<1x5xf32> into tensor<5xf32> + %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %2 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x5xf32> + return %2 : tensor<5x5xf32> +} +// CHECK: func @unit_dim_reshape_expansion +// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.init_tensor +// CHECK: linalg.generic + +// ----- + +func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> { + %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x5xf32> + %2 = linalg.tensor_reshape %1 + [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<5x5xf32> into tensor<5x1x5xf32> + return %2 : tensor<5x1x5xf32> +} +// CHECK: func @unit_dim_reshape_collapse +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK: linalg.tensor_reshape + +// ----- + +func @unit_dim_reshape_expansion_full + (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>) + -> tensor<?x2x4xf32> { + %c1 = constant 1 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32> + %1 = dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32> + %2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32> + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>) + outs(%2 : tensor<?x2x4xf32>) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %4 = mulf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor<?x2x4xf32> + return %3 : tensor<?x2x4xf32> +} +// CHECK: func @unit_dim_reshape_expansion_full +// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.init_tensor +// CHECK: linalg.generic _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits