================ @@ -69,11 +82,295 @@ FunctionPass *llvm::createAMDGPURegBankLegalizePass() { return new AMDGPURegBankLegalize(); } -using namespace AMDGPU; +const RegBankLegalizeRules &getRules(const GCNSubtarget &ST, + MachineRegisterInfo &MRI) { + static std::mutex GlobalMutex; + static SmallDenseMap<unsigned, std::unique_ptr<RegBankLegalizeRules>> + CacheForRuleSet; + std::lock_guard<std::mutex> Lock(GlobalMutex); + if (!CacheForRuleSet.contains(ST.getGeneration())) { + auto Rules = std::make_unique<RegBankLegalizeRules>(ST, MRI); + CacheForRuleSet[ST.getGeneration()] = std::move(Rules); + } else { + CacheForRuleSet[ST.getGeneration()]->refreshRefs(ST, MRI); + } + return *CacheForRuleSet[ST.getGeneration()]; +} + +class AMDGPURegBankLegalizeCombiner { + MachineIRBuilder &B; + MachineRegisterInfo &MRI; + const SIRegisterInfo &TRI; + const RegisterBank *SgprRB; + const RegisterBank *VgprRB; + const RegisterBank *VccRB; + + static constexpr LLT S1 = LLT::scalar(1); + static constexpr LLT S16 = LLT::scalar(16); + static constexpr LLT S32 = LLT::scalar(32); + static constexpr LLT S64 = LLT::scalar(64); + +public: + AMDGPURegBankLegalizeCombiner(MachineIRBuilder &B, const SIRegisterInfo &TRI, + const RegisterBankInfo &RBI) + : B(B), MRI(*B.getMRI()), TRI(TRI), + SgprRB(&RBI.getRegBank(AMDGPU::SGPRRegBankID)), + VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)), + VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {}; + + bool isLaneMask(Register Reg) { + const RegisterBank *RB = MRI.getRegBankOrNull(Reg); + if (RB && RB->getID() == AMDGPU::VCCRegBankID) + return true; + + const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg); + return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1); + } + + void cleanUpAfterCombine(MachineInstr &MI, MachineInstr *Optional0) { + MI.eraseFromParent(); + if (Optional0 && isTriviallyDead(*Optional0, MRI)) + Optional0->eraseFromParent(); + } + + std::pair<MachineInstr *, Register> tryMatch(Register Src, unsigned Opcode) { + MachineInstr *MatchMI = MRI.getVRegDef(Src); + if (MatchMI->getOpcode() != Opcode) + return {nullptr, Register()}; + return {MatchMI, MatchMI->getOperand(1).getReg()}; + } + + void tryCombineCopy(MachineInstr &MI) { + using namespace llvm::MIPatternMatch; + Register Dst = MI.getOperand(0).getReg(); + Register Src = MI.getOperand(1).getReg(); + // Skip copies of physical registers. + if (!Dst.isVirtual() || !Src.isVirtual()) + return; + + // This is a cross bank copy, sgpr S1 to lane mask. + // + // %Src:sgpr(s1) = G_TRUNC %TruncS32Src:sgpr(s32) + // %Dst:lane-mask(s1) = COPY %Src:sgpr(s1) + // -> + // %Dst:lane-mask(s1) = G_COPY_VCC_SCC %TruncS32Src:sgpr(s32) + if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) { + auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC); + assert(Trunc && MRI.getType(TruncS32Src) == S32 && + "sgpr S1 must be result of G_TRUNC of sgpr S32"); + + B.setInstr(MI); + // Ensure that truncated bits in BoolSrc are 0. + auto One = B.buildConstant({SgprRB, S32}, 1); + auto BoolSrc = B.buildAnd({SgprRB, S32}, TruncS32Src, One); + B.buildInstr(AMDGPU::G_COPY_VCC_SCC, {Dst}, {BoolSrc}); + cleanUpAfterCombine(MI, Trunc); + return; + } + + // Src = G_READANYLANE RALSrc + // Dst = COPY Src + // -> + // Dst = RALSrc + if (MRI.getRegBankOrNull(Dst) == VgprRB && + MRI.getRegBankOrNull(Src) == SgprRB) { + auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_READANYLANE); + if (!RAL) + return; + + assert(MRI.getRegBank(RALSrc) == VgprRB); + MRI.replaceRegWith(Dst, RALSrc); + cleanUpAfterCombine(MI, RAL); + return; + } + } + + void tryCombineS1AnyExt(MachineInstr &MI) { + // %Src:sgpr(S1) = G_TRUNC %TruncSrc + // %Dst = G_ANYEXT %Src:sgpr(S1) + // -> + // %Dst = G_... %TruncSrc + Register Dst = MI.getOperand(0).getReg(); + Register Src = MI.getOperand(1).getReg(); + if (MRI.getType(Src) != S1) + return; + + auto [Trunc, TruncSrc] = tryMatch(Src, AMDGPU::G_TRUNC); + if (!Trunc) + return; + + LLT DstTy = MRI.getType(Dst); + LLT TruncSrcTy = MRI.getType(TruncSrc); + + if (DstTy == TruncSrcTy) { + MRI.replaceRegWith(Dst, TruncSrc); + cleanUpAfterCombine(MI, Trunc); + return; + } + + B.setInstr(MI); + + if (DstTy == S32 && TruncSrcTy == S64) { + auto Unmerge = B.buildUnmerge({SgprRB, S32}, TruncSrc); + MRI.replaceRegWith(Dst, Unmerge.getReg(0)); + cleanUpAfterCombine(MI, Trunc); + return; + } + + if (DstTy == S32 && TruncSrcTy == S16) { + B.buildAnyExt(Dst, TruncSrc); + cleanUpAfterCombine(MI, Trunc); + return; + } + + if (DstTy == S16 && TruncSrcTy == S32) { + B.buildTrunc(Dst, TruncSrc); + cleanUpAfterCombine(MI, Trunc); + return; + } + + llvm_unreachable("missing anyext + trunc combine"); + } +}; + +// Search through MRI for virtual registers with sgpr register bank and S1 LLT. +[[maybe_unused]] static Register getAnySgprS1(const MachineRegisterInfo &MRI) { + const LLT S1 = LLT::scalar(1); + for (unsigned i = 0; i < MRI.getNumVirtRegs(); ++i) { + Register Reg = Register::index2VirtReg(i); + if (MRI.def_empty(Reg)) + continue; + + if (MRI.getType(Reg) != S1) + continue; ---------------- arsenm wrote:
```suggestion if (MRI.getType(Reg) != S1 || MRI.def_empty(Reg)) continue; ``` https://github.com/llvm/llvm-project/pull/112864 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits