This revision was landed with ongoing or failed builds.
This revision was automatically updated to reflect the committed changes.
Closed by commit rG88eb3c62f25d: Add FP8 E4M3 support to APFloat. (authored by 
reedwm, committed by bkramer).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D137760/new/

https://reviews.llvm.org/D137760

Files:
  clang/include/clang/AST/Stmt.h
  clang/lib/AST/MicrosoftMangle.cpp
  llvm/include/llvm/ADT/APFloat.h
  llvm/lib/Support/APFloat.cpp
  llvm/unittests/ADT/APFloatTest.cpp

Index: llvm/unittests/ADT/APFloatTest.cpp
===================================================================
--- llvm/unittests/ADT/APFloatTest.cpp
+++ llvm/unittests/ADT/APFloatTest.cpp
@@ -1683,6 +1683,7 @@
 TEST(APFloatTest, getLargest) {
   EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat());
   EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble());
+  EXPECT_EQ(448, APFloat::getLargest(APFloat::Float8E4M3FN()).convertToDouble());
 }
 
 TEST(APFloatTest, getSmallest) {
@@ -1766,6 +1767,8 @@
       {&APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
       {&APFloat::Float8E5M2(), false, {0, 0}, 1},
       {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1},
+      {&APFloat::Float8E4M3FN(), false, {0, 0}, 1},
+      {&APFloat::Float8E4M3FN(), true, {0x80ULL, 0}, 1},
   };
   const unsigned NumGetZeroTests = 12;
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
@@ -3665,6 +3668,16 @@
     EXPECT_EQ(f1.mod(f2), APFloat::opOK);
     EXPECT_TRUE(f1.bitwiseIsEqual(expected));
   }
+  {
+    // Test E4M3FN mod where the LHS exponent is maxExponent (8) and the RHS is
+    // the max value whose exponent is minExponent (-6). This requires special
+    // logic in the mod implementation to prevent overflow to NaN.
+    APFloat f1(APFloat::Float8E4M3FN(), "0x1p8");        // 256
+    APFloat f2(APFloat::Float8E4M3FN(), "0x1.ep-6");     // 0.029296875
+    APFloat expected(APFloat::Float8E4M3FN(), "0x1p-8"); // 0.00390625
+    EXPECT_EQ(f1.mod(f2), APFloat::opOK);
+    EXPECT_TRUE(f1.bitwiseIsEqual(expected));
+  }
 }
 
 TEST(APFloatTest, remainder) {
@@ -4756,6 +4769,389 @@
   EXPECT_TRUE(ilogb(F) == -1);
 }
 
+TEST(APFloatTest, ConvertE4M3FNToE5M2) {
+  bool losesInfo;
+  APFloat test(APFloat::Float8E4M3FN(), "1.0");
+  APFloat::opStatus status = test.convert(
+      APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven, &losesInfo);
+  EXPECT_EQ(1.0f, test.convertToFloat());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  test = APFloat(APFloat::Float8E4M3FN(), "0.0");
+  status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0.0f, test.convertToFloat());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  test = APFloat(APFloat::Float8E4M3FN(), "0x1.2p0"); // 1.125
+  status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0x1.0p0 /* 1.0 */, test.convertToFloat());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opInexact);
+
+  test = APFloat(APFloat::Float8E4M3FN(), "0x1.6p0"); // 1.375
+  status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0x1.8p0 /* 1.5 */, test.convertToFloat());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opInexact);
+
+  // Convert E4M3 denormal to E5M2 normal. Should not be truncated, despite the
+  // destination format having one fewer significand bit
+  test = APFloat(APFloat::Float8E4M3FN(), "0x1.Cp-7");
+  status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0x1.Cp-7, test.convertToFloat());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  // Test convert from NaN
+  test = APFloat(APFloat::Float8E4M3FN(), "nan");
+  status = test.convert(APFloat::Float8E5M2(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_TRUE(std::isnan(test.convertToFloat()));
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+}
+
+TEST(APFloatTest, ConvertE5M2ToE4M3FN) {
+  bool losesInfo;
+  APFloat test(APFloat::Float8E5M2(), "1.0");
+  APFloat::opStatus status = test.convert(
+      APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven, &losesInfo);
+  EXPECT_EQ(1.0f, test.convertToFloat());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  test = APFloat(APFloat::Float8E5M2(), "0.0");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0.0f, test.convertToFloat());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  test = APFloat(APFloat::Float8E5M2(), "0x1.Cp8"); // 448
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0x1.Cp8 /* 448 */, test.convertToFloat());
+  EXPECT_FALSE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+
+  // Test overflow
+  test = APFloat(APFloat::Float8E5M2(), "0x1.0p9"); // 512
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_TRUE(std::isnan(test.convertToFloat()));
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOverflow | APFloat::opInexact);
+
+  // Test underflow
+  test = APFloat(APFloat::Float8E5M2(), "0x1.0p-10");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0., test.convertToFloat());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact);
+
+  // Test rounding up to smallest denormal number
+  test = APFloat(APFloat::Float8E5M2(), "0x1.8p-10");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0x1.0p-9, test.convertToFloat());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact);
+
+  // Testing inexact rounding to denormal number
+  test = APFloat(APFloat::Float8E5M2(), "0x1.8p-9");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_EQ(0x1.0p-8, test.convertToFloat());
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact);
+
+  APFloat nan = APFloat(APFloat::Float8E4M3FN(), "nan");
+
+  // Testing convert from Inf
+  test = APFloat(APFloat::Float8E5M2(), "inf");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_TRUE(std::isnan(test.convertToFloat()));
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opInexact);
+  EXPECT_TRUE(test.bitwiseIsEqual(nan));
+
+  // Testing convert from quiet NaN
+  test = APFloat(APFloat::Float8E5M2(), "nan");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_TRUE(std::isnan(test.convertToFloat()));
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(nan));
+
+  // Testing convert from signaling NaN
+  test = APFloat(APFloat::Float8E5M2(), "snan");
+  status = test.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                        &losesInfo);
+  EXPECT_TRUE(std::isnan(test.convertToFloat()));
+  EXPECT_TRUE(losesInfo);
+  EXPECT_EQ(status, APFloat::opInvalidOp);
+  EXPECT_TRUE(test.bitwiseIsEqual(nan));
+}
+
+TEST(APFloatTest, Float8E4M3FNGetInf) {
+  APFloat t = APFloat::getInf(APFloat::Float8E4M3FN());
+  EXPECT_TRUE(t.isNaN());
+  EXPECT_FALSE(t.isInfinity());
+}
+
+TEST(APFloatTest, Float8E4M3FNFromString) {
+  // Exactly representable
+  EXPECT_EQ(448, APFloat(APFloat::Float8E4M3FN(), "448").convertToDouble());
+  // Round down to maximum value
+  EXPECT_EQ(448, APFloat(APFloat::Float8E4M3FN(), "464").convertToDouble());
+  // Round up, causing overflow to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "465").isNaN());
+  // Overflow without rounding
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "480").isNaN());
+  // Inf converted to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "inf").isNaN());
+  // NaN converted to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FN(), "nan").isNaN());
+}
+
+TEST(APFloatTest, Float8E4M3FNAdd) {
+  APFloat QNaN = APFloat::getNaN(APFloat::Float8E4M3FN(), false);
+
+  auto FromStr = [](StringRef S) {
+    return APFloat(APFloat::Float8E4M3FN(), S);
+  };
+
+  struct {
+    APFloat x;
+    APFloat y;
+    const char *result;
+    int status;
+    int category;
+    APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven;
+  } AdditionTests[] = {
+      // Test addition operations involving NaN, overflow, and the max E4M3
+      // value (448) because E4M3 differs from IEEE-754 types in these regards
+      {FromStr("448"), FromStr("16"), "448", APFloat::opInexact,
+       APFloat::fcNormal},
+      {FromStr("448"), FromStr("18"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {FromStr("448"), FromStr("32"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {FromStr("-448"), FromStr("-32"), "-NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {QNaN, FromStr("-448"), "NaN", APFloat::opOK, APFloat::fcNaN},
+      {FromStr("448"), FromStr("-32"), "416", APFloat::opOK, APFloat::fcNormal},
+      {FromStr("448"), FromStr("0"), "448", APFloat::opOK, APFloat::fcNormal},
+      {FromStr("448"), FromStr("32"), "448", APFloat::opInexact,
+       APFloat::fcNormal, APFloat::rmTowardZero},
+      {FromStr("448"), FromStr("448"), "448", APFloat::opInexact,
+       APFloat::fcNormal, APFloat::rmTowardZero},
+  };
+
+  for (size_t i = 0; i < std::size(AdditionTests); ++i) {
+    APFloat x(AdditionTests[i].x);
+    APFloat y(AdditionTests[i].y);
+    APFloat::opStatus status = x.add(y, AdditionTests[i].roundingMode);
+
+    APFloat result(APFloat::Float8E4M3FN(), AdditionTests[i].result);
+
+    EXPECT_TRUE(result.bitwiseIsEqual(x));
+    EXPECT_EQ(AdditionTests[i].status, (int)status);
+    EXPECT_EQ(AdditionTests[i].category, (int)x.getCategory());
+  }
+}
+
+TEST(APFloatTest, Float8E4M3FNDivideByZero) {
+  APFloat x(APFloat::Float8E4M3FN(), "1");
+  APFloat zero(APFloat::Float8E4M3FN(), "0");
+  EXPECT_EQ(x.divide(zero, APFloat::rmNearestTiesToEven), APFloat::opDivByZero);
+  EXPECT_TRUE(x.isNaN());
+}
+
+TEST(APFloatTest, Float8E4M3FNNext) {
+  APFloat test(APFloat::Float8E4M3FN(), APFloat::uninitialized);
+  APFloat expected(APFloat::Float8E4M3FN(), APFloat::uninitialized);
+
+  // nextUp on positive numbers
+  for (int i = 0; i < 127; i++) {
+    test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i + 1));
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextUp on negative zero
+  test = APFloat::getZero(APFloat::Float8E4M3FN(), true);
+  expected = APFloat::getSmallest(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(test.next(false), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextUp on negative nonzero numbers
+  for (int i = 129; i < 255; i++) {
+    test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i - 1));
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextUp on NaN
+  test = APFloat::getQNaN(APFloat::Float8E4M3FN(), false);
+  expected = APFloat::getQNaN(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(test.next(false), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextDown on positive nonzero finite numbers
+  for (int i = 1; i < 127; i++) {
+    test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i - 1));
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextDown on positive zero
+  test = APFloat::getZero(APFloat::Float8E4M3FN(), true);
+  expected = APFloat::getSmallest(APFloat::Float8E4M3FN(), true);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextDown on negative finite numbers
+  for (int i = 128; i < 255; i++) {
+    test = APFloat(APFloat::Float8E4M3FN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FN(), APInt(8, i + 1));
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextDown on NaN
+  test = APFloat::getQNaN(APFloat::Float8E4M3FN(), false);
+  expected = APFloat::getQNaN(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+}
+
+TEST(APFloatTest, Float8E4M3FNExhaustive) {
+  // Test each of the 256 Float8E4M3FN values.
+  for (int i = 0; i < 256; i++) {
+    APFloat test(APFloat::Float8E4M3FN(), APInt(8, i));
+    SCOPED_TRACE("i=" + std::to_string(i));
+
+    // isLargest
+    if (i == 126 || i == 254) {
+      EXPECT_TRUE(test.isLargest());
+      EXPECT_EQ(abs(test).convertToDouble(), 448.);
+    } else {
+      EXPECT_FALSE(test.isLargest());
+    }
+
+    // isSmallest
+    if (i == 1 || i == 129) {
+      EXPECT_TRUE(test.isSmallest());
+      EXPECT_EQ(abs(test).convertToDouble(), 0x1p-9);
+    } else {
+      EXPECT_FALSE(test.isSmallest());
+    }
+
+    // convert to BFloat
+    APFloat test2 = test;
+    bool loses_info;
+    APFloat::opStatus status = test2.convert(
+        APFloat::BFloat(), APFloat::rmNearestTiesToEven, &loses_info);
+    EXPECT_EQ(status, APFloat::opOK);
+    EXPECT_FALSE(loses_info);
+    if (i == 127 || i == 255)
+      EXPECT_TRUE(test2.isNaN());
+    else
+      EXPECT_EQ(test.convertToFloat(), test2.convertToFloat());
+
+    // bitcastToAPInt
+    EXPECT_EQ(i, test.bitcastToAPInt());
+  }
+}
+
+TEST(APFloatTest, Float8E4M3FNExhaustivePair) {
+  // Test each pair of Float8E4M3FN values.
+  for (int i = 0; i < 256; i++) {
+    for (int j = 0; j < 256; j++) {
+      SCOPED_TRACE("i=" + std::to_string(i) + ",j=" + std::to_string(j));
+      APFloat x(APFloat::Float8E4M3FN(), APInt(8, i));
+      APFloat y(APFloat::Float8E4M3FN(), APInt(8, j));
+
+      bool losesInfo;
+      APFloat x16 = x;
+      x16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_FALSE(losesInfo);
+      APFloat y16 = y;
+      y16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_FALSE(losesInfo);
+
+      // Add
+      APFloat z = x;
+      z.add(y, APFloat::rmNearestTiesToEven);
+      APFloat z16 = x16;
+      z16.add(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16));
+
+      // Subtract
+      z = x;
+      z.subtract(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.subtract(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16));
+
+      // Multiply
+      z = x;
+      z.multiply(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.multiply(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Divide
+      z = x;
+      z.divide(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.divide(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Mod
+      z = x;
+      z.mod(y);
+      z16 = x16;
+      z16.mod(y16);
+      z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Remainder
+      z = x;
+      z.remainder(y);
+      z16 = x16;
+      z16.remainder(y16);
+      z16.convert(APFloat::Float8E4M3FN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+    }
+  }
+}
+
 TEST(APFloatTest, IEEEdoubleToDouble) {
   APFloat DPosZero(0.0);
   APFloat DPosZeroToDouble(DPosZero.convertToDouble());
@@ -4937,6 +5333,30 @@
   EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
 }
 
+TEST(APFloatTest, Float8E4M3FNToDouble) {
+  APFloat One(APFloat::Float8E4M3FN(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat Two(APFloat::Float8E4M3FN(), "2.0");
+  EXPECT_EQ(2.0, Two.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(448., PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), true);
+  EXPECT_EQ(-448., NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(0x1.p-6, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), true);
+  EXPECT_EQ(-0x1.p-6, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3FN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1p-9, SmallestDenorm.convertToDouble());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FN());
+  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
 TEST(APFloatTest, IEEEsingleToFloat) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToFloat(FPosZero.convertToFloat());
@@ -5085,4 +5505,36 @@
   EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
 
+TEST(APFloatTest, Float8E4M3FNToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FN());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+  APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3FN(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+  APFloat One(APFloat::Float8E4M3FN(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float8E4M3FN(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(448., PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FN(), true);
+  EXPECT_EQ(-448, NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), false);
+  EXPECT_EQ(0x1.p-6, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FN(), true);
+  EXPECT_EQ(-0x1.p-6, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3FN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToFloat());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FN());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
+}
+
 } // namespace
Index: llvm/lib/Support/APFloat.cpp
===================================================================
--- llvm/lib/Support/APFloat.cpp
+++ llvm/lib/Support/APFloat.cpp
@@ -50,6 +50,23 @@
 static_assert(APFloatBase::integerPartWidth % 4 == 0, "Part width must be divisible by 4!");
 
 namespace llvm {
+
+  // How the nonfinite values Inf and NaN are represented.
+  enum class fltNonfiniteBehavior {
+    // Represents standard IEEE 754 behavior. A value is nonfinite if the
+    // exponent field is all 1s. In such cases, a value is Inf if the
+    // significand bits are all zero, and NaN otherwise
+    IEEE754,
+
+    // Only the Float8E5M2 has this behavior. There is no Inf representation. A
+    // value is NaN if the exponent field and the mantissa field are all 1s.
+    // This behavior matches the FP8 E4M3 type described in
+    // https://arxiv.org/abs/2209.05433. We treat both signed and unsigned NaNs
+    // as non-signalling, although the paper does not state whether the NaN
+    // values are signalling or not.
+    NanOnly,
+  };
+
   /* Represents floating point arithmetic semantics.  */
   struct fltSemantics {
     /* The largest E such that 2^E is representable; this matches the
@@ -67,8 +84,11 @@
     /* Number of bits actually used in the semantics. */
     unsigned int sizeInBits;
 
+    fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754;
+
     // Returns true if any number described by this semantics can be precisely
-    // represented by the specified semantics.
+    // represented by the specified semantics. Does not take into account
+    // the value of fltNonfiniteBehavior.
     bool isRepresentableBy(const fltSemantics &S) const {
       return maxExponent <= S.maxExponent && minExponent >= S.minExponent &&
              precision <= S.precision;
@@ -81,6 +101,8 @@
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
   static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
+  static const fltSemantics semFloat8E4M3FN = {8, -6, 4, 8,
+                                               fltNonfiniteBehavior::NanOnly};
   static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
   static const fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -138,6 +160,8 @@
       return PPCDoubleDouble();
     case S_Float8E5M2:
       return Float8E5M2();
+    case S_Float8E4M3FN:
+      return Float8E4M3FN();
     case S_x87DoubleExtended:
       return x87DoubleExtended();
     }
@@ -160,6 +184,8 @@
       return S_PPCDoubleDouble;
     else if (&Sem == &llvm::APFloat::Float8E5M2())
       return S_Float8E5M2;
+    else if (&Sem == &llvm::APFloat::Float8E4M3FN())
+      return S_Float8E4M3FN;
     else if (&Sem == &llvm::APFloat::x87DoubleExtended())
       return S_x87DoubleExtended;
     else
@@ -183,6 +209,7 @@
     return semPPCDoubleDouble;
   }
   const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
+  const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; }
   const fltSemantics &APFloatBase::x87DoubleExtended() {
     return semX87DoubleExtended;
   }
@@ -769,6 +796,15 @@
   integerPart *significand = significandParts();
   unsigned numParts = partCount();
 
+  APInt fill_storage;
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+    // The only NaN representation is where the mantissa is all 1s, which is
+    // non-signalling.
+    SNaN = false;
+    fill_storage = APInt::getAllOnes(semantics->precision - 1);
+    fill = &fill_storage;
+  }
+
   // Set the significand bits to the fill.
   if (!fill || fill->getNumWords() < numParts)
     APInt::tcSet(significand, 0, numParts);
@@ -869,6 +905,33 @@
   return true;
 }
 
+bool IEEEFloat::isSignificandAllOnesExceptLSB() const {
+  // Test if the significand excluding the integral bit is all ones except for
+  // the least significant bit.
+  const integerPart *Parts = significandParts();
+
+  if (Parts[0] & 1)
+    return false;
+
+  const unsigned PartCount = partCountForBits(semantics->precision);
+  for (unsigned i = 0; i < PartCount - 1; i++) {
+    if (~Parts[i] & ~unsigned{!i})
+      return false;
+  }
+
+  // Set the unused high bits to all ones when we compare.
+  const unsigned NumHighBits =
+      PartCount * integerPartWidth - semantics->precision + 1;
+  assert(NumHighBits <= integerPartWidth && NumHighBits > 0 &&
+         "Can not have more high bits to fill than integerPartWidth");
+  const integerPart HighBitFill = ~integerPart(0)
+                                  << (integerPartWidth - NumHighBits);
+  if (~(Parts[PartCount - 1] | HighBitFill | 0x1))
+    return false;
+
+  return true;
+}
+
 bool IEEEFloat::isSignificandAllZeros() const {
   // Test if the significand excluding the integral bit is all zeros. This
   // allows us to test for binade boundaries.
@@ -893,10 +956,18 @@
 }
 
 bool IEEEFloat::isLargest() const {
-  // The largest number by magnitude in our format will be the floating point
-  // number with maximum exponent and with significand that is all ones.
-  return isFiniteNonZero() && exponent == semantics->maxExponent
-    && isSignificandAllOnes();
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+    // The largest number by magnitude in our format will be the floating point
+    // number with maximum exponent and with significand that is all ones except
+    // the LSB.
+    return isFiniteNonZero() && exponent == semantics->maxExponent &&
+           isSignificandAllOnesExceptLSB();
+  } else {
+    // The largest number by magnitude in our format will be the floating point
+    // number with maximum exponent and with significand that is all ones.
+    return isFiniteNonZero() && exponent == semantics->maxExponent &&
+           isSignificandAllOnes();
+  }
 }
 
 bool IEEEFloat::isInteger() const {
@@ -1315,7 +1386,10 @@
       rounding_mode == rmNearestTiesToAway ||
       (rounding_mode == rmTowardPositive && !sign) ||
       (rounding_mode == rmTowardNegative && sign)) {
-    category = fcInfinity;
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+      makeNaN(false, sign);
+    else
+      category = fcInfinity;
     return (opStatus) (opOverflow | opInexact);
   }
 
@@ -1324,6 +1398,8 @@
   exponent = semantics->maxExponent;
   tcSetLeastSignificantBits(significandParts(), partCount(),
                             semantics->precision);
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+    APInt::tcClearBit(significandParts(), 0);
 
   return opInexact;
 }
@@ -1423,6 +1499,10 @@
     }
   }
 
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
+      exponent == semantics->maxExponent && isSignificandAllOnes())
+    return handleOverflow(rounding_mode);
+
   /* Now round the number according to rounding_mode given the lost
      fraction.  */
 
@@ -1459,6 +1539,10 @@
 
       return opInexact;
     }
+
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
+        exponent == semantics->maxExponent && isSignificandAllOnes())
+      return handleOverflow(rounding_mode);
   }
 
   /* The normal case - we were and are not denormal, and any
@@ -1679,7 +1763,10 @@
     return opOK;
 
   case PackCategoriesIntoKey(fcNormal, fcZero):
-    category = fcInfinity;
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+      makeNaN(false, sign);
+    else
+      category = fcInfinity;
     return opDivByZero;
 
   case PackCategoriesIntoKey(fcInfinity, fcInfinity):
@@ -1965,9 +2052,12 @@
 
   while (isFiniteNonZero() && rhs.isFiniteNonZero() &&
          compareAbsoluteValue(rhs) != cmpLessThan) {
-    IEEEFloat V = scalbn(rhs, ilogb(*this) - ilogb(rhs), rmNearestTiesToEven);
-    if (compareAbsoluteValue(V) == cmpLessThan)
-      V = scalbn(V, -1, rmNearestTiesToEven);
+    int Exp = ilogb(*this) - ilogb(rhs);
+    IEEEFloat V = scalbn(rhs, Exp, rmNearestTiesToEven);
+    // V can overflow to NaN with fltNonfiniteBehavior::NanOnly, so explicitly
+    // check for it.
+    if (V.isNaN() || compareAbsoluteValue(V) == cmpLessThan)
+      V = scalbn(rhs, Exp - 1, rmNearestTiesToEven);
     V.sign = sign;
 
     fs = subtract(V, rmNearestTiesToEven);
@@ -2194,6 +2284,7 @@
   opStatus fs;
   int shift;
   const fltSemantics &fromSemantics = *semantics;
+  bool is_signaling = isSignaling();
 
   lostFraction = lfExactlyZero;
   newPartCount = partCountForBits(toSemantics.precision + 1);
@@ -2235,7 +2326,9 @@
   }
 
   // If this is a truncation, perform the shift before we narrow the storage.
-  if (shift < 0 && (isFiniteNonZero() || category==fcNaN))
+  if (shift < 0 && (isFiniteNonZero() ||
+                    (category == fcNaN && semantics->nonFiniteBehavior !=
+                                              fltNonfiniteBehavior::NanOnly)))
     lostFraction = shiftRight(significandParts(), oldPartCount, -shift);
 
   // Fix the storage so it can hold to new value.
@@ -2269,6 +2362,13 @@
     fs = normalize(rounding_mode, lostFraction);
     *losesInfo = (fs != opOK);
   } else if (category == fcNaN) {
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+      *losesInfo =
+          fromSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly;
+      makeNaN(false, sign);
+      return is_signaling ? opInvalidOp : opOK;
+    }
+
     *losesInfo = lostFraction != lfExactlyZero || X86SpecialNan;
 
     // For x87 extended precision, we want to make a NaN, not a special NaN if
@@ -2279,12 +2379,17 @@
     // Convert of sNaN creates qNaN and raises an exception (invalid op).
     // This also guarantees that a sNaN does not become Inf on a truncation
     // that loses all payload bits.
-    if (isSignaling()) {
+    if (is_signaling) {
       makeQuiet();
       fs = opInvalidOp;
     } else {
       fs = opOK;
     }
+  } else if (category == fcInfinity &&
+             semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+    makeNaN(false, sign);
+    *losesInfo = true;
+    fs = opInexact;
   } else {
     *losesInfo = false;
     fs = opOK;
@@ -3382,6 +3487,33 @@
                    (mysignificand & 0x3)));
 }
 
+APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 7; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x8))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0xf;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0xf;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) |
+                   (mysignificand & 0x7)));
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3408,6 +3540,9 @@
   if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2)
     return convertFloat8E5M2APFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN)
+    return convertFloat8E4M3FNAPFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3663,10 +3798,33 @@
   }
 }
 
-/// Treat api as containing the bits of a floating point number.  Currently
-/// we infer the floating point type from the size of the APInt.  The
-/// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful
-/// when the size is anything else).
+void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) {
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 3) & 0xf;
+  uint32_t mysignificand = i & 0x7;
+
+  initialize(&semFloat8E4M3FN);
+  assert(partCount() == 1);
+
+  sign = i >> 7;
+  if (myexponent == 0 && mysignificand == 0) {
+    makeZero(sign);
+  } else if (myexponent == 0xf && mysignificand == 7) {
+    category = fcNaN;
+    exponent = exponentNaN();
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 7; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -6;
+    else
+      *significandParts() |= 0x8; // integer bit
+  }
+}
+
+/// Treat api as containing the bits of a floating point number.
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   assert(api.getBitWidth() == Sem->sizeInBits);
   if (Sem == &semIEEEhalf)
@@ -3685,6 +3843,8 @@
     return initFromPPCDoubleDoubleAPInt(api);
   if (Sem == &semFloat8E5M2)
     return initFromFloat8E5M2APInt(api);
+  if (Sem == &semFloat8E4M3FN)
+    return initFromFloat8E4M3FNAPInt(api);
 
   llvm_unreachable(nullptr);
 }
@@ -3712,6 +3872,9 @@
   significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth)
                                    ? (~integerPart(0) >> NumUnusedHighBits)
                                    : 0;
+
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+    significand[0] &= ~integerPart(1);
 }
 
 /// Make this number the smallest magnitude denormal number in the given
@@ -4085,6 +4248,8 @@
 bool IEEEFloat::isSignaling() const {
   if (!isNaN())
     return false;
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+    return false;
 
   // IEEE-754R 2008 6.2.1: A signaling NaN bit string should be encoded with the
   // first bit of the trailing significand being 0.
@@ -4135,12 +4300,18 @@
       break;
     }
 
-    // nextUp(getLargest()) == INFINITY
     if (isLargest() && !isNegative()) {
-      APInt::tcSet(significandParts(), 0, partCount());
-      category = fcInfinity;
-      exponent = semantics->maxExponent + 1;
-      break;
+      if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+        // nextUp(getLargest()) == NAN
+        makeNaN();
+        break;
+      } else {
+        // nextUp(getLargest()) == INFINITY
+        APInt::tcSet(significandParts(), 0, partCount());
+        category = fcInfinity;
+        exponent = semantics->maxExponent + 1;
+        break;
+      }
     }
 
     // nextUp(normal) == normal + inc.
@@ -4212,6 +4383,8 @@
 }
 
 APFloatBase::ExponentType IEEEFloat::exponentNaN() const {
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+    return semantics->maxExponent;
   return semantics->maxExponent + 1;
 }
 
@@ -4224,6 +4397,11 @@
 }
 
 void IEEEFloat::makeInf(bool Negative) {
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+    // There is no Inf, so make NaN instead.
+    makeNaN(false, Negative);
+    return;
+  }
   category = fcInfinity;
   sign = Negative;
   exponent = exponentInf();
@@ -4239,7 +4417,8 @@
 
 void IEEEFloat::makeQuiet() {
   assert(isNaN());
-  APInt::tcSetBit(significandParts(), semantics->precision - 2);
+  if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly)
+    APInt::tcSetBit(significandParts(), semantics->precision - 2);
 }
 
 int ilogb(const IEEEFloat &Arg) {
Index: llvm/include/llvm/ADT/APFloat.h
===================================================================
--- llvm/include/llvm/ADT/APFloat.h
+++ llvm/include/llvm/ADT/APFloat.h
@@ -156,8 +156,13 @@
     S_IEEEquad,
     S_PPCDoubleDouble,
     // 8-bit floating point number following IEEE-754 conventions with bit
-    // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433
+    // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433.
     S_Float8E5M2,
+    // 8-bit floating point number mostly following IEEE-754 conventions with
+    // bit layout S1E4M3 as described in https://arxiv.org/abs/2209.05433.
+    // Unlike IEEE-754 types, there are no infinity values, and NaN is
+    // represented with the exponent and mantissa bits set to all 1s.
+    S_Float8E4M3FN,
     S_x87DoubleExtended,
     S_MaxSemantics = S_x87DoubleExtended,
   };
@@ -172,6 +177,7 @@
   static const fltSemantics &IEEEquad() LLVM_READNONE;
   static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
   static const fltSemantics &Float8E5M2() LLVM_READNONE;
+  static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -508,6 +514,7 @@
   void zeroSignificand();
   /// Return true if the significand excluding the integral bit is all ones.
   bool isSignificandAllOnes() const;
+  bool isSignificandAllOnesExceptLSB() const;
   /// Return true if the significand excluding the integral bit is all zeros.
   bool isSignificandAllZeros() const;
 
@@ -557,6 +564,7 @@
   APInt convertF80LongDoubleAPFloatToAPInt() const;
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
   APInt convertFloat8E5M2APFloatToAPInt() const;
+  APInt convertFloat8E4M3FNAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
   void initFromBFloatAPInt(const APInt &api);
@@ -566,6 +574,7 @@
   void initFromF80LongDoubleAPInt(const APInt &api);
   void initFromPPCDoubleDoubleAPInt(const APInt &api);
   void initFromFloat8E5M2APInt(const APInt &api);
+  void initFromFloat8E4M3FNAPInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -839,6 +839,7 @@
   case APFloat::S_IEEEquad: Out << 'Y'; break;
   case APFloat::S_PPCDoubleDouble: Out << 'Z'; break;
   case APFloat::S_Float8E5M2:
+  case APFloat::S_Float8E4M3FN:
     llvm_unreachable("Tried to mangle unexpected APFloat semantics");
   }
 
Index: clang/include/clang/AST/Stmt.h
===================================================================
--- clang/include/clang/AST/Stmt.h
+++ clang/include/clang/AST/Stmt.h
@@ -22,6 +22,7 @@
 #include "clang/Basic/LangOptions.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/Specifiers.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/PointerIntPair.h"
@@ -389,7 +390,10 @@
 
     unsigned : NumExprBits;
 
-    unsigned Semantics : 3; // Provides semantics for APFloat construction
+    static_assert(
+        llvm::APFloat::S_MaxSemantics < 16,
+        "Too many Semantics enum values to fit in bitfield of size 4");
+    unsigned Semantics : 4; // Provides semantics for APFloat construction
     unsigned IsExact : 1;
   };
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to