yubing created this revision.
Herald added subscribers: pengfei, hiraditya.
yubing requested review of this revision.
Herald added projects: clang, LLVM.
Herald added subscribers: llvm-commits, cfe-commits.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D96110
Files:
clang/include/clang/Basic/BuiltinsX86_64.def
clang/lib/Headers/amxintrin.h
llvm/include/llvm/IR/IntrinsicsX86.td
llvm/lib/Target/X86/X86ExpandPseudo.cpp
llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
llvm/lib/Target/X86/X86InstrAMX.td
llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
llvm/lib/Target/X86/X86LowerAMXType.cpp
llvm/lib/Target/X86/X86PreTileConfig.cpp
llvm/lib/Target/X86/X86RegisterInfo.cpp
Index: llvm/lib/Target/X86/X86RegisterInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -873,6 +873,7 @@
// We only collect the tile shape that is defined.
case X86::PTILELOADDV:
case X86::PTDPBSSDV:
+ case X86::PTDPBF16PSV:
case X86::PTILEZEROV:
MachineOperand &MO1 = MI->getOperand(1);
MachineOperand &MO2 = MI->getOperand(2);
Index: llvm/lib/Target/X86/X86PreTileConfig.cpp
===================================================================
--- llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -127,6 +127,7 @@
llvm_unreachable("Unexpected machine instruction on tile");
case X86::PTILELOADDV:
case X86::PTDPBSSDV:
+ case X86::PTDPBF16PSV:
case X86::PTILEZEROV:
MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
@@ -221,6 +222,7 @@
case X86::PTILELOADDV:
case X86::PTILESTOREDV:
case X86::PTDPBSSDV:
+ case X86::PTDPBF16PSV:
case X86::PTILEZEROV:
return true;
}
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -67,7 +67,8 @@
}
// a * b + c
// The shape depends on which operand.
- case Intrinsic::x86_tdpbssd_internal: {
+ case Intrinsic::x86_tdpbssd_internal:
+ case Intrinsic::x86_tdpbf16ps_internal:{
switch (OpNo) {
case 3:
Row = II->getArgOperand(0);
Index: llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
+++ llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
@@ -306,6 +306,111 @@
return NewVecC;
}
+static Value *createTileDPBF16PSLoops(BasicBlock *Start, BasicBlock *End,
+ IRBuilderBase &B, DomTreeUpdater &DTU,
+ LoopInfo &LI, Value *Row, Value *Col,
+ Value *K, Value *Acc, Value *LHS,
+ Value *RHS) {
+ Loop *RowLoop = LI.AllocateLoop();
+ Loop *ColLoop = LI.AllocateLoop();
+ Loop *InnerLoop = LI.AllocateLoop();
+ ColLoop->addChildLoop(InnerLoop);
+ RowLoop->addChildLoop(ColLoop);
+ if (Loop *ParentL = LI.getLoopFor(Start))
+ ParentL->addChildLoop(RowLoop);
+ else
+ LI.addTopLevelLoop(RowLoop);
+
+ BasicBlock *RowBody =
+ createLoop(Start, End, Row, B.getInt16(1), "tiledpbf16ps.unroll.rows", B,
+ DTU, RowLoop, LI);
+ BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+ BasicBlock *ColBody =
+ createLoop(RowBody, RowLatch, Col, B.getInt16(1),
+ "tiledpbf16ps.unroll.cols", B, DTU, ColLoop, LI);
+ BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+
+ B.SetInsertPoint(ColBody->getTerminator());
+ BasicBlock *InnerBody =
+ createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
+ "tiledpbf16ps.unroll.inner", B, DTU, InnerLoop, LI);
+
+ BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+ BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+ BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
+ BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
+ Value *CurrentRow = &*RowLoopHeader->begin();
+ Value *CurrentCol = &*ColumnLoopHeader->begin();
+ Value *CurrentInner = &*InnerLoopHeader->begin();
+
+ FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+ // Type *EltTy = V256I32Ty->getElementType();
+ Value *VecC, *VecA, *VecB;
+ if (auto BitCast = dyn_cast<BitCastInst>(Acc))
+ VecC = BitCast->getOperand(0);
+ assert(VecC->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+ // TODO else create BitCast from x86amx to v256i32.
+ // Store x86amx to memory, and reload from memory
+ // to vector. However with -O0, it doesn't happen.
+ if (auto BitCast = dyn_cast<BitCastInst>(LHS))
+ VecA = BitCast->getOperand(0);
+ assert(VecA->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+ if (auto BitCast = dyn_cast<BitCastInst>(RHS))
+ VecB = BitCast->getOperand(0);
+ assert(VecB->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx");
+
+ // tiledpbf16ps.unroll.rows.header:
+ // %vec.phi.rows = phi <256 x i32> [ %vec_c, %continue ], [ %NewVecC,
+ // %tiledpbf16ps.unroll.rows.latch ]
+ B.SetInsertPoint(RowLoopHeader->getTerminator());
+ PHINode *VecPhi_Row_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
+ VecPhi_Row_Loop->addIncoming(VecC, Start);
+
+ // tiledpbf16ps.unroll.cols.header:
+ // %vec.phi.cols = phi <256 x i32> [ %vec.phi.rows,
+ // %tiledpbf16ps.unroll.rows.body ], [ %NewVecC, %tiledpbf16ps.unroll.cols.latch ]
+ B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+ PHINode *VecPhi_Col_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.col");
+ VecPhi_Col_Loop->addIncoming(VecPhi_Row_Loop, RowBody);
+
+ // Generate PHI vector for C.
+ B.SetInsertPoint(InnerLoopHeader->getTerminator());
+ PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
+ VecCPhi->addIncoming(VecPhi_Col_Loop, ColBody);
+
+ // Generate accmulate multiply in innerbody.
+ B.SetInsertPoint(InnerBody->getTerminator());
+ Value *IdxC =
+ B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+ Value *IdxA =
+ B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
+ Value *IdxB =
+ B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
+
+ //FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+ FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
+ FixedVectorType *V2I32Ty = FixedVectorType::get(B.getInt32Ty(), 2);
+ FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
+ FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+ Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+ Value *C_F32= B.CreateBitCast(EltC, B.getFloatTy());
+ Value *EltA = B.CreateExtractElement(VecA, IdxA);
+ Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
+ Value *EltB = B.CreateExtractElement(VecB, IdxB);
+ Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
+ Value *A_V2F32 = B.CreateBitCast(B.CreateShl(B.CreateZExt(SubVecA, V2I32Ty), B.CreateVectorSplat(2,B.getInt32(16))), V2F32Ty);
+ Value *B_V2F32 = B.CreateBitCast(B.CreateShl(B.CreateZExt(SubVecB, V2I32Ty), B.CreateVectorSplat(2,B.getInt32(16))), V2F32Ty);
+ Value *SubVecR = B.CreateFAddReduce(C_F32, B.CreateFMul(A_V2F32, B_V2F32));
+ Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
+ Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+ VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
+ VecPhi_Row_Loop->addIncoming(NewVecC, RowLatch);
+ VecPhi_Col_Loop->addIncoming(NewVecC, ColLoopLatch);
+
+ return NewVecC;
+}
+
namespace {
class X86LowerAMXIntrinsics {
Function &Func;
@@ -320,6 +425,7 @@
LoopInfo *LI;
bool lowerTileLoad(Instruction *TileLoad);
bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+ bool lowerTileDPBF16PS(Instruction *TileDPBSSD);
bool lowerTileStore(Instruction *TileStore);
bool lowerTileZero(Instruction *TileZero);
};
@@ -359,6 +465,41 @@
return true;
}
+bool X86LowerAMXIntrinsics::lowerTileDPBF16PS(Instruction *TileDPBF16PS) {
+ Value *M, *N, *K, *C, *A, *B;
+ match(TileDPBF16PS, m_Intrinsic<Intrinsic::x86_tdpbf16ps_internal>(
+ m_Value(M), m_Value(N), m_Value(K), m_Value(C),
+ m_Value(A), m_Value(B)));
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ Instruction *InsertI = TileDPBF16PS;
+ IRBuilder<> Builder_Prepare(TileDPBF16PS);
+ Builder_Prepare.SetInsertPoint(TileDPBF16PS);
+ // We visit the loop with (m, n/4, k/4):
+ // %n_dword = udiv i16 %n, 4
+ // %k_dword = udiv i16 %k, 4
+ Value *N_DWord = Builder_Prepare.CreateUDiv(N, Builder_Prepare.getInt16(4));
+ Value *K_DWord = Builder_Prepare.CreateUDiv(K, Builder_Prepare.getInt16(4));
+ BasicBlock *Start = InsertI->getParent();
+ BasicBlock *End =
+ SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+ IRBuilder<> Builder(TileDPBF16PS);
+ Value *ResVec = createTileDPBF16PSLoops(Start, End, Builder, DTU, *LI, M,
+ N_DWord, K_DWord, C, A, B);
+
+ // Delete tileloadd6 intrinsic and bitcast instruction.
+ for (auto UI = TileDPBF16PS->use_begin(), UE = TileDPBF16PS->use_end();
+ UI != UE;) {
+ Instruction *I = cast<Instruction>((UI++)->getUser());
+ Value *Vec;
+ if (match(I, m_BitCast(m_Value(Vec)))) {
+ I->replaceAllUsesWith(ResVec);
+ I->eraseFromParent();
+ }
+ }
+ TileDPBF16PS->eraseFromParent();
+ return true;
+}
+
bool X86LowerAMXIntrinsics::lowerTileLoad(Instruction *TileLoad) {
Value *M, *N, *Ptr, *Stride;
match(TileLoad, m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
@@ -432,6 +573,7 @@
bool X86LowerAMXIntrinsics::visit() {
bool C;
SmallVector<Instruction *, 8> TileDPBSSDs;
+ SmallVector<Instruction *, 8> TileDPBF16PSs;
SmallVector<Instruction *, 8> TileLoads;
SmallVector<Instruction *, 8> TileStores;
SmallVector<Instruction *, 8> TileZeros;
@@ -446,6 +588,12 @@
// x86_amx, %amx1, ...)
// %vec2 = bitcast x86_amx %res to <256 x i32>
TileDPBSSDs.push_back(&Inst);
+ } else if (match(&Inst, m_Intrinsic<Intrinsic::x86_tdpbf16ps_internal>())) {
+ // %amx1 = bitcast <256 x i32> %vec to x86_amx
+ // %res = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 m, i16 n, i16 k,
+ // x86_amx, %amx1, ...)
+ // %vec2 = bitcast x86_amx %res to <256 x i32>
+ TileDPBF16PSs.push_back(&Inst);
} else if (match(&Inst,
m_Intrinsic<Intrinsic::x86_tileloadd64_internal>())) {
// %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14,
@@ -473,6 +621,9 @@
for (auto *Inst : TileDPBSSDs) {
C |= lowerTileDPBSSD(Inst);
}
+ for (auto *Inst : TileDPBF16PSs) {
+ C |= lowerTileDPBF16PS(Inst);
+ }
for (auto *Inst : TileStores) {
C |= lowerTileStore(Inst);
}
Index: llvm/lib/Target/X86/X86InstrAMX.td
===================================================================
--- llvm/lib/Target/X86/X86InstrAMX.td
+++ llvm/lib/Target/X86/X86InstrAMX.td
@@ -136,5 +136,10 @@
[(int_x86_tdpbf16ps timm:$src1,
timm:$src2, timm:$src3)]>;
}
+ // Pseduo instruction for RA.
+ let Constraints = "$src4 = $dst" in
+ def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+ GR16:$src2, GR16:$src3, TILE:$src4,
+ TILE:$src5, TILE:$src6), []>;
}
} // HasAMXTILE, HasAMXBF16
Index: llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4638,6 +4638,23 @@
ReplaceNode(Node, CNode);
return;
}
+ case Intrinsic::x86_tdpbf16ps_internal: {
+ if (!Subtarget->hasAMXTILE())
+ break;
+ SDValue Chain = Node->getOperand(0);
+ unsigned Opc = X86::PTDPBF16PSV;
+ SDValue Ops[] = {Node->getOperand(2),
+ Node->getOperand(3),
+ Node->getOperand(4),
+ Node->getOperand(5),
+ Node->getOperand(6),
+ Node->getOperand(7),
+ Chain};
+ MachineSDNode *CNode =
+ CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
+ ReplaceNode(Node, CNode);
+ return;
+ }
case Intrinsic::x86_tilezero_internal: {
if (!Subtarget->hasAMXTILE())
break;
Index: llvm/lib/Target/X86/X86ExpandPseudo.cpp
===================================================================
--- llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -475,6 +475,14 @@
MI.tieOperands(0, 1);
return true;
}
+ case X86::PTDPBF16PSV: {
+ MI.untieRegOperand(4);
+ for (unsigned i = 3; i > 0; --i)
+ MI.RemoveOperand(i);
+ MI.setDesc(TII->get(X86::TDPBF16PS));
+ MI.tieOperands(0, 1);
+ return true;
+ }
case X86::PTILESTOREDV: {
for (int i = 1; i >= 0; --i)
MI.RemoveOperand(i);
Index: llvm/include/llvm/IR/IntrinsicsX86.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsX86.td
+++ llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5053,6 +5053,12 @@
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty,
llvm_x86amx_ty], []>;
+ def int_x86_tdpbf16ps_internal :
+ GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">,
+ Intrinsic<[llvm_x86amx_ty],
+ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+ llvm_x86amx_ty, llvm_x86amx_ty,
+ llvm_x86amx_ty], []>;
def int_x86_tilestored64_internal :
GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
Index: clang/lib/Headers/amxintrin.h
===================================================================
--- clang/lib/Headers/amxintrin.h
+++ clang/lib/Headers/amxintrin.h
@@ -224,6 +224,9 @@
#define __DEFAULT_FN_ATTRS_INT8 \
__attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
+#define __DEFAULT_FN_ATTRS_BF16 \
+ __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16")))
+
typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
_tile_loadd_internal(unsigned short m, unsigned short n, const void *base,
@@ -238,6 +241,12 @@
return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
}
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
+_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
+ _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+ return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
+}
+
static __inline__ void __DEFAULT_FN_ATTRS_INT8
_tile_stored_internal(unsigned short m, unsigned short n, void *base,
__SIZE_TYPE__ stride, _tile1024i tile) {
@@ -264,6 +273,13 @@
src1.tile, src2.tile);
}
+__DEFAULT_FN_ATTRS_INT8
+static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src1,
+ __tile1024i src2) {
+ dst->tile = _tile_dpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile,
+ src1.tile, src2.tile);
+}
+
__DEFAULT_FN_ATTRS_TILE
static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) {
_tile_stored_internal(src.row, src.col, base, stride, src.tile);
Index: clang/include/clang/Basic/BuiltinsX86_64.def
===================================================================
--- clang/include/clang/Basic/BuiltinsX86_64.def
+++ clang/include/clang/Basic/BuiltinsX86_64.def
@@ -103,6 +103,7 @@
// AMX internal builtin
TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
+TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16")
TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile")
// AMX
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits