================ @@ -1694,6 +1698,84 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, return Result; } +// Since pre-1.6 SPIRV has no DotProductInput4x8BitPacked implementation, +// extract the elements of the packed inputs, multiply them and add the result +// to the accumulator. +template <bool Signed> +bool SPIRVInstructionSelector::selectDot4AddPacked(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.getNumOperands() == 5); + assert(I.getOperand(2).isReg()); + assert(I.getOperand(3).isReg()); + assert(I.getOperand(4).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + bool Result = false; + + // Acc = C + Register Acc = I.getOperand(4).getReg(); + SPIRVType *EltType = GR.getOrCreateSPIRVIntegerType(8, I, TII); + auto ExtractOp = + Signed ? SPIRV::OpBitFieldSExtract : SPIRV::OpBitFieldUExtract; + + // Extract the i8 element, multiply and add it to the accumulator + for (unsigned i = 0; i < 4; i++) { + // A[i] + Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); + Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(AElt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII)) + .addImm(8) + .constrainAllUses(TII, TRI, RBI); + + // B[i] + Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); + Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(BElt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(3).getReg()) + .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII)) + .addImm(8) + .constrainAllUses(TII, TRI, RBI); + + // A[i] * B[i] + Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass); + Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS)) + .addDef(Mul) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(AElt) + .addUse(BElt) + .constrainAllUses(TII, TRI, RBI); + + // Discard 24 highest-bits so that stored i32 register is i8 equivalent + Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass); + Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(MaskMul) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Mul) + .addUse(GR.getOrCreateConstInt(0, I, EltType, TII)) + .addImm(8) ---------------- s-perron wrote:
Same here. https://github.com/llvm/llvm-project/pull/113623 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits