https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/74200
This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is another shot at the refactoring proposed in https://github.com/llvm/llvm-project/pull/72885. >From 22928e7e5da508d8d9dc8d4b7e54f84cccadef06 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Mon, 20 Nov 2023 09:02:41 +0100 Subject: [PATCH 1/5] [mlir][tensor] Fix canon via `hasNegativeDimension` --- mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 6 ++++++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 15 +++++++++++++++ mlir/test/Dialect/Tensor/canonicalize.mlir | 10 ++++++++++ 3 files changed, 31 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h index 06642adda42b3..0d027057b3a95 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -150,6 +150,12 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, /// Tests if types are the same when ignoring encoding on ranked tensors. bool isSameTypeWithoutEncoding(Type tp1, Type tp2); +/// Helper function to check whether the dimensions are non-negative. This +/// check also occurs in the verifier, but we need it at later stages too +/// because the verifier ignores dynamic dimensions, but later stages might +/// have constant folded those to (negative) constants. +bool hasNegativeDimension(SmallVector<int64_t> shape); + /// Function to control the folding of constant and extract slice. using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index e469815496e18..3297ef673ca2e 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -125,6 +125,12 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { return tp1 == tp2; // default implementation } +bool tensor::hasNegativeDimension(SmallVector<int64_t> shape) { + return llvm::any_of(shape, [](int64_t dim) { + return !ShapedType::isDynamic(dim) && dim < 0; + }); +} + /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or /// rank-extending tensor.insert_slice op. static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape, @@ -1801,6 +1807,10 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + if (hasNegativeDimension(staticOffsets)) + return {}; + if (hasNegativeDimension(staticSizes)) + return {}; return ExtractSliceOp::inferCanonicalRankReducedResultType( desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); @@ -2370,6 +2380,8 @@ class InsertSliceOpConstantArgumentFolder final auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), mixedOffsets, mixedSizes, mixedStrides); + if (!sourceType) + return failure(); Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); @@ -2500,6 +2512,8 @@ struct InsertSliceOpSourceCastInserter final getConstantIntValue(insertSliceOp.getMixedSizes()[i])) newSrcShape[i] = *constInt; } + // if (hasNegativeDimension(newSrcShape)) + // return failure(); RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); @@ -2521,6 +2535,7 @@ struct InsertSliceOpSourceCastInserter final rewriter.setInsertionPoint(insertSliceOp->getParentOp()); Value cast = rewriter.create<tensor::CastOp>( insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); + rewriter.replaceOpWithNewOp<InsertOpTy>( insertSliceOp, cast, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index ea8c17640d7c1..88f27d3d36b04 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1102,6 +1102,16 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) // ----- +func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> { + %c-1 = arith.constant -1 : index + %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32> + return %e : tensor<?xf32> +} +// CHECK-LABEL: func @no_fold_extract_slice_negative_offset +// CHECK: tensor.extract_slice + +// ----- + func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { %c0 = arith.constant dense<42> : tensor<2x8xi32> %0 = tensor.expand_shape %c0 [[0], [1, 2]] >From ecef5428c160cb72103e06a160c450440ce1f416 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Mon, 20 Nov 2023 16:27:53 +0100 Subject: [PATCH 2/5] Fix `insert_slice` cast inserter and refactor --- mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 6 ------ .../mlir/Dialect/Utils/StaticValueUtils.h | 6 ++++++ mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 15 ++++----------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 18 +++--------------- mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 6 ++++++ mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++++++ 6 files changed, 33 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h index 0d027057b3a95..06642adda42b3 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -150,12 +150,6 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, /// Tests if types are the same when ignoring encoding on ranked tensors. bool isSameTypeWithoutEncoding(Type tp1, Type tp2); -/// Helper function to check whether the dimensions are non-negative. This -/// check also occurs in the verifier, but we need it at later stages too -/// because the verifier ignores dynamic dimensions, but later stages might -/// have constant folded those to (negative) constants. -bool hasNegativeDimension(SmallVector<int64_t> shape); - /// Function to control the folding of constant and extract slice. using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>; diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 23a366036b9dd..9e39d81e5c4f9 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -128,6 +128,12 @@ std::pair<ArrayAttr, SmallVector<Value>> decomposeMixedValues(Builder &b, const SmallVectorImpl<OpFoldResult> &mixedValues); +/// Helper function to check whether the dimensions are non-negative. +/// +/// This is used to re-check whether dimensions are still non-negative after +/// constant folding the dynamic dimensions. +bool hasNegativeDimension(SmallVector<int64_t> values); + /// Helper to sort `values` according to matching `keys`. SmallVector<Value> getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a2fc954ad07fa..dd75ed2500306 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2621,17 +2621,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - // If one of the offsets or sizes is invalid, fail the canonicalization. - // These checks also occur in the verifier, but they are needed here - // because some dynamic dimensions may have been constant folded. - for (int64_t offset : staticOffsets) - if (offset < 0 && !ShapedType::isDynamic(offset)) - return {}; - for (int64_t size : staticSizes) - if (size < 0 && !ShapedType::isDynamic(size)) - return {}; - + if (hasNegativeDimension(staticOffsets)) + return {}; + if (hasNegativeDimension(staticSizes)) + return {}; return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 3297ef673ca2e..986e40a2e4eb3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -125,12 +125,6 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { return tp1 == tp2; // default implementation } -bool tensor::hasNegativeDimension(SmallVector<int64_t> shape) { - return llvm::any_of(shape, [](int64_t dim) { - return !ShapedType::isDynamic(dim) && dim < 0; - }); -} - /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or /// rank-extending tensor.insert_slice op. static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape, @@ -1265,13 +1259,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { SmallVector<int64_t> newShape; operandsAndShape(resultType, dynamicExtents, newOperands, newShape); - for (int64_t newdim : newShape) { - // This check also occurs in the verifier, but we need it here too - // since intermediate passes may have replaced some dynamic dimensions - // by constants. - if (newdim < 0 && !ShapedType::isDynamic(newdim)) + if (hasNegativeDimension(newShape)) return failure(); - } if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); @@ -2512,8 +2501,8 @@ struct InsertSliceOpSourceCastInserter final getConstantIntValue(insertSliceOp.getMixedSizes()[i])) newSrcShape[i] = *constInt; } - // if (hasNegativeDimension(newSrcShape)) - // return failure(); + if (hasNegativeDimension(newSrcShape)) + return failure(); RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); @@ -2535,7 +2524,6 @@ struct InsertSliceOpSourceCastInserter final rewriter.setInsertionPoint(insertSliceOp->getParentOp()); Value cast = rewriter.create<tensor::CastOp>( insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); - rewriter.replaceOpWithNewOp<InsertOpTy>( insertSliceOp, cast, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 8a4ccc990331a..5d777ad74e9e8 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -200,6 +200,12 @@ decomposeMixedValues(Builder &b, return {b.getI64ArrayAttr(staticValues), dynamicValues}; } +bool hasNegativeDimension(SmallVector<int64_t> values) { + return llvm::any_of(values, [](int64_t value) { + return !ShapedType::isDynamic(value) && value < 0; + }); +} + /// Helper to sort `values` according to matching `keys`. template <typename K, typename V> static SmallVector<V> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 88f27d3d36b04..1c0a2e868475f 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1112,6 +1112,20 @@ func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor // ----- +func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> { + %c = arith.constant 0 : index + %const = tensor.empty(%c) : tensor<?xf32> + %insert_val = tensor.empty(%c) : tensor<?xf32> + %c-1 = arith.constant -1 : index + %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32> + return %inserted : tensor<?xf32> +} +// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset +// CHECK: %[[CAST:.*]] = tensor.cast +// CHECK: tensor.insert_slice %[[CAST:.+]] + +// ----- + func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { %c0 = arith.constant dense<42> : tensor<2x8xi32> %0 = tensor.expand_shape %c0 [[0], [1, 2]] >From 69637ad2b8915f352c6dae3cab838a04b84c3e10 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Mon, 20 Nov 2023 16:40:09 +0100 Subject: [PATCH 3/5] Apply `clang-format` --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 986e40a2e4eb3..04a8e43a639f4 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1260,7 +1260,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { operandsAndShape(resultType, dynamicExtents, newOperands, newShape); if (hasNegativeDimension(newShape)) - return failure(); + return failure(); if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); >From ecd074dc485485ebf6b7ae7aa5ee52cb397994ca Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Sat, 2 Dec 2023 18:02:31 +0100 Subject: [PATCH 4/5] Refactor --- .../mlir/Dialect/Utils/StaticValueUtils.h | 36 ++++++++++++++----- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 7 ++-- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 16 +++------ mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 33 +++++++++++++---- mlir/test/Dialect/MemRef/canonicalize.mlir | 12 +++++++ 5 files changed, 75 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 768f0ac1abe56..a1853438ccf7f 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -128,12 +128,6 @@ std::pair<ArrayAttr, SmallVector<Value>> decomposeMixedValues(Builder &b, const SmallVectorImpl<OpFoldResult> &mixedValues); -/// Helper function to check whether the dimensions are non-negative. -/// -/// This is used to re-check whether dimensions are still non-negative after -/// constant folding the dynamic dimensions. -bool hasNegativeDimension(SmallVector<int64_t> values); - /// Helper to sort `values` according to matching `keys`. SmallVector<Value> getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, @@ -145,12 +139,36 @@ SmallVector<int64_t> getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, llvm::function_ref<bool(Attribute, Attribute)> compare); +/// Helper function to check whether the passed in `sizes` or `values` are +/// valid. This can be used to re-check whether dimensions are still valid +/// after constant folding the dynamic dimensions. +bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets); + +/// Helper function to check whether the passed in `strides` are valid. This +/// can be used to re-check whether dimensions are still valid after constant +/// folding the dynamic dimensions. +bool hasValidStrides(SmallVector<int64_t> strides); + /// Returns "success" when any of the elements in `ofrs` is a constant value. In /// that case the value is replaced by an attribute. Returns "failure" when no -/// folding happened. If `onlyNonNegative` is set, only non-negative constant -/// values are folded. +/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only +/// non-negative and non-zero constant values are folded respectively. LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, - bool onlyNonNegative = false); + bool onlyNonNegative = false, + bool onlyNonZero = false); + +/// Returns "success" when any of the elements in `OffsetsOrSizes` is a +/// constant value. In that case the value is replaced by an attribute. Returns +/// "failure" when no folding happened. Invalid values are not folded to avoid +/// canonicalization crashes. +LogicalResult +foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes); + +/// Returns "success" when any of the elements in `strides` is a constant +/// value. In that case the value is replaced by an attribute. Returns +/// "failure" when no folding happened. Invalid values are not folded to avoid +/// canonicalization crashes. +LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides); /// Return the number of iterations for a loop with a lower bound `lb`, upper /// bound `ub` and step `step`. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index f222011a2edf5..c6d947a2427db 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include <cstdint> using namespace mlir; using namespace mlir::memref; @@ -2581,9 +2582,11 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - if (hasNegativeDimension(staticOffsets)) + if (!hasValidSizesOffsets(staticOffsets)) return {}; - if (hasNegativeDimension(staticSizes)) + if (!hasValidSizesOffsets(staticSizes)) + return {}; + if (!hasValidStrides(staticStrides)) return {}; return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index eab1d261b1064..94b7b734f88fe 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1446,7 +1446,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { SmallVector<int64_t> newShape; operandsAndShape(resultType, dynamicExtents, newOperands, newShape); - if (hasNegativeDimension(newShape)) + if (!hasValidSizesOffsets(newShape)) return failure(); if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) @@ -1983,10 +1983,6 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - if (hasNegativeDimension(staticOffsets)) - return {}; - if (hasNegativeDimension(staticSizes)) - return {}; return ExtractSliceOp::inferCanonicalRankReducedResultType( desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); @@ -2547,17 +2543,15 @@ class InsertSliceOpConstantArgumentFolder final SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides()); // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedStrides))) + if (failed(foldDynamicOffsetSizeList(mixedOffsets)) && + failed(foldDynamicOffsetSizeList(mixedSizes)) && + failed(foldDynamicStrideList(mixedStrides))) return failure(); // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), mixedOffsets, mixedSizes, mixedStrides); - if (!sourceType) - return failure(); Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); @@ -2692,7 +2686,7 @@ struct InsertSliceOpSourceCastInserter final newSrcShape[i] = *constInt; } } - if (hasNegativeDimension(newSrcShape)) + if (!hasValidSizesOffsets(newSrcShape)) return failure(); RankedTensorType newSrcType = diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 4f606e17a4d59..0c8a88da789e2 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -200,12 +200,6 @@ decomposeMixedValues(Builder &b, return {b.getI64ArrayAttr(staticValues), dynamicValues}; } -bool hasNegativeDimension(SmallVector<int64_t> values) { - return llvm::any_of(values, [](int64_t value) { - return !ShapedType::isDynamic(value) && value < 0; - }); -} - /// Helper to sort `values` according to matching `keys`. template <typename K, typename V> static SmallVector<V> @@ -262,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant); } +bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) { + return llvm::none_of(sizesOrOffsets, [](int64_t value) { + return !ShapedType::isDynamic(value) && value < 0; + }); +} + +bool hasValidStrides(SmallVector<int64_t> strides) { + return llvm::none_of(strides, [](int64_t value) { + return !ShapedType::isDynamic(value) && value == 0; + }); +} + LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, - bool onlyNonNegative) { + bool onlyNonNegative, bool onlyNonZero) { bool valuesChanged = false; for (OpFoldResult &ofr : ofrs) { if (ofr.is<Attribute>()) @@ -273,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, // Note: All ofrs have index type. if (onlyNonNegative && *getConstantIntValue(attr) < 0) continue; + if (onlyNonZero && *getConstantIntValue(attr) == 0) + continue; ofr = attr; valuesChanged = true; } @@ -280,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, return success(valuesChanged); } +LogicalResult +foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) { + return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/false); +} + +LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) { + return foldDynamicIndexList(strides, /*onlyNonNegative=*/false, + /*onlyNonZero=*/true); +} + } // namespace mlir diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a1f8673638ff8..d3406c630f6dd 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<? // ----- +// CHECK-LABEL: func @no_fold_subview_zero_stride +// CHECK: %[[SUBVIEW:.+]] = memref.subview +// CHECK: return %[[SUBVIEW]] +func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>> + return %1 : memref<1xf32, strided<[?], offset: 1>> +} + +// ----- + // CHECK-LABEL: func @no_fold_of_store // CHECK: %[[cst:.+]] = memref.cast %arg // CHECK: memref.store %[[cst]] >From 9a577af49dfc360587a4e45195a6a26b75eab083 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Sat, 2 Dec 2023 18:06:05 +0100 Subject: [PATCH 5/5] Cleanup --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 1 - mlir/test/Dialect/Tensor/canonicalize.mlir | 24 ---------------------- 2 files changed, 25 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c6d947a2427db..b2d52e400e52d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -22,7 +22,6 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" -#include <cstdint> using namespace mlir; using namespace mlir::memref; diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 77978e0896a28..84c44a09aa3dd 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1179,30 +1179,6 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) // ----- -func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> { - %c-1 = arith.constant -1 : index - %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32> - return %e : tensor<?xf32> -} -// CHECK-LABEL: func @no_fold_extract_slice_negative_offset -// CHECK: tensor.extract_slice - -// ----- - -func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> { - %c = arith.constant 0 : index - %const = tensor.empty(%c) : tensor<?xf32> - %insert_val = tensor.empty(%c) : tensor<?xf32> - %c-1 = arith.constant -1 : index - %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32> - return %inserted : tensor<?xf32> -} -// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset -// CHECK: %[[CAST:.*]] = tensor.cast -// CHECK: tensor.insert_slice %[[CAST:.+]] - -// ----- - func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { %c0 = arith.constant dense<42> : tensor<2x8xi32> %0 = tensor.expand_shape %c0 [[0], [1, 2]] _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits