llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-tensor Author: Rik Huijzer (rikhuijzer) <details> <summary>Changes</summary> This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is another shot at the refactoring proposed in https://github.com/llvm/llvm-project/pull/72885. --- Full diff: https://github.com/llvm/llvm-project/pull/74200.diff 5 Files Affected: - (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+27-3) - (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-11) - (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-10) - (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+26-1) - (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+12) ``````````diff diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 502ab93ddbfa7..a1853438ccf7f 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -139,12 +139,36 @@ SmallVector<int64_t> getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, llvm::function_ref<bool(Attribute, Attribute)> compare); +/// Helper function to check whether the passed in `sizes` or `values` are +/// valid. This can be used to re-check whether dimensions are still valid +/// after constant folding the dynamic dimensions. +bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets); + +/// Helper function to check whether the passed in `strides` are valid. This +/// can be used to re-check whether dimensions are still valid after constant +/// folding the dynamic dimensions. +bool hasValidStrides(SmallVector<int64_t> strides); + /// Returns "success" when any of the elements in `ofrs` is a constant value. In /// that case the value is replaced by an attribute. Returns "failure" when no -/// folding happened. If `onlyNonNegative` is set, only non-negative constant -/// values are folded. +/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only +/// non-negative and non-zero constant values are folded respectively. LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, - bool onlyNonNegative = false); + bool onlyNonNegative = false, + bool onlyNonZero = false); + +/// Returns "success" when any of the elements in `OffsetsOrSizes` is a +/// constant value. In that case the value is replaced by an attribute. Returns +/// "failure" when no folding happened. Invalid values are not folded to avoid +/// canonicalization crashes. +LogicalResult +foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes); + +/// Returns "success" when any of the elements in `strides` is a constant +/// value. In that case the value is replaced by an attribute. Returns +/// "failure" when no folding happened. Invalid values are not folded to avoid +/// canonicalization crashes. +LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides); /// Return the number of iterations for a loop with a lower bound `lb`, upper /// bound `ub` and step `step`. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index dce96cca016ff..b2d52e400e52d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2581,17 +2581,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - // If one of the offsets or sizes is invalid, fail the canonicalization. - // These checks also occur in the verifier, but they are needed here - // because some dynamic dimensions may have been constant folded. - for (int64_t offset : staticOffsets) - if (offset < 0 && !ShapedType::isDynamic(offset)) - return {}; - for (int64_t size : staticSizes) - if (size < 0 && !ShapedType::isDynamic(size)) - return {}; - + if (!hasValidSizesOffsets(staticOffsets)) + return {}; + if (!hasValidSizesOffsets(staticSizes)) + return {}; + if (!hasValidStrides(staticStrides)) + return {}; return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 8970ea1c73b40..94b7b734f88fe 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1446,13 +1446,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { SmallVector<int64_t> newShape; operandsAndShape(resultType, dynamicExtents, newOperands, newShape); - for (int64_t newdim : newShape) { - // This check also occurs in the verifier, but we need it here too - // since intermediate passes may have replaced some dynamic dimensions - // by constants. - if (newdim < 0 && !ShapedType::isDynamic(newdim)) - return failure(); - } + if (!hasValidSizesOffsets(newShape)) + return failure(); if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); @@ -2548,9 +2543,9 @@ class InsertSliceOpConstantArgumentFolder final SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides()); // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedStrides))) + if (failed(foldDynamicOffsetSizeList(mixedOffsets)) && + failed(foldDynamicOffsetSizeList(mixedSizes)) && + failed(foldDynamicStrideList(mixedStrides))) return failure(); // Create the new op in canonical form. @@ -2691,6 +2686,8 @@ struct InsertSliceOpSourceCastInserter final newSrcShape[i] = *constInt; } } + if (!hasValidSizesOffsets(newSrcShape)) + return failure(); RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index c7a3d8fc8eb28..0c8a88da789e2 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant); } +bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) { + return llvm::none_of(sizesOrOffsets, [](int64_t value) { + return !ShapedType::isDynamic(value) && value < 0; + }); +} + +bool hasValidStrides(SmallVector<int64_t> strides) { + return llvm::none_of(strides, [](int64_t value) { + return !ShapedType::isDynamic(value) && value == 0; + }); +} + LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, - bool onlyNonNegative) { + bool onlyNonNegative, bool onlyNonZero) { bool valuesChanged = false; for (OpFoldResult &ofr : ofrs) { if (ofr.is<Attribute>()) @@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, // Note: All ofrs have index type. if (onlyNonNegative && *getConstantIntValue(attr) < 0) continue; + if (onlyNonZero && *getConstantIntValue(attr) == 0) + continue; ofr = attr; valuesChanged = true; } @@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, return success(valuesChanged); } +LogicalResult +foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) { + return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/false); +} + +LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) { + return foldDynamicIndexList(strides, /*onlyNonNegative=*/false, + /*onlyNonZero=*/true); +} + } // namespace mlir diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a1f8673638ff8..d3406c630f6dd 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<? // ----- +// CHECK-LABEL: func @no_fold_subview_zero_stride +// CHECK: %[[SUBVIEW:.+]] = memref.subview +// CHECK: return %[[SUBVIEW]] +func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>> + return %1 : memref<1xf32, strided<[?], offset: 1>> +} + +// ----- + // CHECK-LABEL: func @no_fold_of_store // CHECK: %[[cst:.+]] = memref.cast %arg // CHECK: memref.store %[[cst]] `````````` </details> https://github.com/llvm/llvm-project/pull/74200 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits