yubing updated this revision to Diff 325170.
yubing edited the summary of this revision.
yubing added a comment.

Address comments above and refactor some code


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D96110/new/

https://reviews.llvm.org/D96110

Files:
  clang/include/clang/Basic/BuiltinsX86_64.def
  clang/lib/Headers/amxintrin.h
  llvm/include/llvm/IR/IntrinsicsX86.td
  llvm/lib/Target/X86/X86ExpandPseudo.cpp
  llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
  llvm/lib/Target/X86/X86InstrAMX.td
  llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
  llvm/lib/Target/X86/X86LowerAMXType.cpp
  llvm/lib/Target/X86/X86PreTileConfig.cpp
  llvm/lib/Target/X86/X86RegisterInfo.cpp

Index: llvm/lib/Target/X86/X86RegisterInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -878,6 +878,7 @@
   // We only collect the tile shape that is defined.
   case X86::PTILELOADDV:
   case X86::PTDPBSSDV:
+  case X86::PTDPBF16PSV:
   case X86::PTILEZEROV:
     MachineOperand &MO1 = MI->getOperand(1);
     MachineOperand &MO2 = MI->getOperand(2);
Index: llvm/lib/Target/X86/X86PreTileConfig.cpp
===================================================================
--- llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -127,6 +127,7 @@
     llvm_unreachable("Unexpected machine instruction on tile");
   case X86::PTILELOADDV:
   case X86::PTDPBSSDV:
+  case X86::PTDPBF16PSV:
   case X86::PTILEZEROV:
     MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
     MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
@@ -221,6 +222,7 @@
   case X86::PTILELOADDV:
   case X86::PTILESTOREDV:
   case X86::PTDPBSSDV:
+  case X86::PTDPBF16PSV:
   case X86::PTILEZEROV:
     return true;
   }
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -69,7 +69,8 @@
   }
   // a * b + c
   // The shape depends on which operand.
-  case Intrinsic::x86_tdpbssd_internal: {
+  case Intrinsic::x86_tdpbssd_internal:
+  case Intrinsic::x86_tdpbf16ps_internal: {
     switch (OpNo) {
     case 3:
       Row = II->getArgOperand(0);
Index: llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
+++ llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
@@ -22,7 +22,6 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/Passes.h"
-
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DataLayout.h"
@@ -209,11 +208,11 @@
   B.CreateStore(Elt, EltPtr);
 }
 
-static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End,
-                                    IRBuilderBase &B, DomTreeUpdater &DTU,
-                                    LoopInfo &LI, Value *Row, Value *Col,
-                                    Value *K, Value *Acc, Value *LHS,
-                                    Value *RHS) {
+template <Intrinsic::ID IntrID>
+static Value *createTileDPLoops(BasicBlock *Start, BasicBlock *End,
+                                IRBuilderBase &B, DomTreeUpdater &DTU,
+                                LoopInfo &LI, Value *Row, Value *Col, Value *K,
+                                Value *Acc, Value *LHS, Value *RHS) {
   Loop *RowLoop = LI.AllocateLoop();
   Loop *ColLoop = LI.AllocateLoop();
   Loop *InnerLoop = LI.AllocateLoop();
@@ -321,17 +320,40 @@
       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
   Value *IdxB =
       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
-
-  FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
-  FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
-  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
-  Value *EltA = B.CreateExtractElement(VecA, IdxA);
-  Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
-  Value *EltB = B.CreateExtractElement(VecB, IdxB);
-  Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
-  Value *SubVecR = B.CreateAddReduce(B.CreateMul(
-      B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty)));
-  Value *ResElt = B.CreateAdd(EltC, SubVecR);
+  Value *ResElt = nullptr;
+  if (IntrID == Intrinsic::x86_tdpbssd_internal) {
+    FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+    FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+    Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+    Value *EltA = B.CreateExtractElement(VecA, IdxA);
+    Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
+    Value *EltB = B.CreateExtractElement(VecB, IdxB);
+    Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
+    Value *SubVecR = B.CreateAddReduce(B.CreateMul(
+        B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty)));
+    ResElt = B.CreateAdd(EltC, SubVecR);
+  } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) {
+    FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
+    FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
+    Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+    Value *C_F32 = B.CreateBitCast(EltC, B.getFloatTy());
+    Value *EltA = B.CreateExtractElement(VecA, IdxA);
+    Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
+    Value *EltB = B.CreateExtractElement(VecB, IdxB);
+    Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
+    Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
+    int ShuffleMask[4] = {2, 0, 3, 1};
+    Value *A_V2F32 = B.CreateBitCast(
+        B.CreateShuffleVector(SubVecA, ZeroV2I16, makeArrayRef(ShuffleMask)),
+        V2F32Ty);
+    Value *B_V2F32 = B.CreateBitCast(
+        B.CreateShuffleVector(SubVecB, ZeroV2I16, makeArrayRef(ShuffleMask)),
+        V2F32Ty);
+    Value *SubVecR = B.CreateFAddReduce(C_F32, B.CreateFMul(A_V2F32, B_V2F32));
+    ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
+  } else {
+    llvm_unreachable("it is not a tdpb intrinsic");
+  }
   Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
   Value *NewVecD = B.CreateInsertElement(VecDPhi, ResElt, IdxC);
 
@@ -358,20 +380,20 @@
   DominatorTree *DT;
   LoopInfo *LI;
   bool lowerTileLoad(Instruction *TileLoad);
-  bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+  template <Intrinsic::ID IntrID> bool lowerTileDP(Instruction *TileDP);
   bool lowerTileStore(Instruction *TileStore);
   bool lowerTileZero(Instruction *TileZero);
 };
 
-bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
+template <Intrinsic::ID IntrID>
+bool X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
   Value *M, *N, *K, *C, *A, *B;
-  match(TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
-                        m_Value(M), m_Value(N), m_Value(K), m_Value(C),
-                        m_Value(A), m_Value(B)));
+  match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
+                                    m_Value(C), m_Value(A), m_Value(B)));
   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
-  Instruction *InsertI = TileDPBSSD;
-  IRBuilder<> BuilderPrepare(TileDPBSSD);
-  BuilderPrepare.SetInsertPoint(TileDPBSSD);
+  Instruction *InsertI = TileDP;
+  IRBuilder<> BuilderPrepare(TileDP);
+  BuilderPrepare.SetInsertPoint(TileDP);
   // We visit the loop with (m, n/4, k/4):
   // %n_dword = udiv i16 %n, 4
   // %k_dword = udiv i16 %k, 4
@@ -380,17 +402,16 @@
   BasicBlock *Start = InsertI->getParent();
   BasicBlock *End =
       SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
-  IRBuilder<> Builder(TileDPBSSD);
-  Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M,
-                                        NDWord, KDWord, C, A, B);
-  // we cannot assume there always be bitcast after tiledpbssd. So we need to
+  IRBuilder<> Builder(TileDP);
+  Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, DTU, *LI, M,
+                                            NDWord, KDWord, C, A, B);
+  // we cannot assume there always be bitcast after TileDP. So we need to
   // insert one bitcast as required
   Builder.SetInsertPoint(End->getFirstNonPHI());
   Value *ResAMX =
       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
-  // Delete tiledpbssd intrinsic and do some clean-up.
-  for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end();
-       UI != UE;) {
+  // Delete TileDP intrinsic and do some clean-up.
+  for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
     Instruction *I = cast<Instruction>((UI++)->getUser());
     Value *Vec;
     if (match(I, m_BitCast(m_Value(Vec)))) {
@@ -398,8 +419,8 @@
       I->eraseFromParent();
     }
   }
-  TileDPBSSD->replaceAllUsesWith(ResAMX);
-  TileDPBSSD->eraseFromParent();
+  TileDP->replaceAllUsesWith(ResAMX);
+  TileDP->eraseFromParent();
   return true;
 }
 
@@ -481,6 +502,7 @@
 bool X86LowerAMXIntrinsics::visit() {
   bool C = false;
   SmallVector<Instruction *, 8> TileDPBSSDs;
+  SmallVector<Instruction *, 8> TileDPBF16PSs;
   SmallVector<Instruction *, 8> TileLoads;
   SmallVector<Instruction *, 8> TileStores;
   SmallVector<Instruction *, 8> TileZeros;
@@ -489,6 +511,7 @@
     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
       Instruction &Inst = *II++;
       if (match(&Inst, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>()) ||
+          match(&Inst, m_Intrinsic<Intrinsic::x86_tdpbf16ps_internal>()) ||
           match(&Inst, m_Intrinsic<Intrinsic::x86_tileloadd64_internal>()) ||
           match(&Inst, m_Intrinsic<Intrinsic::x86_tilestored64_internal>()) ||
           match(&Inst, m_Intrinsic<Intrinsic::x86_tilezero_internal>()))
@@ -504,7 +527,13 @@
       // %res = call x86_amx @llvm.x86.tdpbssd.internal(i16 m, i16 n, i16 k,
       //                                                x86_amx, %amx1, ...)
       // %vec2 = bitcast x86_amx %res to <256 x i32>
-      C = lowerTileDPBSSD(Inst) || C;
+      C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
+    else if (match(Inst, m_Intrinsic<Intrinsic::x86_tdpbf16ps_internal>()))
+      // %amx1 = bitcast <256 x i32> %vec to x86_amx
+      // %res = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 m, i16 n, i16 k,
+      //                                                x86_amx, %amx1, ...)
+      // %vec2 = bitcast x86_amx %res to <256 x i32>
+      C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
     else if (match(Inst, m_Intrinsic<Intrinsic::x86_tileloadd64_internal>()))
       // %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14,
       //                                                   i8* %15, i64 %16)
Index: llvm/lib/Target/X86/X86InstrAMX.td
===================================================================
--- llvm/lib/Target/X86/X86InstrAMX.td
+++ llvm/lib/Target/X86/X86InstrAMX.td
@@ -136,5 +136,10 @@
                                [(int_x86_tdpbf16ps timm:$src1,
                                  timm:$src2, timm:$src3)]>;
     }
+    // Pseduo instruction for RA.
+    let Constraints = "$src4 = $dst" in
+    def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                            GR16:$src2, GR16:$src3, TILE:$src4,
+                            TILE:$src5, TILE:$src6), []>;
   }
 } // HasAMXTILE, HasAMXBF16
Index: llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4638,6 +4638,23 @@
       ReplaceNode(Node, CNode);
       return;
     }
+    case Intrinsic::x86_tdpbf16ps_internal: {
+      if (!Subtarget->hasAMXTILE())
+        break;
+      SDValue Chain = Node->getOperand(0);
+      unsigned Opc = X86::PTDPBF16PSV;
+      SDValue Ops[] = {Node->getOperand(2),
+                       Node->getOperand(3),
+                       Node->getOperand(4),
+                       Node->getOperand(5),
+                       Node->getOperand(6),
+                       Node->getOperand(7),
+                       Chain};
+      MachineSDNode *CNode =
+          CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
     case Intrinsic::x86_tilezero_internal: {
       if (!Subtarget->hasAMXTILE())
         break;
Index: llvm/lib/Target/X86/X86ExpandPseudo.cpp
===================================================================
--- llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -475,6 +475,14 @@
     MI.tieOperands(0, 1);
     return true;
   }
+  case X86::PTDPBF16PSV: {
+    MI.untieRegOperand(4);
+    for (unsigned i = 3; i > 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TDPBF16PS));
+    MI.tieOperands(0, 1);
+    return true;
+  }
   case X86::PTILESTOREDV: {
     for (int i = 1; i >= 0; --i)
       MI.RemoveOperand(i);
Index: llvm/include/llvm/IR/IntrinsicsX86.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsX86.td
+++ llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5053,6 +5053,12 @@
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty], []>;
+  def int_x86_tdpbf16ps_internal :
+              GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_x86amx_ty, llvm_x86amx_ty,
+                         llvm_x86amx_ty], []>;
   def int_x86_tilestored64_internal :
               GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
               Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
Index: clang/lib/Headers/amxintrin.h
===================================================================
--- clang/lib/Headers/amxintrin.h
+++ clang/lib/Headers/amxintrin.h
@@ -224,6 +224,9 @@
 #define __DEFAULT_FN_ATTRS_INT8                                                \
   __attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
 
+#define __DEFAULT_FN_ATTRS_BF16                                                \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16")))
+
 typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base,
@@ -238,6 +241,12 @@
   return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
 }
 
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
+_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                        _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
+}
+
 static __inline__ void __DEFAULT_FN_ATTRS_INT8
 _tile_stored_internal(unsigned short m, unsigned short n, void *base,
                       __SIZE_TYPE__ stride, _tile1024i tile) {
@@ -264,6 +273,13 @@
                                     src1.tile, src2.tile);
 }
 
+__DEFAULT_FN_ATTRS_INT8
+static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src1,
+                            __tile1024i src2) {
+  dst->tile = _tile_dpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                      src1.tile, src2.tile);
+}
+
 __DEFAULT_FN_ATTRS_TILE
 static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) {
   _tile_stored_internal(src.row, src.col, base, stride, src.tile);
Index: clang/include/clang/Basic/BuiltinsX86_64.def
===================================================================
--- clang/include/clang/Basic/BuiltinsX86_64.def
+++ clang/include/clang/Basic/BuiltinsX86_64.def
@@ -103,6 +103,7 @@
 // AMX internal builtin
 TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
+TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16")
 TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile")
 // AMX
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to