llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-llvm-ir Author: Changpeng Fang (changpeng) <details> <summary>Changes</summary> --- Patch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff 31 Files Affected: - (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+1) - (modified) clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp (+5) - (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl (+12) - (modified) clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl (+7) - (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+15) - (modified) llvm/lib/IR/Verifier.cpp (+48) - (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+41) - (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+1) - (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+79) - (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+46-1) - (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1) - (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+42) - (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+6) - (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+2) - (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+10) - (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+6) - (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1) - (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+15) - (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23) - (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+8) - (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+89-33) - (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+2) - (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+9) - (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+552) - (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll (+105) - (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll (+67) - (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+65) - (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+76) - (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+39) - (added) llvm/test/Transforms/InstCombine/AMDGPU/wmma-f8f6f4.ll (+158) - (added) llvm/test/Verifier/AMDGPU/wmma-f8f6f4.ll (+165) ``````````diff diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def index d4fef5d46af73..878543566f0e3 100644 --- a/clang/include/clang/Basic/BuiltinsAMDGPU.def +++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def @@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32") TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32") TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32") +TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32") TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32") TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32") TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32") diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp index ee736a2816218..7dccf82b1a7a3 100644 --- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp @@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID, case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8: case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8: case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8: + case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4: case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4: case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16: case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16: @@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID, ArgsForMatchingMatrixTypes = {4, 1}; BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8; break; + case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4: + ArgsForMatchingMatrixTypes = {5, 1, 3}; + BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4; + break; case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4: ArgsForMatchingMatrixTypes = {3, 0, 1}; BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4; diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl index e4ef3defdb341..86c27d48ab0d4 100644 --- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl +++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl @@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c) *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true); } +// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4( +// CHECK-GFX1250-NEXT: entry: +// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11> +// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]]) +// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]] +// CHECK-GFX1250-NEXT: ret void +// +void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c) +{ + *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c); +} + // CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16( // CHECK-GFX1250-NEXT: entry: // CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true) diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl index 55d705e6ad238..8aa7c34672783 100644 --- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl +++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl @@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}} } +void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod) +{ + *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}} + *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}} + *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}} +} + void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod) { *out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}} diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td index ecda6c4efefe3..8bfa34584c3a4 100644 --- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td +++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td @@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> : IntrWillReturn, IntrNoCallback, IntrNoFree] >; +class AMDGPUWmmaIntrinsicModsC_MatrixFMT : + Intrinsic< + [llvm_anyfloat_ty], // %D + [ + llvm_i32_ty, // matrix_a_fmt + llvm_anyint_ty, // %A + llvm_i32_ty, // matrix_b_fmt + llvm_anyint_ty, // %B + llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs) + LLVMMatchType<0>, // %C + ], + [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree] +>; + defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = { def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>; def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>; @@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>; def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>; def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>; +def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT; def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>; } diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 8c8ed3c5e47ba..40d20ee92e4af 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) { "invalid vector type for format", &Call, Src1, Call.getArgOperand(5)); break; } + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: { + Value *Src0 = Call.getArgOperand(1); + Value *Src1 = Call.getArgOperand(3); + + unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue(); + unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue(); + Check(FmtA <= 4, "invalid value for matrix format", Call, + Call.getArgOperand(0)); + Check(FmtB <= 4, "invalid value for matrix format", Call, + Call.getArgOperand(2)); + + // AMDGPU::MatrixFMT values + auto getFormatNumRegs = [](unsigned FormatVal) { + switch (FormatVal) { + case 0: + case 1: + return 16u; + case 2: + case 3: + return 12u; + case 4: + return 8u; + default: + llvm_unreachable("invalid format value"); + } + }; + + auto isValidSrcASrcBVector = [](FixedVectorType *Ty) { + if (!Ty || !Ty->getElementType()->isIntegerTy(32)) + return false; + unsigned NumElts = Ty->getNumElements(); + return NumElts == 16 || NumElts == 12 || NumElts == 8; + }; + + auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType()); + auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType()); + Check(isValidSrcASrcBVector(Src0Ty), + "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0); + Check(isValidSrcASrcBVector(Src1Ty), + "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1); + + // Permit excess registers for the format. + Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA), + "invalid vector type for format", &Call, Src0, Call.getArgOperand(0)); + Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB), + "invalid vector type for format", &Call, Src1, Call.getArgOperand(2)); + break; + } case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32: case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: { Value *V = Call.getArgOperand(0); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp index e2c2e8912c715..f2207ff4cb1c4 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp @@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { NewII->takeName(&II); return IC.replaceInstUsesWith(II, NewII); } + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: { + Value *Src0 = II.getArgOperand(1); + Value *Src1 = II.getArgOperand(3); + unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); + uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + auto *Src0Ty = cast<FixedVectorType>(Src0->getType()); + auto *Src1Ty = cast<FixedVectorType>(Src1->getType()); + + bool MadeChange = false; + unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA); + unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB); + + // Depending on the used format, fewer registers are required so shrink the + // vector type. + if (Src0Ty->getNumElements() > Src0NumElts) { + Src0 = IC.Builder.CreateExtractVector( + FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0, + IC.Builder.getInt64(0)); + MadeChange = true; + } + + if (Src1Ty->getNumElements() > Src1NumElts) { + Src1 = IC.Builder.CreateExtractVector( + FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1, + IC.Builder.getInt64(0)); + MadeChange = true; + } + + if (!MadeChange) + return std::nullopt; + + SmallVector<Value *, 13> Args(II.args()); + Args[1] = Src0; + Args[3] = Src1; + + CallInst *NewII = IC.Builder.CreateIntrinsic( + IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()}, + Args, &II); + NewII->takeName(&II); + return IC.replaceInstUsesWith(II, NewII); + } } if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr = AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) { diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index bf2f37bddb9ed..8ef4745bd0b6b 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8: case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8: case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8: + case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: case Intrinsic::amdgcn_wmma_f32_32x16x128_f4: case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16: case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16: diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index 43d4e8db791b0..b83d88b55e72a 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand { ImmTyWaitVAVDst, ImmTyWaitVMVSrc, ImmTyBitOp3, + ImmTyMatrixAFMT, + ImmTyMatrixBFMT, ImmTyMatrixAReuse, ImmTyMatrixBReuse, ImmTyByteSel, @@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand { bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); } bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); } bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); } + bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); } + bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); } bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); } bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); } bool isTFE() const { return isImmTy(ImmTyTFE); } @@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand { case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break; case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break; case ImmTyBitOp3: OS << "BitOp3"; break; + case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break; + case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break; case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break; case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break; case ImmTyByteSel: OS << "ByteSel" ; break; @@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser { ParseStatus parseIndexKey8bit(OperandVector &Operands); ParseStatus parseIndexKey16bit(OperandVector &Operands); ParseStatus parseIndexKey32bit(OperandVector &Operands); + ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name, + AMDGPUOperand::ImmTy Type); + ParseStatus parseMatrixAFMT(OperandVector &Operands); + ParseStatus parseMatrixBFMT(OperandVector &Operands); ParseStatus parseDfmtNfmt(int64_t &Format); ParseStatus parseUfmt(int64_t &Format); @@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser { const unsigned CPol); bool validateTFE(const MCInst &Inst, const OperandVector &Operands); std::optional<StringRef> validateLdsDirect(const MCInst &Inst); + bool validateWMMA(const MCInst &Inst, const OperandVector &Operands); unsigned getConstantBusLimit(unsigned Opcode) const; bool usesConstantBus(const MCInst &Inst, unsigned OpIdx); bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const; @@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst, return true; } +bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst, + const OperandVector &Operands) { + unsigned Opc = Inst.getOpcode(); + const MCRegisterInfo *TRI = getContext().getRegisterInfo(); + const MCInstrDesc &Desc = MII.get(Opc); + + auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool { + int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp); + if (FmtIdx == -1) + return true; + unsigned Fmt = Inst.getOperand(FmtIdx).getImm(); + int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp); + unsigned RegSize = + TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits(); + + if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32) + return true; + + static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8", + "MATRIX_FMT_FP6", "MATRIX_FMT_BF6", + "MATRIX_FMT_FP4"}; + + Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands), + "wrong register tuple size for " + Twine(FmtNames[Fmt])); + return false; + }; + + return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) && + validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1); +} + bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst, const SMLoc &IDLoc, const OperandVector &Operands) { @@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst, if (!validateTFE(Inst, Operands)) { return false; } + if (!validateWMMA(Inst, Operands)) { + return false; + } return true; } @@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) { return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit); } +ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands, + StringRef Name, + AMDGPUOperand::ImmTy Type) { + return parseStringOrIntWithPrefix(Operands, Name, + {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8", + "MATRIX_FMT_FP6", "MATRIX_FMT_BF6", + "MATRIX_FMT_FP4"}, + Type); +} + +ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) { + return tryParseMatrixFMT(Operands, "matrix_a_fmt", + AMDGPUOperand::ImmTyMatrixAFMT); +} + +ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) { + return tryParseMatrixFMT(Operands, "matrix_b_fmt", + AMDGPUOperand::ImmTyMatrixBFMT); +} + // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their // values to live in a joint format operand in the MCInst encoding. ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) { @@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands, DefaultVal); } + int MatrixAFMTIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt); + if (MatrixAFMTIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixAFMT, 0); + } + + int MatrixBFMTIdx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt); + if (MatrixBFMTIdx != -1) { + addOptionalImmOperand(Inst, Operands, OptIdx, + AMDGPUOperand::ImmTyMatrixBFMT, 0); + } + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse)) addOptionalImmOperand(Inst, Operands, OptIdx, AMDGPUOperand::ImmTyMatrixAReuse, 0); diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp index 98f7e17e9528c..5c1989b345bdc 100644 --- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp +++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp @@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size, if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI) convertMAIInst(MI); + if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA) + convertWMMAInst(MI); + int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::vdst_in); if (VDstIn_Idx != -1) { @@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI, return MO.setReg( MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5)); case 8: + if (MCRegister NewReg = MRI.getSubReg( + MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/149684 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits