https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/65441
>From 7b71da55fca8fe2a7dbe4982b1959be6a6175fa1 Mon Sep 17 00:00:00 2001 From: Guray Ozen <guray.o...@gmail.com> Date: Thu, 7 Sep 2023 11:52:38 +0200 Subject: [PATCH 1/6] [MLIR][NVGPU] Introduce `nvgpu.warpgroup.mma.store` Op for Hopper GPUs This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref. An example of fragmentation is given here : https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d The `warpgroup.mma.store` does followings: 1) Takes one or more fragmented results matrix. 2) Calculates indexes per thread in warp group and stores the data into give memref. Here's an example usage of the `nvgpu.warpgroup.mma` operation: ``` // Performs matmul, results are fragmented and in registers %res, %res2 = nvgpu.warpgroup.mma ... // Stores the fragmented result to the give memory nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD : !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>, !nvgpu.warpgroup.result<tensor = !llvm.struct<...>> to memref<128x128xf32,3> ``` Depends on #65440 --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 19 +++++ .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 83 ++++++++++++++++++- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 29 +++++++ 3 files changed, 129 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 90381648dac6acc..e102ae0dc581013 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -721,4 +721,23 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> { let hasVerifier = 1; } +def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> { + let description = [{ + The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result + in $matrixD to give memref. + + [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + + Note that, the op must be run with warp group. + }]; + + let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD, + Arg<AnyMemRef, "", [MemWrite]>:$dstMemref); + + let assemblyFormat = [{ + `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref) + }]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index f74aa05c0c4c4ff..4f1a0bc651e81b7 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -11,6 +11,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>(); + registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, + arith::ArithDialect>(); } void runOnOperation() override { @@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass populateNVGPUToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::arith::ArithDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); mlir::scf::populateSCFStructuralTypeConversionsAndLegality( @@ -1299,11 +1301,88 @@ struct NVGPUWarpgroupMmaOpLowering } }; +struct NVGPUWarpgroupMmaStoreOpLowering + : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { + using ConvertOpToLLVMPattern< + nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; + + void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + int offset) const { + Location loc = op->getLoc(); + Type i32 = rewriter.getI32Type(); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create<LLVM::ConstantOp>( + loc, i32, rewriter.getI32IntegerAttr(index)); + }; + Value c4 = makeConst(4); + Value c32 = makeConst(kWarpSize); + Value c8 = makeConst(8); + Value c2 = makeConst(2); + Value c1 = makeConst(1); + Value c16 = makeConst(16); + + auto makeMul = [&](Value lhs, Value rhs) -> Value { + return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs); + }; + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs); + }; + + Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32); + Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32); + Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32); + Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4); + Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4); + + auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, + TypedValue<::mlir::MemRefType> memref) { + Type it = rewriter.getIndexType(); + Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x); + Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y); + Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1)); + Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i); + Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1); + rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0}); + rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1}); + }; + + Value tj = makeMul(lane4modId, c2); + Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); + if (offset) + ti = makeAdd(ti, makeConst(offset)); + for (int i = 0; i < 2; ++i) { + Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); + for (int j = 0; j < 16; ++j) { + Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); + int sIndex = i * 2 + j * 4; + makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref()); + } + } + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int offset = 0; + for (auto result : adaptor.getMatrixD()) { + auto stype = result.getType().cast<LLVM::LLVMStructType>(); + storeFragmentedMatrix(result, op, adaptor, rewriter, offset); + offset += stype.getBody().size(); + } + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add< + NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store` NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index d96ed69982870b4..fc85df1654198d5 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -529,6 +530,34 @@ LogicalResult WarpgroupMmaOp::verify() { return success(); } +LogicalResult WarpgroupMmaStoreOp::verify() { + Type stype = + getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor(); + + for (auto result : getMatrixD()) { + auto resultStype = result.getType() + .cast<WarpgroupResultType>() + .getTensor() + .dyn_cast<LLVM::LLVMStructType>(); + if (!resultStype) + return emitOpError() << "result is " << result.getType() + << " but must keep type of llvm struct"; + if (stype != resultStype) + return emitOpError() << "all results must be the same type"; + + // todo improve this limitation + if (!resultStype.getBody().front().isF32()) { + return emitOpError() << "supporst only f32 results for the time being"; + } + } + + if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) { + return emitOpError() << "all element types must be equal "; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// >From 4a1824f3e6ae955b78f7262178fa1b8e4608e3da Mon Sep 17 00:00:00 2001 From: Guray Ozen <guray.o...@gmail.com> Date: Fri, 22 Sep 2023 16:53:21 +0200 Subject: [PATCH 2/6] use new type `WarpgroupAccumulator` --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 5 +++-- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 11 +++++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index e102ae0dc581013..4e80c33aec6043d 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -726,12 +726,13 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> { The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result in $matrixD to give memref. - [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + [See the details of register fragment layout for accumulator matrix D] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) Note that, the op must be run with warp group. }]; - let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD, + let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD, Arg<AnyMemRef, "", [MemWrite]>:$dstMemref); let assemblyFormat = [{ diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 4f1a0bc651e81b7..006ecbef2546e3e 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1382,7 +1382,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add< - NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store` NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive @@ -1394,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma + NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store` MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index fc85df1654198d5..1486bba5d3e57f6 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -531,13 +531,16 @@ LogicalResult WarpgroupMmaOp::verify() { } LogicalResult WarpgroupMmaStoreOp::verify() { - Type stype = - getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor(); + Type stype = getMatrixD() + .front() + .getType() + .cast<WarpgroupAccumulatorType>() + .getFragmented(); for (auto result : getMatrixD()) { auto resultStype = result.getType() - .cast<WarpgroupResultType>() - .getTensor() + .cast<WarpgroupAccumulatorType>() + .getFragmented() .dyn_cast<LLVM::LLVMStructType>(); if (!resultStype) return emitOpError() << "result is " << result.getType() >From e60310d10c8e43669402e432cd130383cdf7a837 Mon Sep 17 00:00:00 2001 From: Guray Ozen <guray.o...@gmail.com> Date: Wed, 27 Sep 2023 09:41:45 +0200 Subject: [PATCH 3/6] add test --- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index f011007e040ce9c..93123cecbc38f94 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -732,6 +732,133 @@ func.func @warpgroup_mma_128_128_64( return } +// CHECK-LABEL: @warpgroup_mma_store( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>) +func.func @warpgroup_mma_store( + %result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + %matrixD: memref<128x128xf32,3>) { +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : +// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32 + +// ### Store {d0, d1} of each thread ### + +// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32 +// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32 +// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32 +// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32 +// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32 +// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32 +// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32 +// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32 +// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32 +// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32 +// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32 +// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index +// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32 +// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index +// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct +// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct +// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3> + +// ### Store {d2, d3} of each thread ### + +// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32 +// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32 +// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index +// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32 +// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index +// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct< +// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct< +// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3> + +// ### Store {d4, d5} of each thread ### + +// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32 +// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32 +// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index +// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32 +// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index +// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct< +// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct< +// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3> + +// ### Store {d6, d7} of each thread ### + +// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32 +// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32 +// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index +// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32 +// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index +// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct< +// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct< +// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3> + +// Pattern continues similarly 28x times until {... d62, d63} + +// ### Store {d64, d65} of each thread ### + +// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]] : i32 +// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]] : i32 +// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]] +// CHECK: %[[S321:.+]] = llvm.urem %[[S318]] +// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]] : i32 +// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32 +// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32 +// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32 +// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32 +// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32 +// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32 +// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32 +// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32 +// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index +// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index +// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32 +// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index +// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0] +// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1] +// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3> + +// Pattern continues similarly 31x times until {... d126, d127} + + nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD : + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> + to memref<128x128xf32,3> + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 >From fe03a52c573c287efba3e9c77837a5d91a1e3ad1 Mon Sep 17 00:00:00 2001 From: Guray Ozen <guray.o...@gmail.com> Date: Wed, 27 Sep 2023 09:41:53 +0200 Subject: [PATCH 4/6] better verification --- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 48 +++++++++++----------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 1486bba5d3e57f6..b9994aced0be7f4 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -531,33 +531,35 @@ LogicalResult WarpgroupMmaOp::verify() { } LogicalResult WarpgroupMmaStoreOp::verify() { - Type stype = getMatrixD() - .front() - .getType() - .cast<WarpgroupAccumulatorType>() - .getFragmented(); - + MemRefType dstMemrefType = getDstMemref().getType(); + VectorType firstVtype = getMatrixD() + .front() + .getType() + .cast<WarpgroupAccumulatorType>() + .getFragmented(); + + int64_t totalFirstDimension = 0; for (auto result : getMatrixD()) { - auto resultStype = result.getType() - .cast<WarpgroupAccumulatorType>() - .getFragmented() - .dyn_cast<LLVM::LLVMStructType>(); - if (!resultStype) - return emitOpError() << "result is " << result.getType() - << " but must keep type of llvm struct"; - if (stype != resultStype) - return emitOpError() << "all results must be the same type"; - - // todo improve this limitation - if (!resultStype.getBody().front().isF32()) { - return emitOpError() << "supporst only f32 results for the time being"; + VectorType vtype = + result.getType().cast<WarpgroupAccumulatorType>().getFragmented(); + if (vtype != firstVtype) + return emitOpError() << "all fragmented types must be the same"; + // Limitation + if (!vtype.getElementType().isF32()) { + return emitOpError() + << "hit a limitation: only f32 results for the time being"; } + totalFirstDimension += vtype.getDimSize(0); } - - if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) { - return emitOpError() << "all element types must be equal "; + if (totalFirstDimension != dstMemrefType.getDimSize(0) || + firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) { + return emitOpError() << "results [" << totalFirstDimension << "][" + << firstVtype.getDimSize(1) + << "] values. However, destination memref[" + << dstMemrefType.getDimSize(0) << "][" + << dstMemrefType.getDimSize(1) + << "] does not have same size as results"; } - return success(); } >From 2ba9de5e479a6a680b4cacf2b8a90d8da87115c1 Mon Sep 17 00:00:00 2001 From: Guray Ozen <guray.o...@gmail.com> Date: Wed, 27 Sep 2023 09:52:01 +0200 Subject: [PATCH 5/6] fix test --- mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 93123cecbc38f94..cd8222c1c0ce585 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -739,8 +739,7 @@ func.func @warpgroup_mma_store( %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, %matrixD: memref<128x128xf32,3>) { // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -// CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -// CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32 // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32 >From c894a472ba7f703a02e2da19fcd09a3e71e15517 Mon Sep 17 00:00:00 2001 From: Guray Ozen <guray.o...@gmail.com> Date: Mon, 2 Oct 2023 11:09:23 +0200 Subject: [PATCH 6/6] address @qcolombet comments --- .../lib/Conversion/NVGPUToNVVM/CMakeLists.txt | 1 + .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 105 +++++++++++++----- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 28 ++--- .../llvm-project-overlay/mlir/BUILD.bazel | 1 + 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt index 00e775ce7dd22bc..a050749eb7da87e 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM MLIRLLVMDialect MLIRNVGPUDialect MLIRNVVMDialect + MLIRArithDialect MLIRPass MLIRSCFTransforms MLIRTransforms diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 006ecbef2546e3e..c02026977922e81 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -52,6 +53,16 @@ static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc, return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value); } +/// Returns warp-size as a value. +static Value getWarpSizeValue(ImplicitLocOpBuilder &b) { + static std::optional<Value> warpSize = std::nullopt; + if (!warpSize.has_value()) { + warpSize = b.create<LLVM::ConstantOp>(IntegerType::get(b.getContext(), 32), + b.getI32IntegerAttr(kWarpSize)); + } + return warpSize.value(); +} + /// Returns the type for the intrinsic given the vectorResultType of the /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { @@ -1306,47 +1317,80 @@ struct NVGPUWarpgroupMmaStoreOpLowering using ConvertOpToLLVMPattern< nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; - void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, + /// This function stores a fragmented register matrix owned by a warp group + /// (128 threads) into a memref. Each thread has 64 registers, each the size + /// of a struct. + /// Here is what each threads (T) holds, each `d` is struct value with a + /// number. + /// + /// Threads in warp-group (128 threads) and what they owns in the matriD: + /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] + /// 32-63 Warp-1 -> MatrixD[16:31][0:N] + /// 64-95 Warp-2 -> MatrixD[32:47][0:N] + /// 96-127 Warp-3 -> MatrixD[48:64][0:N] + /// + /// Matrix-D: + /// +______________________________________________________________________+ + /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | + /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| + /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| + /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| + /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| + /// ..| .........|.........|.........|.........|........|...........|........| + /// +______________________________________________________________________+ + /// + /// \param rewriter: The pattern rewriter. + /// \param matrixD: Result of the warp-group MMA operation (fragmented + /// matrix). It is holded by a thread and a struct with 64 elements. + /// \param dstMemref: The memref where the registers will be stored. + /// \param offset: the offset within the memref where the registers will be + /// stored. + void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, + TypedValue<MemRefType> dstMemref, int offset) const { - Location loc = op->getLoc(); - Type i32 = rewriter.getI32Type(); + Type i32 = b.getI32Type(); auto makeConst = [&](int32_t index) -> Value { - return rewriter.create<LLVM::ConstantOp>( - loc, i32, rewriter.getI32IntegerAttr(index)); + return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); }; + Value c1 = makeConst(1); + Value c2 = makeConst(2); Value c4 = makeConst(4); - Value c32 = makeConst(kWarpSize); Value c8 = makeConst(8); - Value c2 = makeConst(2); - Value c1 = makeConst(1); Value c16 = makeConst(16); + Value warpSize = getWarpSizeValue(b); auto makeMul = [&](Value lhs, Value rhs) -> Value { - return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs); + return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); }; auto makeAdd = [&](Value lhs, Value rhs) -> Value { - return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs); + return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); }; - Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32); - Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32); - Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32); - Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4); - Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4); + Value tidx = b.create<NVVM::ThreadIdXOp>(i32); + Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); + Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); + Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); + Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, TypedValue<::mlir::MemRefType> memref) { - Type it = rewriter.getIndexType(); - Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x); - Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y); - Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1)); - Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i); - Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1); - rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0}); - rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1}); + Type it = b.getIndexType(); + Value idx = b.create<arith::IndexCastOp>(it, x); + Value idy0 = b.create<arith::IndexCastOp>(it, y); + Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); + Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); + Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); + b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); + b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); }; Value tj = makeMul(lane4modId, c2); @@ -1358,7 +1402,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering for (int j = 0; j < 16; ++j) { Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); int sIndex = i * 2 + j * 4; - makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref()); + makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref); } } } @@ -1367,10 +1411,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int offset = 0; - for (auto result : adaptor.getMatrixD()) { - auto stype = result.getType().cast<LLVM::LLVMStructType>(); - storeFragmentedMatrix(result, op, adaptor, rewriter, offset); - offset += stype.getBody().size(); + ImplicitLocOpBuilder lb(op->getLoc(), rewriter); + for (Value matrixD : adaptor.getMatrixD()) { + auto structType = matrixD.getType().cast<LLVM::LLVMStructType>(); + storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset); + offset += structType.getBody().size(); } rewriter.eraseOp(op); return success(); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index cd8222c1c0ce585..7e9b2f3ed01c862 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -740,18 +740,18 @@ func.func @warpgroup_mma_store( %matrixD: memref<128x128xf32,3>) { // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32 -// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32 // CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32 -// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32 -// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32 // ### Store {d0, d1} of each thread ### // CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32 -// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32 -// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32 +// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32 +// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32 // CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32 // CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32 // CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32 @@ -816,20 +816,22 @@ func.func @warpgroup_mma_store( // Pattern continues similarly 28x times until {... d62, d63} +// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32 + // ### Store {d64, d65} of each thread ### +// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32 -// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32 // CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32 -// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32 -// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32 -// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]] : i32 -// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]] : i32 -// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]] -// CHECK: %[[S321:.+]] = llvm.urem %[[S318]] -// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]] : i32 +// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WarpSize]] : i32 +// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WarpSize]] : i32 +// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32 +// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32 +// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32 // CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32 // CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32 // CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 3c30afabfbb204b..8e9d22016526ceb 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5237,6 +5237,7 @@ cc_library( ":LLVMCommonConversion", ":LLVMDialect", ":MemRefDialect", + ":MLIRArithDialect", ":NVGPUDialect", ":NVVMDialect", ":Pass", _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits