================
@@ -605,15 +608,117 @@ static Value *emitRsqIEEE1ULP(IRBuilder<> &Builder, 
Value *Src,
   return Builder.CreateFMul(Rsq, OutputScaleFactor);
 }
 
+/// Emit inverse sqrt expansion for f64 with a correction sequence on top of
+/// v_rsq_f64. This should give a 1ulp result.
+Value *AMDGPUCodeGenPrepareImpl::emitRsqF64(IRBuilder<> &Builder, Value *X,
+                                            FastMathFlags SqrtFMF,
+                                            FastMathFlags DivFMF,
+                                            const Instruction *CtxI,
+                                            bool IsNegative) const {
+  // rsq(x):
+  //   double y0 = BUILTIN_AMDGPU_RSQRT_F64(x);
+  //   double e = MATH_MAD(-y0 * (x == PINF_F64 || x == 0.0 ? y0 : x), y0, 
1.0);
+  //   return MATH_MAD(y0*e, MATH_MAD(e, 0.375, 0.5), y0);
+  //
+  // The rsq instruction handles the special cases correctly. We need to check
+  // for the edge case conditions to ensure the special case propagates through
+  // the later instructions.
+
+  Value *Y0 = Builder.CreateUnaryIntrinsic(Intrinsic::amdgcn_rsq, X);
+
+  // Try to elide the edge case check.
+  //
+  // Fast math flags imply:
+  //   sqrt ninf => !isinf(x)
+  //   sqrt nnan => not helpful
+  //   fdiv ninf => x != 0, !isinf(x)
+  //   fdiv nnan => x != 0
+  bool MaybePosInf = !SqrtFMF.noInfs() && !DivFMF.noInfs();
+  bool MaybeZero = !DivFMF.noInfs() && !DivFMF.noNaNs();
+
+  DenormalMode DenormMode;
+  FPClassTest Interested = fcNone;
+  if (MaybeZero)
+    Interested = fcZero;
+  if (MaybePosInf)
+    Interested = fcPosInf;
+
+  if (Interested != fcNone) {
+    KnownFPClass KnownSrc = computeKnownFPClass(X, Interested, CtxI);
+    if (KnownSrc.isKnownNeverPosInfinity())
+      MaybePosInf = false;
+
+    DenormMode = F.getDenormalMode(X->getType()->getFltSemantics());
+    if (KnownSrc.isKnownNeverLogicalZero(DenormMode))
+      MaybeZero = false;
+  }
+
+  Value *SpecialOrRsq = Y0;
----------------
dtcxzyw wrote:

This doesn't match your code above. IIUC it should be `(x == PINF_F64 || x == 
0.0 ? x : y0)`.
See also 
https://github.com/Multi2Sim/m2s-bench-cudasdk-6.5/blob/1f0f416b45d918936c598f41ec3b80434f6502ed/include/math_functions_dbl_ptx3.h#L710-L727



https://github.com/llvm/llvm-project/pull/172053
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to