https://github.com/ckoparkar updated 
https://github.com/llvm/llvm-project/pull/152919

>From 24139055908afd5d5f32c2ec975c5d831fa07174 Mon Sep 17 00:00:00 2001
From: Chaitanya Koparkar <ckopar...@gmail.com>
Date: Sat, 9 Aug 2025 07:18:20 -0400
Subject: [PATCH] [clang] Enable constexpr handling for
 __builtin_elementwise_fma

---
 clang/docs/LanguageExtensions.rst            |  5 +-
 clang/include/clang/Basic/Builtins.td        |  2 +-
 clang/lib/AST/ByteCode/InterpBuiltin.cpp     | 69 ++++++++++++++++++++
 clang/lib/AST/ExprConstant.cpp               | 39 +++++++++++
 clang/test/Sema/constant-builtins-vector.cpp | 22 +++++++
 5 files changed, 134 insertions(+), 3 deletions(-)

diff --git a/clang/docs/LanguageExtensions.rst 
b/clang/docs/LanguageExtensions.rst
index b5bb198ca637a..e2aa2ad58a41e 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -757,9 +757,10 @@ elementwise to the input.
 
 Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = 
±infinity
 
-The integer elementwise intrinsics, including 
``__builtin_elementwise_popcount``,
+The elementwise intrinsics ``__builtin_elementwise_popcount``,
 ``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
-``__builtin_elementwise_sub_sat`` can be called in a ``constexpr`` context.
+``__builtin_elementwise_sub_sat``, and ``__builtin_elementwise_fma``
+can be called in a ``constexpr`` context.
 
 No implicit promotion of integer types takes place. The mixing of integer types
 of different sizes and signs is forbidden in binary and ternary builtins.
diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index c81714e9b009d..0e6a0af34b5da 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
 
 def ElementwiseFma : Builtin {
   let Spellings = ["__builtin_elementwise_fma"];
-  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
   let Prototype = "void(...)";
 }
 
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp 
b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index c835bd4fb6088..22dfafecca383 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -141,6 +141,16 @@ static void diagnoseNonConstexprBuiltin(InterpState &S, 
CodePtr OpPC,
     S.CCEDiag(Loc, diag::note_invalid_subexpr_in_const_expr);
 }
 
+// Same implementation as Compiler::getRoundingMode.
+static llvm::RoundingMode getRoundingMode(const InterpState &S, const Expr *E) 
{
+  FPOptions FPO = E->getFPFeaturesInEffect(S.Ctx.getLangOpts());
+
+  if (FPO.getRoundingMode() == llvm::RoundingMode::Dynamic)
+    return llvm::RoundingMode::NearestTiesToEven;
+
+  return FPO.getRoundingMode();
+}
+
 static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
                                                   const InterpFrame *Frame,
                                                   const CallExpr *Call) {
@@ -2320,6 +2330,62 @@ static bool interp__builtin_elementwise_sat(InterpState 
&S, CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
+                                            const CallExpr *Call) {
+  assert(Call->getNumArgs() == 3);
+
+  llvm::RoundingMode RM = getRoundingMode(S, Call);
+  const QualType Arg1Type = Call->getArg(0)->getType();
+  const QualType Arg2Type = Call->getArg(1)->getType();
+  const QualType Arg3Type = Call->getArg(2)->getType();
+
+  // Non-vector floating point types.
+  if (!Arg1Type->isVectorType()) {
+    assert(!Arg2Type->isVectorType());
+    assert(!Arg3Type->isVectorType());
+
+    const Floating &Z = S.Stk.pop<Floating>();
+    const Floating &Y = S.Stk.pop<Floating>();
+    const Floating &X = S.Stk.pop<Floating>();
+    APFloat F = X.getAPFloat();
+    F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
+    Floating Result = S.allocFloat(X.getSemantics());
+    Result.copy(F);
+    S.Stk.push<Floating>(Result);
+    return true;
+  }
+
+  // Vector type.
+  assert(Arg1Type->isVectorType() &&
+         Arg2Type->isVectorType() &&
+         Arg3Type->isVectorType());
+
+  const VectorType *VecT = Arg1Type->castAs<VectorType>();
+  const QualType ElemT = VecT->getElementType();
+  unsigned NumElems = VecT->getNumElements();
+
+  assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
+         ElemT == Arg3Type->castAs<VectorType>()->getElementType());
+  assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
+         NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
+  assert(ElemT->isRealFloatingType());
+
+  const Pointer &VZ = S.Stk.pop<Pointer>();
+  const Pointer &VY = S.Stk.pop<Pointer>();
+  const Pointer &VX = S.Stk.pop<Pointer>();
+  const Pointer &Dst = S.Stk.peek<Pointer>();
+  for (unsigned I = 0; I != NumElems; ++I) {
+    using T = PrimConv<PT_Float>::T;
+    APFloat X = VX.elem<T>(I).getAPFloat();
+    APFloat Y = VY.elem<T>(I).getAPFloat();
+    APFloat Z = VZ.elem<T>(I).getAPFloat();
+    (void)X.fusedMultiplyAdd(Y, Z, RM);
+    Dst.elem<Floating>(I) = Floating(X);
+  }
+  Dst.initializeAllElements();
+  return true;
+}
+
 bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
                       uint32_t BuiltinID) {
   if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -2727,6 +2793,9 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const 
CallExpr *Call,
   case Builtin::BI__builtin_elementwise_sub_sat:
     return interp__builtin_elementwise_sat(S, OpPC, Call, BuiltinID);
 
+  case Builtin::BI__builtin_elementwise_fma:
+    return interp__builtin_elementwise_fma(S, OpPC, Call);
+
   default:
     S.FFDiag(S.Current->getLocation(OpPC),
              diag::note_invalid_subexpr_in_const_expr)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 3679327da7b0c..a7293415af0ce 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11658,6 +11658,29 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
 
     return Success(APValue(ResultElements.data(), ResultElements.size()), E);
   }
+  case Builtin::BI__builtin_elementwise_fma: {
+    APValue SourceX, SourceY, SourceZ;
+    if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
+        !EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
+        !EvaluateAsRValue(Info, E->getArg(2), SourceZ))
+      return false;
+
+    unsigned SourceLen = SourceX.getVectorLength();
+    SmallVector<APValue> ResultElements;
+    ResultElements.reserve(SourceLen);
+    llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+
+    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+      APFloat X = SourceX.getVectorElt(EltNum).getFloat();
+      APFloat Y = SourceY.getVectorElt(EltNum).getFloat();
+      APFloat Z = SourceZ.getVectorElt(EltNum).getFloat();
+      APFloat Result(X);
+      (void)Result.fusedMultiplyAdd(Y, Z, RM);
+      ResultElements.push_back(APValue(Result));
+    }
+
+    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+  }
   }
 }
 
@@ -15878,6 +15901,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
     Result = minimumnum(Result, RHS);
     return true;
   }
+
+  case Builtin::BI__builtin_elementwise_fma: {
+    if(!E->getArg(0)->isPRValue() ||
+       !E->getArg(1)->isPRValue() ||
+       !E->getArg(2)->isPRValue()) {
+      return false;
+    }
+    APFloat SourceY(0.), SourceZ(0.);
+    if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+        !EvaluateFloat(E->getArg(1), SourceY, Info) ||
+        !EvaluateFloat(E->getArg(2), SourceZ, Info))
+      return false;
+    llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+    (void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
+    return true;
+  }
   }
 }
 
diff --git a/clang/test/Sema/constant-builtins-vector.cpp 
b/clang/test/Sema/constant-builtins-vector.cpp
index bde5c478b2b6f..5fa0a7d447ebe 100644
--- a/clang/test/Sema/constant-builtins-vector.cpp
+++ b/clang/test/Sema/constant-builtins-vector.cpp
@@ -860,3 +860,25 @@ static_assert(__builtin_elementwise_sub_sat(0U, 1U) == 0U);
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_sub_sat((vector4char){5, 4, 3, 2}, (vector4char){1, 1, 1, 
1})) == (LITTLE_END ? 0x01020304 : 0x04030201));
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_sub_sat((vector4uchar){5, 4, 3, 2}, (vector4uchar){1, 1, 
1, 1})) == (LITTLE_END ? 0x01020304U : 0x04030201U));
 static_assert(__builtin_bit_cast(unsigned long long, 
__builtin_elementwise_sub_sat((vector4short){(short)0x8000, (short)0x8001, 
(short)0x8002, (short)0x8003}, (vector4short){7, 8, 9, 10}) == (LITTLE_END ? 
0x8000800080008000 : 0x8000800080008000)));
+
+
+// Non-vector floating point types.
+static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
+static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
+// Vector type.
+constexpr vector4float fmaFloat1 =
+  __builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
+                            (vector4float){2.0, 3.0, 4.0, 5.0},
+                            (vector4float){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaFloat1[0] == 5.0);
+static_assert(fmaFloat1[1] == 10.0);
+static_assert(fmaFloat1[2] == 17.0);
+static_assert(fmaFloat1[3] == 26.0);
+constexpr vector4double fmaDouble1 =
+  __builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
+                            (vector4double){2.0, 3.0, 4.0, 5.0},
+                            (vector4double){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaDouble1[0] == 5.0);
+static_assert(fmaDouble1[1] == 10.0);
+static_assert(fmaDouble1[2] == 17.0);
+static_assert(fmaDouble1[3] == 26.0);

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to