Author: MaheshRavishankar
Date: 2021-01-22T12:55:25-08:00
New Revision: 430d43e010bdd07d73c4d0d6536206d22d35a2cb


LOG: [mlir][Linalg] Disable fusion of tensor_reshape op by expansion when 
unit-dims are involved

Fusion of generic/indexed_generic operations with tensor_reshape by
expansion when the latter just adds/removes unit-dimensions is
disabled since it just adds unit-trip count loops.

Differential Revision:




diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h 
index d041df86d169..5d68328acc7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -72,6 +72,15 @@ void populateFoldReshapeOpsByExpansionPatterns(
 void populateFoldReshapeOpsByLinearizationPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns);
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation. The patterns are applied only when
+/// the tensor reshape involved is collapsing (introducing) unit-extent
+/// dimensions.
+void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns);
 /// Patterns for fusing linalg operation on tensors.
 void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
                                            OwningRewritePatternList &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp 
index 8d09d58b9d7a..3c7b2223ee49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -497,6 +497,7 @@ void mlir::populateLinalgFoldUnitExtentDimsPatterns(
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+  populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns);
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp 
index 670d456ad2f2..0c5b8486824f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -302,9 +302,18 @@ static AffineMap linearizeCollapsedDims(AffineMap 
     unsigned startDim =
-    AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
-        sourceShape.slice(startDim, collapsedDims.size()),
-        sourceExprs.slice(startDim, collapsedDims.size()), context);
+    SmallVector<int64_t, 4> sizes;
+    SmallVector<AffineExpr, 4> dimExprs;
+    for (auto en :
+         llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
+                   sourceExprs.slice(startDim, collapsedDims.size()))) {
+      if (std::get<0>(en) == 1)
+        continue;
+      sizes.push_back(std::get<0>(en));
+      dimExprs.push_back(std::get<1>(en));
+    }
+    AffineExpr linearizedExpr =
+        makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
@@ -349,6 +358,23 @@ static LinalgOp createLinalgOpOfSameType(LinalgOp op, 
PatternRewriter &rewriter,
   return nullptr;
+/// Check if the reshape operation is only expansion into/collapsing of
+/// unit-dimension.
+static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
+                                   ArrayRef<AffineMap> reassociation) {
+  for (auto &map : reassociation) {
+    unsigned numUnitDims = 0;
+    for (AffineExpr expr : map.getResults()) {
+      unsigned position = expr.cast<AffineDimExpr>().getPosition();
+      if (expandedShape[position] == 1)
+        numUnitDims++;
+    }
+    if (numUnitDims != map.getNumResults() - 1)
+      return false;
+  }
+  return true;
 /// Conditions for folding a generic/indexed-generic operation with a reshape 
 /// by expanding the iteration space dimensionality for tensor operations. 
 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
@@ -776,7 +802,7 @@ namespace {
 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
 ///        -> tensor<?x?x4x?xf32>
-template <typename LinalgOpTy>
+template <typename LinalgOpTy, bool foldUnitDimReshapesOnly>
 struct FoldProducerReshapeOpByLinearization
     : public OpRewritePattern<LinalgOpTy> {
   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
@@ -792,7 +818,10 @@ struct FoldProducerReshapeOpByLinearization
       if (!reshapeOp ||
               reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
-              /*asProducer =*/true))
+              /*asProducer =*/true) ||
+          (foldUnitDimReshapesOnly &&
+           !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
+                                   reshapeOp.getReassociationMaps())))
       // Compute the fused operands list,
@@ -858,7 +887,9 @@ struct FoldWithProducerReshapeOpByExpansion
       // - All constraints of fusing with reshape by expansion are met.
       if (reshapeOp.getSrcType().getRank() <
               reshapeOp.getResultType().getRank() ||
-          !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()))
+          !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
+          isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+                                 reshapeOp.getReassociationMaps()))
       Optional<SmallVector<Value, 1>> replacementValues =
@@ -877,6 +908,7 @@ struct FoldWithProducerReshapeOpByExpansion
 /// Pattern to fold tensor_reshape op with its producer. The corresponding 
 /// map in the consumer needs to be modified to linearize the folded dimension.
+template <bool foldUnitDimReshapesOnly>
 struct FoldConsumerReshapeOpByLinearization
     : public OpRewritePattern<TensorReshapeOp> {
   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
@@ -888,7 +920,11 @@ struct FoldConsumerReshapeOpByLinearization
         !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
         !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
-            reshapeOp, producer.getOutputIndexingMap(0), /*asProducer 
+            reshapeOp, producer.getOutputIndexingMap(0),
+            /*asProducer =*/false) ||
+        (foldUnitDimReshapesOnly &&
+         !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+                                 reshapeOp.getReassociationMaps())))
       return failure();
     // The indexing_maps for the operands of the fused operation are same as
     // those for the operands of the producer.
@@ -949,7 +985,10 @@ struct FoldReshapeWithGenericOpByExpansion
       return failure();
     LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
     if (!producer || producer.getNumOutputs() != 1 ||
-        !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs()))
+        !isFusableWithReshapeByDimExpansion(producer,
+                                            producer.getNumInputs()) ||
+        isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
+                               reshapeOp.getReassociationMaps()))
       return failure();
     Optional<SmallVector<Value, 1>> replacementValues =
         fuseWithReshapeByExpansion(producer, reshapeOp, 
@@ -1098,9 +1137,16 @@ struct FoldReshapeOpsByLinearizationPass
 void mlir::populateFoldReshapeOpsByLinearizationPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns) {
-  patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp>,
-                  FoldProducerReshapeOpByLinearization<IndexedGenericOp>,
-                  FoldConsumerReshapeOpByLinearization>(context);
+  patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
+                  FoldProducerReshapeOpByLinearization<IndexedGenericOp, 
+                  FoldConsumerReshapeOpByLinearization<false>>(context);
+void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
+                  FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
+                  FoldConsumerReshapeOpByLinearization<true>>(context);
 void mlir::populateFoldReshapeOpsByExpansionPatterns(

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir 
index 17b8bda967b1..d40a91667500 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -331,3 +331,26 @@ func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
   ] : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
+// -----
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0, d1, d2) -> (d2)>
+func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
+  %1 = linalg.init_tensor [1, 2, 5] : tensor<1x2x5xf32>
+  %2 = linalg.generic {i64, indexing_maps = [#map1, #map0],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+      linalg.yield %arg1 : f32
+    } -> tensor<1x2x5xf32>
+  %3 = linalg.tensor_reshape %2 [#map3, #map4]
+    : tensor<1x2x5xf32> into tensor<2x5xf32>
+  return %3 : tensor<2x5xf32>
+// CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op
+//       CHECK:   %[[RESULT:.+]] = linalg.generic
+//       CHECK:   return %[[RESULT]]

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir 
index 447917548c5c..50269e36751b 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -188,42 +188,6 @@ func @generic_op_reshape_consumer_static(%arg0: 
 // -----
-func @scalar_reshape(
-  %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32>
-  %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32>
-  %1 = linalg.init_tensor [10] : tensor<10xf32>
-  %2 = linalg.generic
-    {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
-     iterator_types = ["parallel"]}
-     ins(%0 : tensor<f32>)
-     outs(%1 : tensor<10xf32>) {
-  ^bb0(%arg2: f32, %s: f32):  // no predecessors
-    linalg.yield %arg2 : f32
-  } -> tensor<10xf32>
-  %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>]
-    : tensor<10xf32> into tensor<1x10xf32>
-  return %3 : tensor<1x10xf32>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()>
-//      CHECK: func @scalar_reshape
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] []
-// CHECK-SAME:     tensor<1xf32> into tensor<f32>
-//      CHECK:   %[[T1:.+]] = linalg.init_tensor [10]
-//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]]
-//      CHECK:   %[[T3:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP0]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel"]
-// CHECK-SAME:     ins(%[[T0]] : tensor<f32>)
-// CHECK-SAME:     outs(%[[T2]] : tensor<1x10xf32>)
-//      CHECK:   return %[[T3]] : tensor<1x10xf32>
-// -----
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
 func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
@@ -336,7 +300,7 @@ func @reshape_as_consumer_permutation
          %5 = addi %3, %4 : i32
          %6 = index_cast %arg2 : index to i32
          %7 = addi %5, %6 : i32
-  linalg.yield %7 : i32
+        linalg.yield %7 : i32
        } -> tensor<6x4x210xi32>
   %d = linalg.tensor_reshape %c
          [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
@@ -493,3 +457,77 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : 
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, 
 // CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
+// -----
+func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1) -> (d0, d1)>] : tensor<1x5xf32> into tensor<5xf32>
+  %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
+  %2 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    linalg.yield %arg2 : f32
+  } -> tensor<5x5xf32>
+  return %2 : tensor<5x5xf32>
+//      CHECK: func @unit_dim_reshape_expansion
+//  CHECK-DAG:   linalg.tensor_reshape
+//  CHECK-DAG:   linalg.init_tensor
+//      CHECK:   linalg.generic
+// -----
+func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
+  %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    linalg.yield %arg2 : f32
+  } -> tensor<5x5xf32>
+  %2 = linalg.tensor_reshape %1
+    [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+    : tensor<5x5xf32> into tensor<5x1x5xf32>
+  return %2 : tensor<5x1x5xf32>
+// CHECK: func @unit_dim_reshape_collapse
+// CHECK:   linalg.init_tensor
+// CHECK:   linalg.generic
+// CHECK:   linalg.tensor_reshape
+// -----
+func @unit_dim_reshape_expansion_full
+  (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
+  -> tensor<?x2x4xf32> {
+  %c1 = constant 1 : index
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>]
+    : tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
+  %1 = dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
+  %2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
+  %3 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+     iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
+    outs(%2 : tensor<?x2x4xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
+    %4 = mulf %arg2, %arg3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<?x2x4xf32>
+  return %3 : tensor<?x2x4xf32>
+//      CHECK: func @unit_dim_reshape_expansion_full
+//  CHECK-DAG:   linalg.tensor_reshape
+//  CHECK-DAG:   linalg.init_tensor
+//      CHECK:   linalg.generic

