yubing created this revision.
Herald added subscribers: pengfei, hiraditya.
yubing requested review of this revision.
Herald added projects: clang, LLVM.
Herald added subscribers: llvm-commits, cfe-commits.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D96110

Files:
  clang/include/clang/Basic/BuiltinsX86_64.def
  clang/lib/Headers/amxintrin.h
  llvm/include/llvm/IR/IntrinsicsX86.td
  llvm/lib/Target/X86/X86ExpandPseudo.cpp
  llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
  llvm/lib/Target/X86/X86InstrAMX.td
  llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
  llvm/lib/Target/X86/X86LowerAMXType.cpp
  llvm/lib/Target/X86/X86PreTileConfig.cpp
  llvm/lib/Target/X86/X86RegisterInfo.cpp

Index: llvm/lib/Target/X86/X86RegisterInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -873,6 +873,7 @@
   // We only collect the tile shape that is defined.
   case X86::PTILELOADDV:
   case X86::PTDPBSSDV:
+  case X86::PTDPBF16PSV:
   case X86::PTILEZEROV:
     MachineOperand &MO1 = MI->getOperand(1);
     MachineOperand &MO2 = MI->getOperand(2);
Index: llvm/lib/Target/X86/X86PreTileConfig.cpp
===================================================================
--- llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -127,6 +127,7 @@
     llvm_unreachable("Unexpected machine instruction on tile");
   case X86::PTILELOADDV:
   case X86::PTDPBSSDV:
+  case X86::PTDPBF16PSV:
   case X86::PTILEZEROV:
     MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
     MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
@@ -221,6 +222,7 @@
   case X86::PTILELOADDV:
   case X86::PTILESTOREDV:
   case X86::PTDPBSSDV:
+  case X86::PTDPBF16PSV:
   case X86::PTILEZEROV:
     return true;
   }
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -67,7 +67,8 @@
   }
   // a * b + c
   // The shape depends on which operand.
-  case Intrinsic::x86_tdpbssd_internal: {
+  case Intrinsic::x86_tdpbssd_internal:
+  case Intrinsic::x86_tdpbf16ps_internal:{
     switch (OpNo) {
     case 3:
       Row = II->getArgOperand(0);
Index: llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
+++ llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
@@ -306,6 +306,111 @@
   return NewVecC;
 }
 
+static Value *createTileDPBF16PSLoops(BasicBlock *Start, BasicBlock *End,
+                                    IRBuilderBase &B, DomTreeUpdater &DTU,
+                                    LoopInfo &LI, Value *Row, Value *Col,
+                                    Value *K, Value *Acc, Value *LHS,
+                                    Value *RHS) {
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  Loop *InnerLoop = LI.AllocateLoop();
+  ColLoop->addChildLoop(InnerLoop);
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), "tiledpbf16ps.unroll.rows", B,
+                 DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(1),
+                 "tiledpbf16ps.unroll.cols", B, DTU, ColLoop, LI);
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+
+  B.SetInsertPoint(ColBody->getTerminator());
+  BasicBlock *InnerBody =
+      createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
+                 "tiledpbf16ps.unroll.inner", B, DTU, InnerLoop, LI);
+
+  BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
+  BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColumnLoopHeader->begin();
+  Value *CurrentInner = &*InnerLoopHeader->begin();
+
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  // Type *EltTy = V256I32Ty->getElementType();
+  Value *VecC, *VecA, *VecB;
+  if (auto BitCast = dyn_cast<BitCastInst>(Acc))
+    VecC = BitCast->getOperand(0);
+  assert(VecC->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+  // TODO else create BitCast from x86amx to v256i32.
+  // Store x86amx to memory, and reload from memory
+  // to vector. However with -O0, it doesn't happen.
+  if (auto BitCast = dyn_cast<BitCastInst>(LHS))
+    VecA = BitCast->getOperand(0);
+  assert(VecA->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+  if (auto BitCast = dyn_cast<BitCastInst>(RHS))
+    VecB = BitCast->getOperand(0);
+  assert(VecB->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+
+  // tiledpbf16ps.unroll.rows.header:
+  // %vec.phi.rows = phi <256 x i32> [ %vec_c, %continue ], [ %NewVecC,
+  // %tiledpbf16ps.unroll.rows.latch ]
+  B.SetInsertPoint(RowLoopHeader->getTerminator());
+  PHINode *VecPhi_Row_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
+  VecPhi_Row_Loop->addIncoming(VecC, Start);
+
+  // tiledpbf16ps.unroll.cols.header:
+  // %vec.phi.cols = phi <256 x i32> [ %vec.phi.rows,
+  // %tiledpbf16ps.unroll.rows.body ], [ %NewVecC, %tiledpbf16ps.unroll.cols.latch ]
+  B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+  PHINode *VecPhi_Col_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.col");
+  VecPhi_Col_Loop->addIncoming(VecPhi_Row_Loop, RowBody);
+
+  // Generate PHI vector for C.
+  B.SetInsertPoint(InnerLoopHeader->getTerminator());
+  PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
+  VecCPhi->addIncoming(VecPhi_Col_Loop, ColBody);
+
+  // Generate accmulate multiply in innerbody.
+  B.SetInsertPoint(InnerBody->getTerminator());
+  Value *IdxC =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+  Value *IdxA =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
+  Value *IdxB =
+      B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
+
+  //FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+  FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
+  FixedVectorType *V2I32Ty = FixedVectorType::get(B.getInt32Ty(), 2);
+  FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
+  FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+  Value *C_F32= B.CreateBitCast(EltC, B.getFloatTy());
+  Value *EltA = B.CreateExtractElement(VecA, IdxA);
+  Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
+  Value *EltB = B.CreateExtractElement(VecB, IdxB);
+  Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
+  Value *A_V2F32 = B.CreateBitCast(B.CreateShl(B.CreateZExt(SubVecA, V2I32Ty), B.CreateVectorSplat(2,B.getInt32(16))), V2F32Ty);
+  Value *B_V2F32 = B.CreateBitCast(B.CreateShl(B.CreateZExt(SubVecB, V2I32Ty), B.CreateVectorSplat(2,B.getInt32(16))), V2F32Ty);
+  Value *SubVecR = B.CreateFAddReduce(C_F32, B.CreateFMul(A_V2F32, B_V2F32));
+  Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
+  Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+  VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
+  VecPhi_Row_Loop->addIncoming(NewVecC, RowLatch);
+  VecPhi_Col_Loop->addIncoming(NewVecC, ColLoopLatch);
+
+  return NewVecC;
+}
+
 namespace {
 class X86LowerAMXIntrinsics {
   Function &Func;
@@ -320,6 +425,7 @@
   LoopInfo *LI;
   bool lowerTileLoad(Instruction *TileLoad);
   bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+  bool lowerTileDPBF16PS(Instruction *TileDPBSSD);
   bool lowerTileStore(Instruction *TileStore);
   bool lowerTileZero(Instruction *TileZero);
 };
@@ -359,6 +465,41 @@
   return true;
 }
 
+bool X86LowerAMXIntrinsics::lowerTileDPBF16PS(Instruction *TileDPBF16PS) {
+  Value *M, *N, *K, *C, *A, *B;
+  match(TileDPBF16PS, m_Intrinsic<Intrinsic::x86_tdpbf16ps_internal>(
+                        m_Value(M), m_Value(N), m_Value(K), m_Value(C),
+                        m_Value(A), m_Value(B)));
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileDPBF16PS;
+  IRBuilder<> Builder_Prepare(TileDPBF16PS);
+  Builder_Prepare.SetInsertPoint(TileDPBF16PS);
+  // We visit the loop with (m, n/4, k/4):
+  // %n_dword = udiv i16 %n, 4
+  // %k_dword = udiv i16 %k, 4
+  Value *N_DWord = Builder_Prepare.CreateUDiv(N, Builder_Prepare.getInt16(4));
+  Value *K_DWord = Builder_Prepare.CreateUDiv(K, Builder_Prepare.getInt16(4));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileDPBF16PS);
+  Value *ResVec = createTileDPBF16PSLoops(Start, End, Builder, DTU, *LI, M,
+                                        N_DWord, K_DWord, C, A, B);
+
+  // Delete tileloadd6 intrinsic and bitcast instruction.
+  for (auto UI = TileDPBF16PS->use_begin(), UE = TileDPBF16PS->use_end();
+       UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(ResVec);
+      I->eraseFromParent();
+    }
+  }
+  TileDPBF16PS->eraseFromParent();
+  return true;
+}
+
 bool X86LowerAMXIntrinsics::lowerTileLoad(Instruction *TileLoad) {
   Value *M, *N, *Ptr, *Stride;
   match(TileLoad, m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
@@ -432,6 +573,7 @@
 bool X86LowerAMXIntrinsics::visit() {
   bool C;
   SmallVector<Instruction *, 8> TileDPBSSDs;
+  SmallVector<Instruction *, 8> TileDPBF16PSs;
   SmallVector<Instruction *, 8> TileLoads;
   SmallVector<Instruction *, 8> TileStores;
   SmallVector<Instruction *, 8> TileZeros;
@@ -446,6 +588,12 @@
         //                                                x86_amx, %amx1, ...)
         // %vec2 = bitcast x86_amx %res to <256 x i32>
         TileDPBSSDs.push_back(&Inst);
+      } else if (match(&Inst, m_Intrinsic<Intrinsic::x86_tdpbf16ps_internal>())) {
+        // %amx1 = bitcast <256 x i32> %vec to x86_amx
+        // %res = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 m, i16 n, i16 k,
+        //                                                x86_amx, %amx1, ...)
+        // %vec2 = bitcast x86_amx %res to <256 x i32>
+        TileDPBF16PSs.push_back(&Inst);
       } else if (match(&Inst,
                        m_Intrinsic<Intrinsic::x86_tileloadd64_internal>())) {
         // %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14,
@@ -473,6 +621,9 @@
   for (auto *Inst : TileDPBSSDs) {
     C |= lowerTileDPBSSD(Inst);
   }
+  for (auto *Inst : TileDPBF16PSs) {
+    C |= lowerTileDPBF16PS(Inst);
+  }
   for (auto *Inst : TileStores) {
     C |= lowerTileStore(Inst);
   }
Index: llvm/lib/Target/X86/X86InstrAMX.td
===================================================================
--- llvm/lib/Target/X86/X86InstrAMX.td
+++ llvm/lib/Target/X86/X86InstrAMX.td
@@ -136,5 +136,10 @@
                                [(int_x86_tdpbf16ps timm:$src1,
                                  timm:$src2, timm:$src3)]>;
     }
+    // Pseduo instruction for RA.
+    let Constraints = "$src4 = $dst" in
+    def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                            GR16:$src2, GR16:$src3, TILE:$src4,
+                            TILE:$src5, TILE:$src6), []>;
   }
 } // HasAMXTILE, HasAMXBF16
Index: llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4638,6 +4638,23 @@
       ReplaceNode(Node, CNode);
       return;
     }
+    case Intrinsic::x86_tdpbf16ps_internal: {
+      if (!Subtarget->hasAMXTILE())
+        break;
+      SDValue Chain = Node->getOperand(0);
+      unsigned Opc = X86::PTDPBF16PSV;
+      SDValue Ops[] = {Node->getOperand(2),
+                       Node->getOperand(3),
+                       Node->getOperand(4),
+                       Node->getOperand(5),
+                       Node->getOperand(6),
+                       Node->getOperand(7),
+                       Chain};
+      MachineSDNode *CNode =
+          CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
     case Intrinsic::x86_tilezero_internal: {
       if (!Subtarget->hasAMXTILE())
         break;
Index: llvm/lib/Target/X86/X86ExpandPseudo.cpp
===================================================================
--- llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -475,6 +475,14 @@
     MI.tieOperands(0, 1);
     return true;
   }
+  case X86::PTDPBF16PSV: {
+    MI.untieRegOperand(4);
+    for (unsigned i = 3; i > 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TDPBF16PS));
+    MI.tieOperands(0, 1);
+    return true;
+  }
   case X86::PTILESTOREDV: {
     for (int i = 1; i >= 0; --i)
       MI.RemoveOperand(i);
Index: llvm/include/llvm/IR/IntrinsicsX86.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsX86.td
+++ llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5053,6 +5053,12 @@
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty], []>;
+  def int_x86_tdpbf16ps_internal :
+              GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_x86amx_ty, llvm_x86amx_ty,
+                         llvm_x86amx_ty], []>;
   def int_x86_tilestored64_internal :
               GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
               Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
Index: clang/lib/Headers/amxintrin.h
===================================================================
--- clang/lib/Headers/amxintrin.h
+++ clang/lib/Headers/amxintrin.h
@@ -224,6 +224,9 @@
 #define __DEFAULT_FN_ATTRS_INT8                                                \
   __attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
 
+#define __DEFAULT_FN_ATTRS_BF16                                                \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16")))
+
 typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base,
@@ -238,6 +241,12 @@
   return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
 }
 
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
+_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                      _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
+}
+
 static __inline__ void __DEFAULT_FN_ATTRS_INT8
 _tile_stored_internal(unsigned short m, unsigned short n, void *base,
                       __SIZE_TYPE__ stride, _tile1024i tile) {
@@ -264,6 +273,13 @@
                                     src1.tile, src2.tile);
 }
 
+__DEFAULT_FN_ATTRS_INT8
+static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src1,
+                          __tile1024i src2) {
+  dst->tile = _tile_dpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                    src1.tile, src2.tile);
+}
+
 __DEFAULT_FN_ATTRS_TILE
 static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) {
   _tile_stored_internal(src.row, src.col, base, stride, src.tile);
Index: clang/include/clang/Basic/BuiltinsX86_64.def
===================================================================
--- clang/include/clang/Basic/BuiltinsX86_64.def
+++ clang/include/clang/Basic/BuiltinsX86_64.def
@@ -103,6 +103,7 @@
 // AMX internal builtin
 TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
+TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16")
 TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile")
 // AMX
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to