llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-llvm-ir

Author: Changpeng Fang (changpeng)

<details>
<summary>Changes</summary>



---

Patch is 132.52 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/149684.diff


31 Files Affected:

- (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+1) 
- (modified) clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp (+5) 
- (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl (+12) 
- (modified) 
clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl (+7) 
- (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+15) 
- (modified) llvm/lib/IR/Verifier.cpp (+48) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+41) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+1) 
- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+79) 
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+46-1) 
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+42) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+6) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+2) 
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+10) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+6) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1) 
- (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+15) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+8) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+89-33) 
- (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+2) 
- (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+9) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+552) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll 
(+105) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll 
(+67) 
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+65) 
- (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+76) 
- (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+39) 
- (added) llvm/test/Transforms/InstCombine/AMDGPU/wmma-f8f6f4.ll (+158) 
- (added) llvm/test/Verifier/AMDGPU/wmma-f8f6f4.ll (+165) 


``````````diff
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def 
b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, 
"V8hV16iV16iIsV8hIbI
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, 
"V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, 
"V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, 
"V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, 
"V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, 
"V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, 
"V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, 
"V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp 
b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned 
BuiltinID,
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned 
BuiltinID,
       ArgsForMatchingMatrixTypes = {4, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+      ArgsForMatchingMatrixTypes = {5, 1, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+      break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
       ArgsForMatchingMatrixTypes = {3, 0, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl 
b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, 
v8i a, v8i b, v8i c)
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
 }
 
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT:  entry:
+// CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], 
<16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, 
i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT:    [[TMP1:%.*]] = tail call <8 x float> 
@llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> 
[[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT:    store <8 x float> [[TMP1]], ptr addrspace(1) 
[[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT:    ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, 
v8f c)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
 // CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
 // CHECK-GFX1250-NEXT:  entry:
 // CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = tail call <8 x float> 
@llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> 
[[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 
false, i1 true)
diff --git 
a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl 
b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, 
v8i a, v8i b, v8i c, int
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // 
expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant 
integer}}
 }
 
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, 
v8f c, int mod)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // 
expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a 
constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // 
expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a 
constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // 
expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a 
constant integer}}
+}
+
 void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, 
int mod)
 {
   *out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, 
false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a 
constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td 
b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, 
LLVMType AB, LLVMType C> :
      IntrWillReturn, IntrNoCallback, IntrNoFree]
 >;
 
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+  Intrinsic<
+    [llvm_anyfloat_ty], // %D
+    [
+      llvm_i32_ty,      // matrix_a_fmt
+      llvm_anyint_ty,   // %A
+      llvm_i32_ty,      // matrix_b_fmt
+      llvm_anyint_ty,   // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>, // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, 
ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : 
AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_bf16     : 
AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : 
AMDGPUWmmaIntrinsicModsC<llvm_anyint
 def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : 
AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : 
AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_i32_16x16x64_iu8      : 
AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4  : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
 def int_amdgcn_wmma_f32_32x16x128_f4       : 
AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 }
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, 
CallBase &Call) {
           "invalid vector type for format", &Call, Src1, 
Call.getArgOperand(5));
     break;
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = Call.getArgOperand(1);
+    Value *Src1 = Call.getArgOperand(3);
+
+    unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+    unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+    Check(FmtA <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(0));
+    Check(FmtB <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(2));
+
+    // AMDGPU::MatrixFMT values
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case 0:
+      case 1:
+        return 16u;
+      case 2:
+      case 3:
+        return 12u;
+      case 4:
+        return 8u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+      if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+        return false;
+      unsigned NumElts = Ty->getNumElements();
+      return NumElts == 16 || NumElts == 12 || NumElts == 8;
+    };
+
+    auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+    Check(isValidSrcASrcBVector(Src0Ty),
+          "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+    Check(isValidSrcASrcBVector(Src1Ty),
+          "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+    // Permit excess registers for the format.
+    Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+          "invalid vector type for format", &Call, Src0, 
Call.getArgOperand(0));
+    Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+          "invalid vector type for format", &Call, Src1, 
Call.getArgOperand(2));
+    break;
+  }
   case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
   case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
     Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp 
b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, 
IntrinsicInst &II) const {
     NewII->takeName(&II);
     return IC.replaceInstUsesWith(II, NewII);
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = II.getArgOperand(1);
+    Value *Src1 = II.getArgOperand(3);
+    unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+    uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+    auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+    bool MadeChange = false;
+    unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+    unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+    // Depending on the used format, fewer registers are required so shrink the
+    // vector type.
+    if (Src0Ty->getNumElements() > Src0NumElts) {
+      Src0 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (Src1Ty->getNumElements() > Src1NumElts) {
+      Src1 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (!MadeChange)
+      return std::nullopt;
+
+    SmallVector<Value *, 13> Args(II.args());
+    Args[1] = Src0;
+    Args[3] = Src1;
+
+    CallInst *NewII = IC.Builder.CreateIntrinsic(
+        IID, {II.getArgOperand(5)->getType(), Src0->getType(), 
Src1->getType()},
+        Args, &II);
+    NewII->takeName(&II);
+    return IC.replaceInstUsesWith(II, NewII);
+  }
   }
   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp 
b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const 
MachineInstr &MI) const {
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
     case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
     case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
     case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp 
b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyWaitVAVDst,
     ImmTyWaitVMVSrc,
     ImmTyBitOp3,
+    ImmTyMatrixAFMT,
+    ImmTyMatrixBFMT,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
   bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+  bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+  bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
     case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
     case ImmTyBitOp3: OS << "BitOp3"; break;
+    case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+    case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
   ParseStatus parseIndexKey8bit(OperandVector &Operands);
   ParseStatus parseIndexKey16bit(OperandVector &Operands);
   ParseStatus parseIndexKey32bit(OperandVector &Operands);
+  ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+                                AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAFMT(OperandVector &Operands);
+  ParseStatus parseMatrixBFMT(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                               const unsigned CPol);
   bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
   std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+  bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
   unsigned getConstantBusLimit(unsigned Opcode) const;
   bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
   bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
   return true;
 }
 
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+                                   const OperandVector &Operands) {
+  unsigned Opc = Inst.getOpcode();
+  const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+  const MCInstrDesc &Desc = MII.get(Opc);
+
+  auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+    int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+    if (FmtIdx == -1)
+      return true;
+    unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+    int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+    unsigned RegSize =
+        TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+    if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+      return true;
+
+    static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"};
+
+    Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+          "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+    return false;
+  };
+
+  return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+         validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
 bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
                                           const SMLoc &IDLoc,
                                           const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst 
&Inst,
   if (!validateTFE(Inst, Operands)) {
     return false;
   }
+  if (!validateWMMA(Inst, Operands)) {
+    return false;
+  }
 
   return true;
 }
@@ -7191,6 +7236,26 @@ ParseStatus 
AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
   return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+                                               StringRef Name,
+                                               AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(Operands, Name,
+                                    {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"},
+                                    Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+                           AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+                           AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const 
OperandVector &Operands,
                           DefaultVal);
   }
 
+  int MatrixAFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+  if (MatrixAFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAFMT, 0);
+  }
+
+  int MatrixBFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+  if (MatrixBFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBFMT, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp 
b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, 
uint64_t &Size,
   if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
     convertMAIInst(MI);
 
+  if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+    convertWMMAInst(MI);
+
   int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
                                               AMDGPU::OpName::vdst_in);
   if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const 
MCRegisterInfo &MRI,
     return MO.setReg(
         MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
   case 8:
+    if (MCRegister NewReg = MRI.getSubReg(
+            MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/149684
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to