https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/75410
>From c07f7e1c5c6f8bbc7189e96096004d39a0a1aa3f Mon Sep 17 00:00:00 2001 From: hanhanW <hanhan0...@gmail.com> Date: Wed, 13 Dec 2023 15:59:48 -0800 Subject: [PATCH 1/3] [mlir][TilingInterface] Early return cloned ops if tile sizes are zeros. It is a trivial early-return case. If the cloned ops are not returned, it will generate `extract_slice` op that extracts the whole slice. However, it is not folded away. Early-return to avoid the case. E.g., ```mlir func.func @matmul_tensors( %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op) transform.yield } } ``` Apply the transforms and canonicalize the IR: ``` mlir-opt --transform-interpreter -canonicalize input.mlir ``` we will get ```mlir module { func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> %dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> %extracted_slice = tensor.extract_slice %arg0[0, 0] [%dim, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> %extracted_slice_2 = tensor.extract_slice %arg1[0, 0] [%dim_0, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> %extracted_slice_3 = tensor.extract_slice %arg2[0, 0] [%dim, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> %0 = linalg.matmul ins(%extracted_slice, %extracted_slice_2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?x?xf32>) -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } } ``` --- .../SCF/Transforms/TileUsingInterface.cpp | 11 ++++++-- mlir/test/Dialect/Linalg/tile-tensors.mlir | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 8057b3898012d4..20413aba8730be 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -362,14 +362,21 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, auto clonedOp = cast<TilingInterface>( cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination)); - // 5b. Tile the cloned operation. + // 5b. Early return cloned op if tiling is not happenning. + if (llvm::all_of(tileSizeVector, + [](OpFoldResult v) { return isZeroIndex(v); })) { + return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{}, + clonedOp->getResults()}; + } + + // 5c. Tile the cloned operation. FailureOr<TilingResult> tiledImplementation = clonedOp.getTiledImplementation(rewriter, offsets, sizes); if (failed(tiledImplementation)) { return rewriter.notifyMatchFailure(op, "failed to tile operation"); } - // 5c. Delete the cloned operation. + // 5d. Delete the cloned operation. rewriter.eraseOp(clonedOp); // If loops are empty, the tiled op is used as the replacement for the untiled diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index e0429b1f873298..e8e63302286400 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -37,6 +37,33 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @matmul_tensors_with_size_zeros( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32> +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32> +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> { +func.func @matmul_tensors_with_size_zeros( + %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) + -> tensor<?x?xf32> { + +// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[TA]], %[[TB]] : tensor<?x?xf32>, tensor<?x?xf32>) +// CHECK-SAME: outs(%[[TC]] : tensor<?x?xf32>) -> tensor<?x?xf32> +// CHECK: return %[[RES]] + %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>) + outs(%arg2: tensor<?x?xf32>) + -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + func.func @generic_op_tensors( %arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c0 = arith.constant 0 : index >From 1d1a47f7d6ddfce0edbdd5728bb770d54d9c321f Mon Sep 17 00:00:00 2001 From: hanhanW <hanhan0...@gmail.com> Date: Fri, 15 Dec 2023 13:13:16 -0800 Subject: [PATCH 2/3] review comments --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 20413aba8730be..51b402d5685e4a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -362,9 +362,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, auto clonedOp = cast<TilingInterface>( cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination)); - // 5b. Early return cloned op if tiling is not happenning. - if (llvm::all_of(tileSizeVector, - [](OpFoldResult v) { return isZeroIndex(v); })) { + // 5b. Early return cloned op if tiling is not happening. + if (llvm::all_of(tileSizeVector, isZeroIndex)) { return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{}, clonedOp->getResults()}; } >From 71d7c57b8b23b60478eda5650ba04bf36338c3db Mon Sep 17 00:00:00 2001 From: hanhanW <hanhan0...@gmail.com> Date: Mon, 18 Dec 2023 10:03:12 -0800 Subject: [PATCH 3/3] add comments about returning cloned op --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 51b402d5685e4a..1b6b4db9d20907 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -362,7 +362,9 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, auto clonedOp = cast<TilingInterface>( cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination)); - // 5b. Early return cloned op if tiling is not happening. + // 5b. Early return cloned op if tiling is not happening. We can not return + // the original op because it could lead to + // `rewriter.replaceOp(op, op->getResults())` and user would get crash. if (llvm::all_of(tileSizeVector, isZeroIndex)) { return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{}, clonedOp->getResults()}; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits