================ @@ -652,3 +652,140 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond, Term->addMetadata(LLVMContext::MD_prof, BranchWeights); } } + +bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath(VPlan &Plan) { + VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); + VPReductionPHIRecipe *RedPhiR = nullptr; + VPValue *MinMaxOp = nullptr; + bool HasUnsupportedPhi = false; + + auto GetMinMaxCompareValue = [](VPSingleDefRecipe *MinMaxOp, + VPReductionPHIRecipe *RedPhi) -> VPValue * { + auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp); + if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) && + !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr())))) + return nullptr; + + if (MinMaxOp->getOperand(0) == RedPhi) + return MinMaxOp->getOperand(1); + assert(MinMaxOp->getOperand(1) == RedPhi && + "Reduction phi operand expected"); + return MinMaxOp->getOperand(0); + }; + + for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) { + // TODO: Also support first-order recurrence phis. + HasUnsupportedPhi |= + !isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe, + VPReductionPHIRecipe>(&R); + auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R); + if (!Cur) + continue; + // For now, only a single reduction is supported. + // TODO: Support multiple MaxNum/MinNum reductions and other reductions. + if (RedPhiR) + return false; + if (Cur->getRecurrenceKind() != RecurKind::FMaxNum && + Cur->getRecurrenceKind() != RecurKind::FMinNum) + continue; + + RedPhiR = Cur; + auto *MinMaxR = dyn_cast<VPRecipeWithIRFlags>( + RedPhiR->getBackedgeValue()->getDefiningRecipe()); + if (!MinMaxR) + return false; + MinMaxOp = GetMinMaxCompareValue(MinMaxR, RedPhiR); + if (!MinMaxOp) + return false; + } + + if (!RedPhiR) + return true; + + if (HasUnsupportedPhi || !Plan.hasScalarTail()) + return false; + + /// Check if the vector loop of \p Plan can early exit and restart + /// execution of last vector iteration in the scalar loop. This requires all + /// recipes up to early exit point be side-effect free as they are + /// re-executed. Currently we check that the loop is free of any recipe that + /// may write to memory. Expected to operate on an early VPlan w/o nested + /// regions. + for (VPBlockBase *VPB : vp_depth_first_shallow( + Plan.getVectorLoopRegion()->getEntryBasicBlock())) { + auto *VPBB = cast<VPBasicBlock>(VPB); + for (auto &R : *VPBB) { + if (match(&R, m_BranchOnCount(m_VPValue(), m_VPValue()))) + continue; + if (R.mayWriteToMemory()) + return false; + } + } + + auto *MiddleVPBB = Plan.getMiddleBlock(); + auto *RdxResult = dyn_cast<VPInstruction>(&MiddleVPBB->front()); + if (!RdxResult || + RdxResult->getOpcode() != VPInstruction::ComputeReductionResult || + RdxResult->getOperand(0) != RedPhiR) + return false; + + // Create a new reduction phi recipe with either FMin/FMax, replacing + // FMinNum/FMaxNum. + RecurKind NewRK = RedPhiR->getRecurrenceKind() == RecurKind::FMinNum + ? RecurKind::FMin + : RecurKind::FMax; + auto *NewRedPhiR = new VPReductionPHIRecipe( + cast<PHINode>(RedPhiR->getUnderlyingValue()), NewRK, + *RedPhiR->getStartValue(), RedPhiR->isInLoop(), RedPhiR->isOrdered()); + NewRedPhiR->addOperand(RedPhiR->getOperand(1)); + NewRedPhiR->insertBefore(RedPhiR); + RedPhiR->replaceAllUsesWith(NewRedPhiR); + RedPhiR->eraseFromParent(); + + // Update the loop exit condition to exit if either any of the inputs is NaN + // or the vector trip count is reached. + VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock(); + VPBuilder Builder(LatchVPBB->getTerminator()); + auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator()); + assert(LatchExitingBranch->getOpcode() == VPInstruction::BranchOnCount && + "Unexpected terminator"); + auto *IsLatchExitTaken = + Builder.createICmp(CmpInst::ICMP_EQ, LatchExitingBranch->getOperand(0), + LatchExitingBranch->getOperand(1)); + + VPValue *IsNaN = Builder.createFCmp(CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp); + VPValue *AnyNaN = Builder.createNaryOp(VPInstruction::AnyOf, {IsNaN}); + auto *AnyExitTaken = + Builder.createNaryOp(Instruction::Or, {AnyNaN, IsLatchExitTaken}); + Builder.createNaryOp(VPInstruction::BranchOnCond, AnyExitTaken); + LatchExitingBranch->eraseFromParent(); + + // If we exit early due to NaNs, compute the final reduction result based on + // the reduction phi at the beginning of the last vector iteration. + Builder.setInsertPoint(MiddleVPBB, MiddleVPBB->begin()); + auto *NewSel = + Builder.createSelect(AnyNaN, NewRedPhiR, RdxResult->getOperand(1)); + RdxResult->setOperand(1, NewSel); + + auto *ScalarPH = Plan.getScalarPreheader(); + // Update the resume phis for inductions in the scalar preheader. If AnyNaN is + // true, the resume from the start of the last vector iteration via the + // canonical IV, otherwise from the original value. + for (auto &R : ScalarPH->phis()) { + auto *ResumeR = cast<VPPhi>(&R); + VPValue *VecV = ResumeR->getOperand(0); + if (VecV == RdxResult) + continue; + if (VecV != &Plan.getVectorTripCount()) + return false; ---------------- fhahn wrote:
Thanks, added a debug message + comment! https://github.com/llvm/llvm-project/pull/148239 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits