================ @@ -0,0 +1,816 @@ +//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements a pass that recognizes certain loop idioms and +// transforms them into more optimized versions of the same loop. In cases +// where this happens, it can be a significant performance win. +// +// We currently only recognize one loop that finds the first mismatched byte +// in an array and returns the index, i.e. something like: +// +// while (++i != n) { +// if (a[i] != b[i]) +// break; +// } +// +// In this example we can actually vectorize the loop despite the early exit, +// although the loop vectorizer does not support it. It requires some extra +// checks to deal with the possibility of faulting loads when crossing page +// boundaries. However, even with these checks it is still profitable to do the +// transformation. +// +//===----------------------------------------------------------------------===// +// +// TODO List: +// +// * When optimizing for code size we may want to avoid some transformations. +// * We can also support the inverse case where we scan for a matching element. +// +//===----------------------------------------------------------------------===// + +#include "AArch64LoopIdiomTransform.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-loop-idiom-transform" + +static cl::opt<bool> + DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false), + cl::desc("Disable AArch64 Loop Idiom Transform Pass.")); + +static cl::opt<bool> DisableByteCmp( + "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false), + cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do " + "not convert byte-compare loop(s).")); + +static cl::opt<bool> VerifyLoops( + "aarch64-lit-verify", cl::Hidden, cl::init(false), + cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass.")); + +namespace llvm { + +void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &); +Pass *createAArch64LoopIdiomTransformPass(); + +} // end namespace llvm + +namespace { + +class AArch64LoopIdiomTransform { + Loop *CurLoop = nullptr; + DominatorTree *DT; + LoopInfo *LI; + const TargetTransformInfo *TTI; + const DataLayout *DL; + +public: + explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI, + const TargetTransformInfo *TTI, + const DataLayout *DL) + : DT(DT), LI(LI), TTI(TTI), DL(DL) {} + + bool run(Loop *L); + +private: + /// \name Countable Loop Idiom Handling + /// @{ + + bool runOnCountableLoop(); + bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, + SmallVectorImpl<BasicBlock *> &ExitBlocks); + + bool recognizeByteCompare(); + Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Instruction *Index, + Value *Start, Value *MaxLen); + void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, + PHINode *IndPhi, Value *MaxLen, Instruction *Index, + Value *Start, bool IncIdx, BasicBlock *FoundBB, + BasicBlock *EndBB); + /// @} +}; + +class AArch64LoopIdiomTransformLegacyPass : public LoopPass { +public: + static char ID; + + explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) { + initializeAArch64LoopIdiomTransformLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "Transform AArch64-specific loop idioms"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; +}; + +bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L, + LPPassManager &LPM) { + + if (skipLoop(L)) + return false; + + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()); + return AArch64LoopIdiomTransform( + DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout()) + .run(L); +} + +} // end anonymous namespace + +char AArch64LoopIdiomTransformLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN( + AArch64LoopIdiomTransformLegacyPass, "aarch64-lit", + "Transform specific loop idioms into optimized vector forms", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END( + AArch64LoopIdiomTransformLegacyPass, "aarch64-lit", + "Transform specific loop idioms into optimized vector forms", false, false) + +Pass *llvm::createAArch64LoopIdiomTransformPass() { + return new AArch64LoopIdiomTransformLegacyPass(); +} + +PreservedAnalyses +AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (DisableAll) + return PreservedAnalyses::all(); + + const auto *DL = &L.getHeader()->getModule()->getDataLayout(); + + AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL); + if (!LIT.run(&L)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +//===----------------------------------------------------------------------===// +// +// Implementation of AArch64LoopIdiomTransform +// +//===----------------------------------------------------------------------===// + +bool AArch64LoopIdiomTransform::run(Loop *L) { + CurLoop = L; + + if (DisableAll) + return false; + + // If the loop could not be converted to canonical form, it must have an + // indirectbr in it, just give up. + if (!L->getLoopPreheader()) + return false; + + LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" + << CurLoop->getHeader()->getParent()->getName() + << "] Loop %" << CurLoop->getHeader()->getName() << "\n"); + + return recognizeByteCompare(); +} + + +bool AArch64LoopIdiomTransform::recognizeByteCompare() { + // Currently the transformation only works on scalable vector types, although + // there is no fundamental reason why it cannot be made to work for fixed + // width too. + + // We also need to know the minimum page size for the target in order to + // generate runtime memory checks to ensure the vector version won't fault. + if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || + DisableByteCmp) + return false; + + BasicBlock *Header = CurLoop->getHeader(); + + // In AArch64LoopIdiomTransform::run we have already checked that the loop + // has a preheader so we can assume it's in a canonical form. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) + return false; + + PHINode *PN = dyn_cast<PHINode>(&Header->front()); + if (!PN || PN->getNumIncomingValues() != 2) + return false; + + auto LoopBlocks = CurLoop->getBlocks(); + // The first block in the loop should contain only 4 instructions, e.g. + // + // while.cond: + // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] + // %inc = add i32 %res.phi, 1 + // %cmp.not = icmp eq i32 %inc, %n + // br i1 %cmp.not, label %while.end, label %while.body + // + auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug(); + if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4) + return false; + + // The second block should contain 7 instructions, e.g. + // + // while.body: + // %idx = zext i32 %inc to i64 + // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx + // %load.a = load i8, ptr %idx.a + // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx + // %load.b = load i8, ptr %idx.b + // %cmp.not.ld = icmp eq i8 %load.a, %load.b + // br i1 %cmp.not.ld, label %while.cond, label %while.end + // + auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug(); + if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7) + return false; + + using namespace PatternMatch; + + // The incoming value to the PHI node from the loop should be an add of 1. + Value *StartIdx = nullptr; + Instruction *Index = nullptr; + if (!CurLoop->contains(PN->getIncomingBlock(0))) { + StartIdx = PN->getIncomingValue(0); + Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); + } else { + StartIdx = PN->getIncomingValue(1); + Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); + } + + // Limit to 32-bit types for now + if (!Index || !Index->getType()->isIntegerTy(32) || + !match(Index, m_c_Add(m_Specific(PN), m_One()))) + return false; + + // If we match the pattern, PN and Index will be replaced with the result of + // the cttz.elts intrinsic. If any other instructions are used outside of + // the loop, we cannot replace it. + for (BasicBlock *BB : LoopBlocks) + for (Instruction &I : *BB) + if (&I != PN && &I != Index) + for (User *U : I.users()) + if (!CurLoop->contains(cast<Instruction>(U))) + return false; + + // Match the branch instruction for the header + ICmpInst::Predicate Pred; + Value *MaxLen; + BasicBlock *EndBB, *WhileBB; + if (!match(Header->getTerminator(), + m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)), + m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || + Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB)) + return false; + + // WhileBB should contain the pattern of load & compare instructions. Match + // the pattern and find the GEP instructions used by the loads. + ICmpInst::Predicate WhilePred; + BasicBlock *FoundBB; + BasicBlock *TrueBB; + Value *LoadA, *LoadB; + if (!match(WhileBB->getTerminator(), + m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)), + m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || + WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB)) + return false; + + Value *A, *B; + if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) + return false; + + GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); + GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); + + if (!GEPA || !GEPB) + return false; + + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + // Check we are loading i8 values from two loop invariant pointers + if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || + !GEPA->getResultElementType()->isIntegerTy(8) || + !GEPB->getResultElementType()->isIntegerTy(8) || + !cast<LoadInst>(LoadA)->getType()->isIntegerTy(8) || + !cast<LoadInst>(LoadB)->getType()->isIntegerTy(8) || PtrA == PtrB) + return false; + + // Check that the index to the GEPs is the index we found earlier + if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) + return false; + + Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); + Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); + if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) + return false; + + // Ensure that when the Found and End blocks are identical the PHIs have the + // supported format. We don't currently allow cases like this: + // while.cond: + // ... + // br i1 %cmp.not, label %while.end, label %while.body + // + // while.body: + // ... + // br i1 %cmp.not2, label %while.cond, label %while.end + // + // while.end: + // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] + // + // Where the incoming values for %final_ptr are unique and from each of the + // loop blocks, but not actually defined in the loop. This requires extra + // work setting up the byte.compare block, i.e. by introducing a select to + // choose the correct value. + // TODO: We could add support for this in future. + if (FoundBB == EndBB) { + for (PHINode &PN : EndBB->phis()) { + Value *LastValue = nullptr; + for (unsigned I = 0; I < PN.getNumIncomingValues(); I++) { + BasicBlock *BB = PN.getIncomingBlock(I); + if (CurLoop->contains(BB)) { + Value *V = PN.getIncomingValue(I); + if (!LastValue) + LastValue = V; + else if (LastValue != V) + return false; + } + } + } + } + + LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" + << *(EndBB->getParent()) << "\n\n"); + + // The index is incremented before the GEP/Load pair so we need to + // add 1 to the start value. + transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, + FoundBB, EndBB); + return true; +} + +Value *AArch64LoopIdiomTransform::expandFindMismatch( + IRBuilder<> &Builder, GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, + Instruction *Index, Value *Start, Value *MaxLen) { + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + // Get the arguments and types for the intrinsic. + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); + LLVMContext &Ctx = PHBranch->getContext(); + Type *LoadType = Type::getInt8Ty(Ctx); + Type *ResType = Builder.getInt32Ty(); + + // Split block in the original loop preheader. + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + BasicBlock *EndBlock = + SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); + + // Create the blocks that we're going to need: + // 1. A block for checking the zero-extended length exceeds 0 + // 2. A block to check that the start and end addresses of a given array + // lie on the same page. + // 3. The SVE loop preheader. + // 4. The first SVE loop block. + // 5. The SVE loop increment block. + // 6. A block we can jump to from the SVE loop when a mismatch is found. + // 7. The first block of the scalar loop itself, containing PHIs , loads + // and cmp. + // 8. A scalar loop increment block to increment the PHIs and go back + // around the loop. + + BasicBlock *MinItCheckBlock = BasicBlock::Create( + Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); + + // Update the terminator added by SplitBlock to branch to the first block + Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); + + BasicBlock *MemCheckBlock = BasicBlock::Create( + Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); + + BasicBlock *SVELoopPreheaderBlock = BasicBlock::Create( + Ctx, "mismatch_sve_loop_preheader", EndBlock->getParent(), EndBlock); + + BasicBlock *SVELoopStartBlock = BasicBlock::Create( + Ctx, "mismatch_sve_loop", EndBlock->getParent(), EndBlock); + + BasicBlock *SVELoopIncBlock = BasicBlock::Create( + Ctx, "mismatch_sve_loop_inc", EndBlock->getParent(), EndBlock); + + BasicBlock *SVELoopMismatchBlock = BasicBlock::Create( + Ctx, "mismatch_sve_loop_found", EndBlock->getParent(), EndBlock); + + BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( + Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); + + BasicBlock *LoopStartBlock = + BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); + + BasicBlock *LoopIncBlock = BasicBlock::Create( + Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); + + DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, + {DominatorTree::Delete, Preheader, EndBlock}}); + + // Update LoopInfo with the new SVE & scalar loops. + auto SVELoop = LI->AllocateLoop(); + auto ScalarLoop = LI->AllocateLoop(); + + if (CurLoop->getParentLoop()) { + CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); + CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); + CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopPreheaderBlock, *LI); + CurLoop->getParentLoop()->addChildLoop(SVELoop); + CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopMismatchBlock, *LI); + CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); + CurLoop->getParentLoop()->addChildLoop(ScalarLoop); + } else { + LI->addTopLevelLoop(SVELoop); + LI->addTopLevelLoop(ScalarLoop); + } + + // Add the new basic blocks to their associated loops. + SVELoop->addBasicBlockToLoop(SVELoopStartBlock, *LI); + SVELoop->addBasicBlockToLoop(SVELoopIncBlock, *LI); + + ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); + ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); + + // Set up some types and constants that we intend to reuse. + Type *I64Type = Builder.getInt64Ty(); + + // Check the zero-extended iteration count > 0 + Builder.SetInsertPoint(MinItCheckBlock); + Value *ExtStart = Builder.CreateZExt(Start, I64Type); + Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); + // This check doesn't really cost us very much. + + Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); + BranchInst *MinItCheckBr = + BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); + MinItCheckBr->setMetadata( + LLVMContext::MD_prof, + MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); + Builder.Insert(MinItCheckBr); + + DTU.applyUpdates( + {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, + {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); + + // For each of the arrays, check the start/end addresses are on the same + // page. + Builder.SetInsertPoint(MemCheckBlock); + + // The early exit in the original loop means that when performing vector + // loads we are potentially reading ahead of the early exit. So we could + // fault if crossing a page boundary. Therefore, we create runtime memory + // checks based on the minimum page size as follows: + // 1. Calculate the addresses of the first memory accesses in the loop, + // i.e. LhsStart and RhsStart. + // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. + // 3. Determine which pages correspond to all the memory accesses, i.e + // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. + // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then + // we know we won't cross any page boundaries in the loop so we can + // enter the vector loop! Otherwise we fall back on the scalar loop. + Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); + Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); + Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); + Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); + Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); + Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); + Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); + Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); + + const uint64_t MinPageSize = TTI->getMinPageSize().value(); + const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); + Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); + Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); + Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); + Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); + Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); + Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); + + Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); + BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( + LoopPreHeaderBlock, SVELoopPreheaderBlock, CombinedPageCmp); + CombinedPageCmpCmpBr->setMetadata( + LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) + .createBranchWeights(10, 90)); + Builder.Insert(CombinedPageCmpCmpBr); + + DTU.applyUpdates( + {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, + {DominatorTree::Insert, MemCheckBlock, SVELoopPreheaderBlock}}); + + // Set up the SVE loop preheader, i.e. calculate initial loop predicate, + // zero-extend MaxLen to 64-bits, determine the number of vector elements + // processed in each iteration, etc. + Builder.SetInsertPoint(SVELoopPreheaderBlock); + + // At this point we know two things must be true: + // 1. Start <= End + // 2. ExtMaxLen <= MinPageSize due to the page checks. + // Therefore, we know that we can use a 64-bit induction variable that + // starts from 0 -> ExtMaxLen and it will not overflow. + ScalableVectorType *PredVTy = + ScalableVectorType::get(Builder.getInt1Ty(), 16); + + Value *InitialPred = Builder.CreateIntrinsic( + Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); + + Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); + VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", + /*HasNUW=*/true, /*HasNSW=*/true); + + Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), + Builder.getInt1(false)); + + BranchInst *JumpToSVELoop = BranchInst::Create(SVELoopStartBlock); + Builder.Insert(JumpToSVELoop); + + DTU.applyUpdates( + {{DominatorTree::Insert, SVELoopPreheaderBlock, SVELoopStartBlock}}); + + // Set up the first SVE loop block by creating the PHIs, doing the vector + // loads and comparing the vectors. + Builder.SetInsertPoint(SVELoopStartBlock); + PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_sve_loop_pred"); + LoopPred->addIncoming(InitialPred, SVELoopPreheaderBlock); + PHINode *SVEIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_sve_index"); + SVEIndexPhi->addIncoming(ExtStart, SVELoopPreheaderBlock); + Type *SVELoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); + Value *GepOffset = SVEIndexPhi; + Value *Passthru = ConstantInt::getNullValue(SVELoadType); + + Value *SVELhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset); + if (GEPA->isInBounds()) + cast<GetElementPtrInst>(SVELhsGep)->setIsInBounds(true); + Value *SVELhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVELhsGep, Align(1), + LoopPred, Passthru); + + Value *SVERhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset); + if (GEPB->isInBounds()) + cast<GetElementPtrInst>(SVERhsGep)->setIsInBounds(true); + Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1), + LoopPred, Passthru); + + Value *SVEMatchCmp = Builder.CreateICmpNE(SVELhsLoad, SVERhsLoad); + SVEMatchCmp = Builder.CreateSelect(LoopPred, SVEMatchCmp, PFalse); + Value *SVEMatchHasActiveLanes = Builder.CreateOrReduce(SVEMatchCmp); + BranchInst *SVEEarlyExit = BranchInst::Create( + SVELoopMismatchBlock, SVELoopIncBlock, SVEMatchHasActiveLanes); + Builder.Insert(SVEEarlyExit); + + DTU.applyUpdates( + {{DominatorTree::Insert, SVELoopStartBlock, SVELoopMismatchBlock}, + {DominatorTree::Insert, SVELoopStartBlock, SVELoopIncBlock}}); + + // Increment the index counter and calculate the predicate for the next + // iteration of the loop. We branch back to the start of the loop if there + // is at least one active lane. + Builder.SetInsertPoint(SVELoopIncBlock); + Value *NewSVEIndexPhi = Builder.CreateAdd(SVEIndexPhi, VecLen, "", + /*HasNUW=*/true, /*HasNSW=*/true); + SVEIndexPhi->addIncoming(NewSVEIndexPhi, SVELoopIncBlock); + Value *NewPred = + Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, + {PredVTy, I64Type}, {NewSVEIndexPhi, ExtEnd}); + LoopPred->addIncoming(NewPred, SVELoopIncBlock); + + Value *PredHasActiveLanes = + Builder.CreateExtractElement(NewPred, uint64_t(0)); + BranchInst *SVELoopBranchBack = + BranchInst::Create(SVELoopStartBlock, EndBlock, PredHasActiveLanes); + Builder.Insert(SVELoopBranchBack); + + DTU.applyUpdates({{DominatorTree::Insert, SVELoopIncBlock, SVELoopStartBlock}, + {DominatorTree::Insert, SVELoopIncBlock, EndBlock}}); + + // If we found a mismatch then we need to calculate which lane in the vector + // had a mismatch and add that on to the current loop index. + Builder.SetInsertPoint(SVELoopMismatchBlock); + PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_sve_found_pred"); + FoundPred->addIncoming(SVEMatchCmp, SVELoopStartBlock); + PHINode *LastLoopPred = + Builder.CreatePHI(PredVTy, 1, "mismatch_sve_last_loop_pred"); + LastLoopPred->addIncoming(LoopPred, SVELoopStartBlock); + PHINode *SVEFoundIndex = + Builder.CreatePHI(I64Type, 1, "mismatch_sve_found_index"); + SVEFoundIndex->addIncoming(SVEIndexPhi, SVELoopStartBlock); + + Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); + Value *Ctz = Builder.CreateIntrinsic( + Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, + {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); + Ctz = Builder.CreateZExt(Ctz, I64Type); + Value *SVELoopRes64 = Builder.CreateAdd(SVEFoundIndex, Ctz, "", + /*HasNUW=*/true, /*HasNSW=*/true); + Value *SVELoopRes = Builder.CreateTrunc(SVELoopRes64, ResType); + + Builder.Insert(BranchInst::Create(EndBlock)); + + DTU.applyUpdates({{DominatorTree::Insert, SVELoopMismatchBlock, EndBlock}}); + + // Generate code for scalar loop. + Builder.SetInsertPoint(LoopPreHeaderBlock); + Builder.Insert(BranchInst::Create(LoopStartBlock)); + + DTU.applyUpdates( + {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); + + Builder.SetInsertPoint(LoopStartBlock); + PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); + IndexPhi->addIncoming(Start, LoopPreHeaderBlock); + + // Otherwise compare the values + // Load bytes from each array and compare them. + GepOffset = Builder.CreateZExt(IndexPhi, I64Type); + + Value *LhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset); + if (GEPA->isInBounds()) + cast<GetElementPtrInst>(LhsGep)->setIsInBounds(true); + Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); + + Value *RhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset); + if (GEPB->isInBounds()) + cast<GetElementPtrInst>(RhsGep)->setIsInBounds(true); + Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); + + Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + // If we have a mismatch then exit the loop ... + BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); + Builder.Insert(MatchCmpBr); + + DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, + {DominatorTree::Insert, LoopStartBlock, EndBlock}}); + + // Have we reached the maximum permitted length for the loop? + Builder.SetInsertPoint(LoopIncBlock); + Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", + /*HasNUW=*/Index->hasNoUnsignedWrap(), + /*HasNSW=*/Index->hasNoSignedWrap()); + IndexPhi->addIncoming(PhiInc, LoopIncBlock); + Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); + BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); + Builder.Insert(IVCmpBr); + + DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, + {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); + + // In the end block we need to insert a PHI node to deal with three cases: + // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. + // 2. We exitted the scalar loop early due to a mismatch and need to return + // the index that we found. + // 3. We didn't find a mismatch in the SVE loop, so we return MaxLen. + // 4. We exitted the SVE loop early due to a mismatch and need to return + // the index that we found. + Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); + PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); + ResPhi->addIncoming(MaxLen, LoopIncBlock); + ResPhi->addIncoming(IndexPhi, LoopStartBlock); + ResPhi->addIncoming(MaxLen, SVELoopIncBlock); + ResPhi->addIncoming(SVELoopRes, SVELoopMismatchBlock); + + Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); + + if (VerifyLoops) { + ScalarLoop->verifyLoop(); + SVELoop->verifyLoop(); + if (!SVELoop->isRecursivelyLCSSAForm(*DT, *LI)) + report_fatal_error("Loops must remain in LCSSA form!"); + if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) + report_fatal_error("Loops must remain in LCSSA form!"); + } + + return FinalRes; +} + +void AArch64LoopIdiomTransform::transformByteCompare( + GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi, + Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx, + BasicBlock *FoundBB, BasicBlock *EndBB) { + + // Insert the byte compare code at the end of the preheader block + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + BasicBlock *Header = CurLoop->getHeader(); + BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); + IRBuilder<> Builder(PHBranch); + Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); + + // Increment the pointer if this was done before the loads in the loop. + if (IncIdx) + Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); + + Value *ByteCmpRes = + expandFindMismatch(Builder, GEPA, GEPB, Index, Start, MaxLen); + + // Replaces uses of index & induction Phi with intrinsic (we already + // checked that the the first instruction of Header is the Phi above). + IndPhi->replaceAllUsesWith(ByteCmpRes); + Index->replaceAllUsesWith(ByteCmpRes); + + assert(PHBranch->isUnconditional() && + "Expected preheader to terminate with an unconditional branch."); + + // If no mismatch was found, we can jump to the end block. Create a + // new basic block for the compare instruction. + auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", ---------------- david-arm wrote:
Given the way that expandFindMismatch funnels all paths from both the scalar and vector loops into a single mismatch.end block, we then need a final byte.compare block to check whether the result is the end or not. That decides which successor block to jump to - FoundBB or EndBB. I admit for the case when FoundBB == EndBB we don't need to do this, but I think it's simpler just to be consistent with the FoundBB != EndBB case. https://github.com/llvm/llvm-project/pull/72273 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits