Author: Nirvedh Date: 2022-03-15T18:42:43Z New Revision: b8d211fc317ffefaed1d65b226cda6c464f7d216
URL: https://github.com/llvm/llvm-project/commit/b8d211fc317ffefaed1d65b226cda6c464f7d216 DIFF: https://github.com/llvm/llvm-project/commit/b8d211fc317ffefaed1d65b226cda6c464f7d216.diff LOG: [MLIR][Linalg] Canonicalization patterns for linalg.generic. Fold linalg.fill into linalg.generic. Remove dead arguments used in linalg.generic. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D121535 Added: Modified: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/test/Dialect/Linalg/canonicalize.mlir mlir/test/Dialect/Linalg/fusion-indexed.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 331a8b91bd330..2c9b1e0c53553 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -105,6 +105,33 @@ static LogicalResult foldMemRefCast(Operation *op) { return success(folded); } +/// Helper function to find if there is atleast one dimension in an AffineMap +/// testMap that is contained in `testMapLocation` of `maps` but not in any +/// other locations +static bool hasaUniqueDim(ArrayRef<AffineMap> maps, unsigned testMapLocation) { + AffineMap testMap = maps[testMapLocation]; + llvm::SmallDenseSet<unsigned> dimsToCheck; + for (auto result : testMap.getResults()) { + auto expr = result.dyn_cast<AffineDimExpr>(); + if (expr != nullptr) + dimsToCheck.insert(expr.getPosition()); + } + for (auto It : llvm::enumerate(maps)) { + if (It.index() == testMapLocation) + continue; + auto map = It.value(); + for (auto result : map.getResults()) { + auto expr = result.dyn_cast<AffineDimExpr>(); + if (expr != nullptr) { + dimsToCheck.erase(expr.getPosition()); + } + if (dimsToCheck.empty()) + return false; + } + } + return true; +} + //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. @@ -826,11 +853,95 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { return success(); } }; + +/// Drop dead args of a linalg generic op. +/// An arg is dead if it has zero uses in the op region. +struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> { + using OpRewritePattern<GenericOp>::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + SmallVector<AffineMap> oldIndexingMaps = genericOp.getIndexingMaps(); + // Maps must be projected permutations. + if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { + return !map.isProjectedPermutation(); + })) + return failure(); + Block &payload = genericOp.region().front(); + SmallVector<Value> newInputOperands; + SmallVector<AffineMap> newIndexingMaps; + bool deadArgFound = false; + int inputSize = genericOp.getInputOperands().size(); + for (int i = inputSize - 1; i >= 0; i--) { + OpOperand *opOperand = genericOp.getInputOperand(i); + // Iterate in reverse, so that we erase later args first, preventing the + // argument list from shifting unexpectedly and invalidating all our + // indices. + if (payload.getArgument(i).use_empty() && + !hasaUniqueDim(oldIndexingMaps, i)) { + payload.eraseArgument(i); + deadArgFound = true; + // remove this indexing map out of consideration for hasaUniqueDim check + oldIndexingMaps.erase(oldIndexingMaps.begin() + i); + } else { + newInputOperands.insert(newInputOperands.begin(), opOperand->get()); + newIndexingMaps.insert(newIndexingMaps.begin(), + genericOp.getTiedIndexingMap(opOperand)); + } + } + // Bail out if there are no dead args. + if (!deadArgFound) + return failure(); + for (OpOperand *opOperand : genericOp.getOutputOperands()) + newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); + SmallVector<Value> outputOperands = genericOp.getOutputOperands(); + + auto newOp = rewriter.create<GenericOp>( + genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, + outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr()); + // Copy over unknown attributes. They might be load bearing for some flow. + ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames(); + for (NamedAttribute kv : genericOp->getAttrs()) { + if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { + newOp->setAttr(kv.getName(), kv.getValue()); + } + } + rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; + +/// Fold linalg.fill into linalg.generic +struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { + using OpRewritePattern<GenericOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + bool fillFound = false; + Block &payload = genericOp.region().front(); + for (OpOperand *opOperand : genericOp.getInputOperands()) { + FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); + if (fillOp) { + fillFound = true; + payload.getArgument(opOperand->getOperandNumber()) + .replaceAllUsesWith(fillOp.value()); + } + } + // fail if there are no FillOps to fold. + return success(fillFound); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context); + results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp, + DeadArgsGenericOpInputs, FoldFillWithGenericOp>(context); } LogicalResult GenericOp::fold(ArrayRef<Attribute>, diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 28655b96e5df6..0e0faab56f6c9 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -325,6 +325,106 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) { // ----- +// CHECK-LABEL: func @fold_fill_generic_basic +// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>) +// CHECK-SAME: outs({{.*}} : tensor<?xf32>) { +#map0 = affine_map<(d0) -> (d0)> +func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor<?xf32> + %1 = linalg.init_tensor [%0] : tensor<?xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32> + %3 = linalg.init_tensor [%0] : tensor<?xf32> + %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %5 = arith.addf %arg1, %arg2 : f32 + linalg.yield %5 : f32 + } -> tensor<?xf32> + return %4 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @fold_fill_generic_mixedaccess +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-NOT: ins +// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) { +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %cst1 = arith.constant 7.0 : f32 + %cst2 = arith.constant 6.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32> + %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32> + %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32> + %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32> + %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32> + %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = arith.divf %arg1, %arg2 : f32 + linalg.yield %8 : f32 + } -> tensor<?x?xf32> + return %7 : tensor<?x?xf32> +} + +// ----- + +// CHECK-LABEL: func @remove_deadargs_generic_basic +// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>) +// CHECK-SAME: outs({{.*}} : tensor<?xf32>) { +#map0 = affine_map<(d0) -> (d0)> +func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor<?xf32> + %1 = linalg.init_tensor [%0] : tensor<?xf32> + %2 = linalg.init_tensor [%0] : tensor<?xf32> + %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %4 = arith.addf %arg1, %cst : f32 + linalg.yield %4 : f32 + } -> tensor<?xf32> + return %3 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-NOT: ins +// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) { +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %cst1 = arith.constant 7.0 : f32 + %cst2 = arith.constant 6.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32> + %3 = linalg.init_tensor [%1, %0] : tensor<?x?xf32> + %4 = linalg.init_tensor [%0, %1] : tensor<?x?xf32> + %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %6 = arith.divf %cst1, %cst2 : f32 + linalg.yield %6 : f32 + } -> tensor<?x?xf32> + return %5 : tensor<?x?xf32> +} + +// ----- // CHECK-LABEL: func @fold_fill_reshape() func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 diff --git a/mlir/test/Dialect/Linalg/fusion-indexed.mlir b/mlir/test/Dialect/Linalg/fusion-indexed.mlir index 1b075cc5ac483..03ac767136f00 100644 --- a/mlir/test/Dialect/Linalg/fusion-indexed.mlir +++ b/mlir/test/Dialect/Linalg/fusion-indexed.mlir @@ -46,7 +46,8 @@ func @fuse_indexed_consumer(%A: memref<?x?xf32>, %10 = arith.index_cast %7 : index to i32 %11 = arith.sitofp %10 : i32 to f32 %12 = arith.addf %9, %11 : f32 - linalg.yield %12 : f32 + %13 = arith.addf %12, %arg4 : f32 + linalg.yield %13 : f32 } } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits