Author: Matt Arsenault Date: 2023-02-24T21:55:08-04:00 New Revision: 8709bcacfb3a06847b47bb6b47e8556db43f3a43
URL: https://github.com/llvm/llvm-project/commit/8709bcacfb3a06847b47bb6b47e8556db43f3a43 DIFF: https://github.com/llvm/llvm-project/commit/8709bcacfb3a06847b47bb6b47e8556db43f3a43.diff LOG: clang: Add __builtin_elementwise_fma I didn't understand why the other builtins have promotion logic, or how it would apply for a ternary operation. Implicit conversions are evil to begin with, and even more so when the purpose is to get an exact IR intrinsic. This checks all the arguments have the same type. Added: Modified: clang/docs/LanguageExtensions.rst clang/include/clang/Basic/Builtins.def clang/include/clang/Sema/Sema.h clang/lib/CodeGen/CGBuiltin.cpp clang/lib/Sema/SemaChecking.cpp clang/test/CodeGen/builtins-elementwise-math.c clang/test/Sema/builtins-elementwise-math.c Removed: ################################################################################ diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst index 2595b5de8f265..c0ea8afad6cb2 100644 --- a/clang/docs/LanguageExtensions.rst +++ b/clang/docs/LanguageExtensions.rst @@ -631,6 +631,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in =========================================== ================================================================ ========================================= T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types the most negative integer remains the most negative integer + T __builtin_elementwise_fma(T x, T y, T z) fused multiply add, (x * y) + z. floating point types T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def index 41288410786b0..6db599a3de116 100644 --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -671,6 +671,7 @@ BUILTIN(__builtin_elementwise_sin, "v.", "nct") BUILTIN(__builtin_elementwise_trunc, "v.", "nct") BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct") BUILTIN(__builtin_elementwise_copysign, "v.", "nct") +BUILTIN(__builtin_elementwise_fma, "v.", "nct") BUILTIN(__builtin_elementwise_add_sat, "v.", "nct") BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct") BUILTIN(__builtin_reduce_max, "v.", "nct") diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 41691eab4972b..0c6a3887e4151 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -13531,6 +13531,7 @@ class Sema final { bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc); bool SemaBuiltinElementwiseMath(CallExpr *TheCall); + bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall); bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall); bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall); diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 52ec6e092c449..1535b14c7fb40 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -3118,6 +3118,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc")); case Builtin::BI__builtin_elementwise_copysign: return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign)); + case Builtin::BI__builtin_elementwise_fma: + return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma)); case Builtin::BI__builtin_elementwise_add_sat: case Builtin::BI__builtin_elementwise_sub_sat: { Value *Op0 = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index eded6061c77eb..485351f157fde 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2626,20 +2626,16 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, return ExprError(); QualType ArgTy = TheCall->getArg(0)->getType(); - QualType EltTy = ArgTy; - - if (auto *VecTy = EltTy->getAs<VectorType>()) - EltTy = VecTy->getElementType(); - if (!EltTy->isFloatingType()) { - Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_builtin_invalid_arg_type) - << 1 << /* float ty*/ 5 << ArgTy; - + if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(), + ArgTy, 1)) + return ExprError(); + break; + } + case Builtin::BI__builtin_elementwise_fma: { + if (SemaBuiltinElementwiseTernaryMath(TheCall)) return ExprError(); - } break; } - // These builtins restrict the element type to integer // types only. case Builtin::BI__builtin_elementwise_add_sat: @@ -17877,6 +17873,40 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { return false; } +bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) { + if (checkArgCount(*this, TheCall, 3)) + return true; + + Expr *Args[3]; + for (int I = 0; I < 3; ++I) { + ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I)); + if (Converted.isInvalid()) + return true; + Args[I] = Converted.get(); + } + + int ArgOrdinal = 1; + for (Expr *Arg : Args) { + if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(), + ArgOrdinal++)) + return true; + } + + for (int I = 1; I < 3; ++I) { + if (Args[0]->getType().getCanonicalType() != + Args[I]->getType().getCanonicalType()) { + return Diag(Args[0]->getBeginLoc(), + diag::err_typecheck_call_ diff erent_arg_types) + << Args[0]->getType() << Args[I]->getType(); + } + + TheCall->setArg(I, Args[I]); + } + + TheCall->setType(Args[0]->getType()); + return false; +} + bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) { if (checkArgCount(*this, TheCall, 1)) return true; diff --git a/clang/test/CodeGen/builtins-elementwise-math.c b/clang/test/CodeGen/builtins-elementwise-math.c index 1571d2bb7f650..1b48a12b92056 100644 --- a/clang/test/CodeGen/builtins-elementwise-math.c +++ b/clang/test/CodeGen/builtins-elementwise-math.c @@ -1,5 +1,9 @@ // RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s +typedef _Float16 half; + +typedef half half2 __attribute__((ext_vector_type(2))); +typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); typedef short int si8 __attribute__((ext_vector_type(8))); typedef unsigned int u4 __attribute__((ext_vector_type(4))); @@ -525,3 +529,77 @@ void test_builtin_elementwise_copysign(float f1, float f2, double d1, double d2, // CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> [[V2F64]]) v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64); } + +void test_builtin_elementwise_fma(float f32, double f64, + float2 v2f32, float4 v4f32, + double2 v2f64, double3 v3f64, + const float4 c_v4f32, + half f16, half2 v2f16) { + // CHECK-LABEL: define void @test_builtin_elementwise_fma( + // CHECK: [[F32_0:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]]) + float f2 = __builtin_elementwise_fma(f32, f32, f32); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + double d2 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32); + + + // FIXME: Are we really still doing the 3 vector load workaround + // CHECK: [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector + // CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]]) + v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + v2f64 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32); + + // CHECK: [[F16_0:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]]) + half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement + // CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> <half 0xH4400, half 0xH4400>) + half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0); + +} diff --git a/clang/test/Sema/builtins-elementwise-math.c b/clang/test/Sema/builtins-elementwise-math.c index cb8b79739b6d2..c803fcea0d505 100644 --- a/clang/test/Sema/builtins-elementwise-math.c +++ b/clang/test/Sema/builtins-elementwise-math.c @@ -4,6 +4,8 @@ typedef double double2 __attribute__((ext_vector_type(2))); typedef double double4 __attribute__((ext_vector_type(4))); typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); + +typedef int int2 __attribute__((ext_vector_type(2))); typedef int int3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned4 __attribute__((ext_vector_type(4))); @@ -572,3 +574,84 @@ void test_builtin_elementwise_copysign(int i, short s, double d, float f, float4 float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32); // expected-error@-1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}} } + +void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16, + double f64, double2 v2f64, double2 v3f64, + float f32, float2 v2f32, float v3f32, float4 v4f32, + const float4 c_v4f32, + int3 v3i32, int *ptr) { + + f32 = __builtin_elementwise_fma(); + // expected-error@-1 {{too few arguments to function call, expected 3, have 0}} + + f32 = __builtin_elementwise_fma(f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 1}} + + f32 = __builtin_elementwise_fma(f32, f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 2}} + + f32 = __builtin_elementwise_fma(f32, f32, f32, f32); + // expected-error@-1 {{too many arguments to function call, expected 3, have 4}} + + f32 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of diff erent types ('double' vs 'float')}} + + f32 = __builtin_elementwise_fma(f32, f64, f32); + // expected-error@-1 {{arguments are of diff erent types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of diff erent types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of diff erent types ('float' vs 'double')}} + + f64 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of diff erent types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f64, f32); + // expected-error@-1 {{arguments are of diff erent types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f32, f64); + // expected-error@-1 {{arguments are of diff erent types ('double' vs 'float')}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, f64); + // expected-error@-1 {{arguments are of diff erent types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64); + // expected-error@-1 {{arguments are of diff erent types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64); + // expected-error@-1 {{arguments are of diff erent types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64); + // expected-error@-1 {{arguments are of diff erent types ('double' vs 'float2' (vector of 2 'float' values)}} + + v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64); + // expected-error@-1 {{arguments are of diff erent types ('double' vs 'double2' (vector of 2 'double' values)}} + + i32 = __builtin_elementwise_fma(i32, i32, i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int')}} + + v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, i32, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + + _Complex float c1, c2, c3; + c1 = __builtin_elementwise_fma(c1, f32, f32); + // expected-error@-1 {{1st argument must be a floating point type (was '_Complex float')}} + + c2 = __builtin_elementwise_fma(f32, c2, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was '_Complex float')}} + + c3 = __builtin_elementwise_fma(f32, f32, c3); + // expected-error@-1 {{3rd argument must be a floating point type (was '_Complex float')}} +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits