================
@@ -652,3 +652,140 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value 
*Cond,
     Term->addMetadata(LLVMContext::MD_prof, BranchWeights);
   }
 }
+
+bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath(VPlan &Plan) {
+  VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
+  VPReductionPHIRecipe *RedPhiR = nullptr;
+  VPValue *MinMaxOp = nullptr;
+  bool HasUnsupportedPhi = false;
+
+  auto GetMinMaxCompareValue = [](VPSingleDefRecipe *MinMaxOp,
+                                  VPReductionPHIRecipe *RedPhi) -> VPValue * {
+    auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
+    if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
+        !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr()))))
+      return nullptr;
+
+    if (MinMaxOp->getOperand(0) == RedPhi)
+      return MinMaxOp->getOperand(1);
+    assert(MinMaxOp->getOperand(1) == RedPhi &&
+           "Reduction phi operand expected");
+    return MinMaxOp->getOperand(0);
+  };
+
+  for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) {
+    // TODO: Also support first-order recurrence phis.
+    HasUnsupportedPhi |=
+        !isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe,
+             VPReductionPHIRecipe>(&R);
+    auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R);
+    if (!Cur)
+      continue;
+    // For now, only a single reduction is supported.
+    // TODO: Support multiple MaxNum/MinNum reductions and other reductions.
+    if (RedPhiR)
+      return false;
+    if (Cur->getRecurrenceKind() != RecurKind::FMaxNum &&
+        Cur->getRecurrenceKind() != RecurKind::FMinNum)
+      continue;
+
+    RedPhiR = Cur;
+    auto *MinMaxR = dyn_cast<VPRecipeWithIRFlags>(
+        RedPhiR->getBackedgeValue()->getDefiningRecipe());
+    if (!MinMaxR)
+      return false;
+    MinMaxOp = GetMinMaxCompareValue(MinMaxR, RedPhiR);
+    if (!MinMaxOp)
+      return false;
+  }
+
+  if (!RedPhiR)
+    return true;
+
+  if (HasUnsupportedPhi || !Plan.hasScalarTail())
+    return false;
+
+  /// Check if the vector loop of \p Plan can early exit and restart
+  /// execution of last vector iteration in the scalar loop. This requires all
+  /// recipes up to early exit point be side-effect free as they are
+  /// re-executed. Currently we check that the loop is free of any recipe that
+  /// may write to memory. Expected to operate on an early VPlan w/o nested
+  /// regions.
+  for (VPBlockBase *VPB : vp_depth_first_shallow(
+           Plan.getVectorLoopRegion()->getEntryBasicBlock())) {
+    auto *VPBB = cast<VPBasicBlock>(VPB);
+    for (auto &R : *VPBB) {
+      if (match(&R, m_BranchOnCount(m_VPValue(), m_VPValue())))
+        continue;
+      if (R.mayWriteToMemory())
+        return false;
+    }
+  }
+
+  auto *MiddleVPBB = Plan.getMiddleBlock();
+  auto *RdxResult = dyn_cast<VPInstruction>(&MiddleVPBB->front());
+  if (!RdxResult ||
+      RdxResult->getOpcode() != VPInstruction::ComputeReductionResult ||
+      RdxResult->getOperand(0) != RedPhiR)
+    return false;
+
+  // Create a new reduction phi recipe with either FMin/FMax, replacing
+  // FMinNum/FMaxNum.
+  RecurKind NewRK = RedPhiR->getRecurrenceKind() == RecurKind::FMinNum
+                        ? RecurKind::FMin
+                        : RecurKind::FMax;
+  auto *NewRedPhiR = new VPReductionPHIRecipe(
+      cast<PHINode>(RedPhiR->getUnderlyingValue()), NewRK,
+      *RedPhiR->getStartValue(), RedPhiR->isInLoop(), RedPhiR->isOrdered());
+  NewRedPhiR->addOperand(RedPhiR->getOperand(1));
+  NewRedPhiR->insertBefore(RedPhiR);
+  RedPhiR->replaceAllUsesWith(NewRedPhiR);
+  RedPhiR->eraseFromParent();
+
+  // Update the loop exit condition to exit if either any of the inputs is NaN
+  // or the vector trip count is reached.
+  VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock();
+  VPBuilder Builder(LatchVPBB->getTerminator());
+  auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator());
+  assert(LatchExitingBranch->getOpcode() == VPInstruction::BranchOnCount &&
+         "Unexpected terminator");
+  auto *IsLatchExitTaken =
+      Builder.createICmp(CmpInst::ICMP_EQ, LatchExitingBranch->getOperand(0),
+                         LatchExitingBranch->getOperand(1));
+
+  VPValue *IsNaN = Builder.createFCmp(CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp);
+  VPValue *AnyNaN = Builder.createNaryOp(VPInstruction::AnyOf, {IsNaN});
+  auto *AnyExitTaken =
+      Builder.createNaryOp(Instruction::Or, {AnyNaN, IsLatchExitTaken});
+  Builder.createNaryOp(VPInstruction::BranchOnCond, AnyExitTaken);
+  LatchExitingBranch->eraseFromParent();
+
+  // If we exit early due to NaNs, compute the final reduction result based on
+  // the reduction phi at the beginning of the last vector iteration.
+  Builder.setInsertPoint(MiddleVPBB, MiddleVPBB->begin());
+  auto *NewSel =
+      Builder.createSelect(AnyNaN, NewRedPhiR, RdxResult->getOperand(1));
+  RdxResult->setOperand(1, NewSel);
+
+  auto *ScalarPH = Plan.getScalarPreheader();
+  // Update the resume phis for inductions in the scalar preheader. If AnyNaN 
is
+  // true, the resume from the start of the last vector iteration via the
+  // canonical IV, otherwise from the original value.
+  for (auto &R : ScalarPH->phis()) {
+    auto *ResumeR = cast<VPPhi>(&R);
+    VPValue *VecV = ResumeR->getOperand(0);
+    if (VecV == RdxResult)
+      continue;
+    if (VecV != &Plan.getVectorTripCount())
+      return false;
----------------
fhahn wrote:

Thanks, added a debug message + comment!

https://github.com/llvm/llvm-project/pull/148239
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to