https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/76292
>From 0ff5a0ec09f7c26824bd90e6c7656222ee2448ae Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Sat, 23 Dec 2023 16:32:27 +0100 Subject: [PATCH 1/3] [mlir][vector] Fix invalid `LoadOp` indices being created --- .../Conversion/VectorToSCF/VectorToSCF.cpp | 48 +++++++++++++------ .../Conversion/VectorToSCF/vector-to-scf.mlir | 37 ++++++++++++++ 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 2ee314e9fedfe3..13d2513a88804c 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { this->setHasBoundedRewriteRecursion(); } + static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer, + SmallVector<Value, 8> &loadIndices, + Value iv) { + assert(xferOp.getMask() && "Expected transfer op to have mask"); + + // Add load indices from the previous iteration. + // The mask buffer depends on the permutation map, which makes determining + // the indices quite complex, so this is why we need to "look back" to the + // previous iteration to find the right indices. + Value maskBuffer = getMaskBuffer(xferOp); + for (OpOperand &use : maskBuffer.getUses()) { + // If there is no previous load op, then the indices are empty. + if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) { + Operation::operand_range prevIndices = loadOp.getIndices(); + loadIndices.append(prevIndices.begin(), prevIndices.end()); + break; + } + } + + // In case of broadcast: Use same indices to load from memref + // as before. + if (!xferOp.isBroadcastDim(0)) + loadIndices.push_back(iv); + } + LogicalResult matchAndRewrite(OpTy xferOp, PatternRewriter &rewriter) const override { if (!xferOp->hasAttr(kPassLabel)) @@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { // Find and cast data buffer. How the buffer can be found depends on OpTy. ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); - auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); + Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp); auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType()); - auto castedDataType = unpackOneDim(dataBufferType); + FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType); if (failed(castedDataType)) return failure(); @@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { // If the xferOp has a mask: Find and cast mask buffer. Value castedMaskBuffer; if (xferOp.getMask()) { - auto maskBuffer = getMaskBuffer(xferOp); - auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType()); + Value maskBuffer = getMaskBuffer(xferOp); if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { // Do not unpack a dimension of the mask, if: // * To-be-unpacked transfer op dimension is a broadcast. @@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { } else { // It's safe to assume the mask buffer can be unpacked if the data // buffer was unpacked. - auto castedMaskType = *unpackOneDim(maskBufferType); + auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType()); + MemRefType castedMaskType = *unpackOneDim(maskBufferType); castedMaskBuffer = locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); } @@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { // If old transfer op has a mask: Set mask on new transfer op. // Special case: If the mask of the old transfer op is 1D and - // the - // unpacked dim is not a broadcast, no mask is - // needed on the new transfer op. + // the unpacked dim is not a broadcast, no mask is needed on + // the new transfer op. if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() > 1)) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(newXfer); // Insert load before newXfer. SmallVector<Value, 8> loadIndices; - Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); - // In case of broadcast: Use same indices to load from memref - // as before. - if (!xferOp.isBroadcastDim(0)) - loadIndices.push_back(iv); - + getMaskBufferLoadIndices(xferOp, castedMaskBuffer, + loadIndices, iv); auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, loadIndices); rewriter.updateRootInPlace(newXfer, [&]() { diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index ad78f0c945b24d..8316b4005cc168 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3 // ----- +// Check that the `TransferOpConversion` generates valid indices for the LoadOp. + +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)> +func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> { + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1} + : memref<1x1x1x1xi32>, vector<1x1x1x1xi32> + return %3 : vector<1x1x1x1xi32> +} +// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim +// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>> +// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>> +// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>> + +// ----- + +// Check that the `TransferOpConversion` generates valid indices for the StoreOp. +// This test is pulled from an integration test for ArmSVE. + +func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 2 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32> + %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1> + %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32> + return %vector_a : vector<1x2x[4]xf32> +} +// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors +// CHECK: scf.for +// CHECK: scf.for +// CHECK: memref.load + +// ----- + // FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) { // FULL-UNROLL-NOT: vector.extract >From 677a56546428c02d55d9afcf453682f1029404ff Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Tue, 2 Jan 2024 18:42:04 +0100 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Mehdi Amini <joker....@gmail.com> --- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 13d2513a88804c..e7abe57db984b4 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -867,7 +867,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { } static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer, - SmallVector<Value, 8> &loadIndices, + SmallVectorImpl<Value> &loadIndices, Value iv) { assert(xferOp.getMask() && "Expected transfer op to have mask"); @@ -876,9 +876,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { // the indices quite complex, so this is why we need to "look back" to the // previous iteration to find the right indices. Value maskBuffer = getMaskBuffer(xferOp); - for (OpOperand &use : maskBuffer.getUses()) { + for (Operation *user : maskBuffer.getUsers()) { // If there is no previous load op, then the indices are empty. - if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) { + if (auto loadOp = dyn_cast<memref::LoadOp>(user)) { Operation::operand_range prevIndices = loadOp.getIndices(); loadIndices.append(prevIndices.begin(), prevIndices.end()); break; >From ec9d8d75077b26c2efa92063ec659ba2dd89d8b7 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Wed, 3 Jan 2024 07:33:49 +0100 Subject: [PATCH 3/3] Use `cast` instead of `dyn_cast` --- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index e7abe57db984b4..a1aff1ab36a52b 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -921,7 +921,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { } else { // It's safe to assume the mask buffer can be unpacked if the data // buffer was unpacked. - auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType()); + auto maskBufferType = cast<MemRefType>(maskBuffer.getType()); MemRefType castedMaskType = *unpackOneDim(maskBufferType); castedMaskBuffer = locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits