https://github.com/ckoparkar updated https://github.com/llvm/llvm-project/pull/152919
>From 6d3acba2796345c56bc3df71d2bc6f6f131395af Mon Sep 17 00:00:00 2001 From: Chaitanya Koparkar <ckopar...@gmail.com> Date: Wed, 20 Aug 2025 08:48:00 -0400 Subject: [PATCH] [clang] Enable constexpr handling for __builtin_elementwise_fma --- clang/docs/LanguageExtensions.rst | 8 +-- clang/include/clang/Basic/Builtins.td | 2 +- clang/lib/AST/ByteCode/InterpBuiltin.cpp | 58 ++++++++++++++++++++ clang/lib/AST/ExprConstant.cpp | 37 +++++++++++++ clang/test/CodeGen/rounding-math.cpp | 52 ++++++++++++++++++ clang/test/Sema/constant-builtins-vector.cpp | 21 +++++++ 6 files changed, 173 insertions(+), 5 deletions(-) diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst index 66ac22349367a..79d66e47cccad 100644 --- a/clang/docs/LanguageExtensions.rst +++ b/clang/docs/LanguageExtensions.rst @@ -757,12 +757,12 @@ elementwise to the input. Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity -The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``, +The elementwise intrinsics ``__builtin_elementwise_popcount``, ``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``, ``__builtin_elementwise_sub_sat``, ``__builtin_elementwise_max``, ``__builtin_elementwise_min``, ``__builtin_elementwise_abs``, -``__builtin_elementwise_ctlz``, and ``__builtin_elementwise_cttz`` can be -called in a ``constexpr`` context. +``__builtin_elementwise_ctlz``, ``__builtin_elementwise_cttz``, and +``__builtin_elementwise_fma`` can be called in a ``constexpr`` context. No implicit promotion of integer types takes place. The mixing of integer types of different sizes and signs is forbidden in binary and ternary builtins. @@ -4379,7 +4379,7 @@ fall into one of the specified floating-point classes. if (__builtin_isfpclass(x, 448)) { // `x` is positive finite value - ... + ... } **Description**: diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index ad340e2ed0eec..332f369a9032f 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin { def ElementwiseFma : Builtin { let Spellings = ["__builtin_elementwise_fma"]; - let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr]; let Prototype = "void(...)"; } diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp index fd8c70c392dcb..5de5091178b8f 100644 --- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp +++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp @@ -2714,6 +2714,62 @@ static bool interp__builtin_ia32_pmul(InterpState &S, CodePtr OpPC, return true; } +static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC, + const CallExpr *Call) { + assert(Call->getNumArgs() == 3); + + FPOptions FPO = Call->getFPFeaturesInEffect(S.Ctx.getLangOpts()); + llvm::RoundingMode RM = getRoundingMode(FPO); + const QualType Arg1Type = Call->getArg(0)->getType(); + const QualType Arg2Type = Call->getArg(1)->getType(); + const QualType Arg3Type = Call->getArg(2)->getType(); + + // Non-vector floating point types. + if (!Arg1Type->isVectorType()) { + assert(!Arg2Type->isVectorType()); + assert(!Arg3Type->isVectorType()); + + const Floating &Z = S.Stk.pop<Floating>(); + const Floating &Y = S.Stk.pop<Floating>(); + const Floating &X = S.Stk.pop<Floating>(); + APFloat F = X.getAPFloat(); + F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM); + Floating Result = S.allocFloat(X.getSemantics()); + Result.copy(F); + S.Stk.push<Floating>(Result); + return true; + } + + // Vector type. + assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() && + Arg3Type->isVectorType()); + + const VectorType *VecT = Arg1Type->castAs<VectorType>(); + const QualType ElemT = VecT->getElementType(); + unsigned NumElems = VecT->getNumElements(); + + assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() && + ElemT == Arg3Type->castAs<VectorType>()->getElementType()); + assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() && + NumElems == Arg3Type->castAs<VectorType>()->getNumElements()); + assert(ElemT->isRealFloatingType()); + + const Pointer &VZ = S.Stk.pop<Pointer>(); + const Pointer &VY = S.Stk.pop<Pointer>(); + const Pointer &VX = S.Stk.pop<Pointer>(); + const Pointer &Dst = S.Stk.peek<Pointer>(); + for (unsigned I = 0; I != NumElems; ++I) { + using T = PrimConv<PT_Float>::T; + APFloat X = VX.elem<T>(I).getAPFloat(); + APFloat Y = VY.elem<T>(I).getAPFloat(); + APFloat Z = VZ.elem<T>(I).getAPFloat(); + (void)X.fusedMultiplyAdd(Y, Z, RM); + Dst.elem<Floating>(I) = Floating(X); + } + Dst.initializeAllElements(); + return true; +} + bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, uint32_t BuiltinID) { if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID)) @@ -3145,6 +3201,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case clang::X86::BI__builtin_ia32_pmuludq128: case clang::X86::BI__builtin_ia32_pmuludq256: return interp__builtin_ia32_pmul(S, OpPC, Call, BuiltinID); + case Builtin::BI__builtin_elementwise_fma: + return interp__builtin_elementwise_fma(S, OpPC, Call); default: S.FFDiag(S.Current->getLocation(OpPC), diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 9c87a88899647..6da88e52fb4f3 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -11874,6 +11874,28 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) { return Success(APValue(ResultElements.data(), ResultElements.size()), E); } + + case Builtin::BI__builtin_elementwise_fma: { + APValue SourceX, SourceY, SourceZ; + if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) || + !EvaluateAsRValue(Info, E->getArg(1), SourceY) || + !EvaluateAsRValue(Info, E->getArg(2), SourceZ)) + return false; + + unsigned SourceLen = SourceX.getVectorLength(); + SmallVector<APValue> ResultElements; + ResultElements.reserve(SourceLen); + llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E); + for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) { + const APFloat &X = SourceX.getVectorElt(EltNum).getFloat(); + const APFloat &Y = SourceY.getVectorElt(EltNum).getFloat(); + const APFloat &Z = SourceZ.getVectorElt(EltNum).getFloat(); + APFloat Result(X); + (void)Result.fusedMultiplyAdd(Y, Z, RM); + ResultElements.push_back(APValue(Result)); + } + return Success(APValue(ResultElements.data(), ResultElements.size()), E); + } } } @@ -16139,6 +16161,21 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) { Result = minimumnum(Result, RHS); return true; } + + case Builtin::BI__builtin_elementwise_fma: { + if(!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() || + !E->getArg(2)->isPRValue()) { + return false; + } + APFloat SourceY(0.), SourceZ(0.); + if (!EvaluateFloat(E->getArg(0), Result, Info) || + !EvaluateFloat(E->getArg(1), SourceY, Info) || + !EvaluateFloat(E->getArg(2), SourceZ, Info)) + return false; + llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E); + (void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM); + return true; + } } } diff --git a/clang/test/CodeGen/rounding-math.cpp b/clang/test/CodeGen/rounding-math.cpp index 264031dc9daa9..5c44fd31242c6 100644 --- a/clang/test/CodeGen/rounding-math.cpp +++ b/clang/test/CodeGen/rounding-math.cpp @@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F); // CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4 // CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4 // CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4 + +void test_builtin_elementwise_fma_round_upward() { + #pragma STDC FENV_ACCESS ON + #pragma STDC FENV_ROUND FE_UPWARD + + // CHECK: store float 0x4018000100000000, ptr %f1 + // CHECK: store float 0x4018000100000000, ptr %f2 + constexpr float f1 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F); + constexpr float f2 = 2.0F * 3.000001F + 0.000001F; + static_assert(f1 == f2); + static_assert(f1 == 6.00000381F); + // CHECK: store double 0x40180000C9539B89, ptr %d1 + // CHECK: store double 0x40180000C9539B89, ptr %d2 + constexpr double d1 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001); + constexpr double d2 = 2.0 * 3.000001 + 0.000001; + static_assert(d1 == d2); + static_assert(d1 == 6.0000030000000004); +} + +void test_builtin_elementwise_fma_round_downward() { + #pragma STDC FENV_ACCESS ON + #pragma STDC FENV_ROUND FE_DOWNWARD + + // CHECK: store float 0x40180000C0000000, ptr %f3 + // CHECK: store float 0x40180000C0000000, ptr %f4 + constexpr float f3 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F); + constexpr float f4 = 2.0F * 3.000001F + 0.000001F; + static_assert(f3 == f4); + // CHECK: store double 0x40180000C9539B87, ptr %d3 + // CHECK: store double 0x40180000C9539B87, ptr %d4 + constexpr double d3 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001); + constexpr double d4 = 2.0 * 3.000001 + 0.000001; + static_assert(d3 == d4); +} + +void test_builtin_elementwise_fma_round_nearest() { + #pragma STDC FENV_ACCESS ON + #pragma STDC FENV_ROUND FE_TONEAREST + + // CHECK: store float 0x40180000C0000000, ptr %f5 + // CHECK: store float 0x40180000C0000000, ptr %f6 + constexpr float f5 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F); + constexpr float f6 = 2.0F * 3.000001F + 0.000001F; + static_assert(f5 == f6); + static_assert(f5 == 6.00000286F); + // CHECK: store double 0x40180000C9539B89, ptr %d5 + // CHECK: store double 0x40180000C9539B89, ptr %d6 + constexpr double d5 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001); + constexpr double d6 = 2.0 * 3.000001 + 0.000001; + static_assert(d5 == d6); + static_assert(d5 == 6.0000030000000004); +} diff --git a/clang/test/Sema/constant-builtins-vector.cpp b/clang/test/Sema/constant-builtins-vector.cpp index 7f882f9ee76eb..9c52a2ab20c7e 100644 --- a/clang/test/Sema/constant-builtins-vector.cpp +++ b/clang/test/Sema/constant-builtins-vector.cpp @@ -936,3 +936,24 @@ constexpr vector4char ctz1 = __builtin_elementwise_cttz((vector4char){1, 0, 3, 4 // expected-note@-1 {{evaluation of __builtin_elementwise_cttz with a zero value is undefined}} static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_cttz((vector4char){8, 0, 127, 0}, (vector4char){9, -1, 9, -2})) == (LITTLE_END ? 0xFE00FF03 : 0x03FF00FE)); static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_cttz((vector4char){0, 0, 0, 0}, (vector4char){0, 0, 0, 0})) == 0); + +// Non-vector floating point types. +static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0); +static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0); +// Vector type. +constexpr vector4float fmaFloat1 = + __builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0}, + (vector4float){2.0, 3.0, 4.0, 5.0}, + (vector4float){3.0, 4.0, 5.0, 6.0}); +static_assert(fmaFloat1[0] == 5.0); +static_assert(fmaFloat1[1] == 10.0); +static_assert(fmaFloat1[2] == 17.0); +static_assert(fmaFloat1[3] == 26.0); +constexpr vector4double fmaDouble1 = + __builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0}, + (vector4double){2.0, 3.0, 4.0, 5.0}, + (vector4double){3.0, 4.0, 5.0, 6.0}); +static_assert(fmaDouble1[0] == 5.0); +static_assert(fmaDouble1[1] == 10.0); +static_assert(fmaDouble1[2] == 17.0); +static_assert(fmaDouble1[3] == 26.0); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits