https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/77217
>From e016ccb680d84257fe44e4e408bad6e510eb703d Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Sat, 6 Jan 2024 18:51:24 -0600 Subject: [PATCH 1/2] Include output size in determining UB for `tensor.pack` --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 7 +++-- .../Dialect/Linalg/Transforms/Transforms.cpp | 6 ++-- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 30 ++++++++++++++----- mlir/test/Dialect/Tensor/invalid.mlir | 10 ++++++- 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index eb0c79c01bee1..1c61ece2676a9 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1943,11 +1943,12 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [ // Returns true if we have enough static information to catch undefined // behavior when the tile size does not divide perfectly the dimension of - // the input tensor. If a given dimension or a tile associated with it is - // dynamic, the dimension is not considered as we don't have enough static - // information to understand if the tile perfectly divides that dimension. + // the input tensor. Detecting UB requires that the input size and either + // corresponding tile or output size are static. static bool requirePaddingValue(ArrayRef<int64_t> inputShape, ArrayRef<int64_t> innerDimsPos, + ArrayRef<int64_t> outputShape, + ArrayRef<int64_t> outerDimsPerm, ArrayRef<OpFoldResult> innerTiles); static Value createDestinationTensor(OpBuilder &b, Location loc, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 9d230e2c2e574..c7fed41d234fd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -582,8 +582,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, return getConstantIntValue(tile).has_value(); }); if (areConstantTiles && operandType.hasStaticShape() && - !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos, - innerPackSizes)) { + !tensor::PackOp::requirePaddingValue( + operandType.getShape(), innerPos, + dest.getType().cast<ShapedType>().getShape(), {}, + innerPackSizes)) { packOps.push_back(rewriter.create<tensor::PackOp>( loc, operand, dest, innerPos, innerPackSizes)); } else { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 816e6ba8fed94..4318e55fd213e 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3742,14 +3742,27 @@ SmallVector<int64_t> PackOp::getStaticTiles() { bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape, ArrayRef<int64_t> innerDimsPos, + ArrayRef<int64_t> outputShape, + ArrayRef<int64_t> outerDimsPerm, ArrayRef<OpFoldResult> innerTiles) { + SmallVector<int64_t> outputTileSizes( + outputShape.take_front(inputShape.size())); + if (!outerDimsPerm.empty()) { + assert(outerDimsPerm.size() == outputTileSizes.size() && + "expected output and outer_dims_perm to have same size"); + applyPermutationToVector(outputTileSizes, + invertPermutationVector(outerDimsPerm)); + } for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { if (ShapedType::isDynamic(inputShape[pos])) continue; std::optional<int64_t> constantTile = getConstantIntValue(tileSize); - if (!constantTile) - continue; - if (inputShape[pos] % (*constantTile) != 0) + + if (!constantTile) { + if (!ShapedType::isDynamic(outputTileSizes[pos]) && + (inputShape[pos] % outputTileSizes[pos] != 0)) + return true; + } else if (inputShape[pos] % (*constantTile) != 0) return true; } return false; @@ -3772,9 +3785,11 @@ LogicalResult PackOp::verify() { if (!paddingValue && requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(), + getDestType().getShape(), getOuterDimsPerm(), getMixedTiles())) { - return emitOpError("invalid tile factor provided. Only full tiles are " - "supported when padding_value is not set"); + return emitOpError( + "invalid tile factor or output size provided. Only full tiles are " + "supported when padding_value is not set"); } return success(); } @@ -3975,8 +3990,9 @@ static bool paddingIsNotNeeded(PackOp op) { return false; if (ShapedType::isDynamicShape(op.getStaticInnerTiles())) return false; - return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(), - op.getMixedTiles()); + return !PackOp::requirePaddingValue( + srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(), + op.getOuterDimsPerm(), op.getMixedTiles()); } LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 735e5146e9dbc..0eb8672d4d41b 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -597,13 +597,21 @@ func.func @empty_wrong_number_of_operands(%sz : index) { // ----- func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> { - // expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}} + // expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}} %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32> return %0 : tensor<8x8x16x33xf32> } // ----- +func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> { + // expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}} + %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<10x8x?x?xf32> + return %0 : tensor<10x8x?x?xf32> +} + +// ----- + func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> { // expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}} %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32> >From d841037a854d363ea9c1b61cf1e4bfa90e4dc5b0 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Sat, 6 Jan 2024 22:25:34 -0600 Subject: [PATCH 2/2] Add test with `outer_dims_perm` --- mlir/test/Dialect/Tensor/invalid.mlir | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 0eb8672d4d41b..4c534fe936e3d 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -612,6 +612,14 @@ func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x12 // ----- +func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> { + // expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}} + %0 = tensor.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<8x10x?x?xf32> + return %0 : tensor<8x10x?x?xf32> +} + +// ----- + func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> { // expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}} %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32> _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits