arsenm updated this revision to Diff 499805.
arsenm added a comment.

Loop merge, documentation


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D140992/new/

https://reviews.llvm.org/D140992

Files:
  clang/docs/LanguageExtensions.rst
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGen/builtins-elementwise-math.c
  clang/test/Sema/builtins-elementwise-math.c

Index: clang/test/Sema/builtins-elementwise-math.c
===================================================================
--- clang/test/Sema/builtins-elementwise-math.c
+++ clang/test/Sema/builtins-elementwise-math.c
@@ -4,6 +4,8 @@
 typedef double double4 __attribute__((ext_vector_type(4)));
 typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float4 __attribute__((ext_vector_type(4)));
+
+typedef int int2 __attribute__((ext_vector_type(2)));
 typedef int int3 __attribute__((ext_vector_type(3)));
 typedef unsigned unsigned3 __attribute__((ext_vector_type(3)));
 typedef unsigned unsigned4 __attribute__((ext_vector_type(4)));
@@ -572,3 +574,84 @@
   float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32);
   // expected-error@-1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}}
 }
+
+void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16,
+                                  double f64, double2 v2f64, double2 v3f64,
+                                  float f32, float2 v2f32, float v3f32, float4 v4f32,
+                                  const float4 c_v4f32,
+                                  int3 v3i32, int *ptr) {
+
+  f32 = __builtin_elementwise_fma();
+  // expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
+
+  f32 = __builtin_elementwise_fma(f32);
+  // expected-error@-1 {{too few arguments to function call, expected 3, have 1}}
+
+  f32 = __builtin_elementwise_fma(f32, f32);
+  // expected-error@-1 {{too few arguments to function call, expected 3, have 2}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, f32, f32);
+  // expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
+
+  f32 = __builtin_elementwise_fma(f64, f32, f32);
+  // expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
+
+  f32 = __builtin_elementwise_fma(f32, f64, f32);
+  // expected-error@-1 {{arguments are of different types ('float' vs 'double')}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, f64);
+  // expected-error@-1 {{arguments are of different types ('float' vs 'double')}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, f64);
+  // expected-error@-1 {{arguments are of different types ('float' vs 'double')}}
+
+  f64 = __builtin_elementwise_fma(f64, f32, f32);
+  // expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
+
+  f64 = __builtin_elementwise_fma(f64, f64, f32);
+  // expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
+
+  f64 = __builtin_elementwise_fma(f64, f32, f64);
+  // expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
+
+  v2f64 = __builtin_elementwise_fma(v2f32, f64, f64);
+  // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}}
+
+  v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64);
+  // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}}
+
+  v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64);
+  // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}}
+
+  v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64);
+  // expected-error@-1 {{arguments are of different types ('double' vs 'float2' (vector of 2 'float' values)}}
+
+  v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64);
+  // expected-error@-1 {{arguments are of different types ('double' vs 'double2' (vector of 2 'double' values)}}
+
+  i32 = __builtin_elementwise_fma(i32, i32, i32);
+  // expected-error@-1 {{1st argument must be a floating point type (was 'int')}}
+
+  v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32);
+  // expected-error@-1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, i32);
+  // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}}
+
+  f32 = __builtin_elementwise_fma(f32, i32, f32);
+  // expected-error@-1 {{2nd argument must be a floating point type (was 'int')}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, i32);
+  // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}}
+
+
+  _Complex float c1, c2, c3;
+  c1 = __builtin_elementwise_fma(c1, f32, f32);
+  // expected-error@-1 {{1st argument must be a floating point type (was '_Complex float')}}
+
+  c2 = __builtin_elementwise_fma(f32, c2, f32);
+  // expected-error@-1 {{2nd argument must be a floating point type (was '_Complex float')}}
+
+  c3 = __builtin_elementwise_fma(f32, f32, c3);
+  // expected-error@-1 {{3rd argument must be a floating point type (was '_Complex float')}}
+}
Index: clang/test/CodeGen/builtins-elementwise-math.c
===================================================================
--- clang/test/CodeGen/builtins-elementwise-math.c
+++ clang/test/CodeGen/builtins-elementwise-math.c
@@ -1,5 +1,9 @@
 // RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
 
+typedef _Float16 half;
+
+typedef half half2 __attribute__((ext_vector_type(2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float4 __attribute__((ext_vector_type(4)));
 typedef short int si8 __attribute__((ext_vector_type(8)));
 typedef unsigned int u4 __attribute__((ext_vector_type(4)));
@@ -525,3 +529,77 @@
   // CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> [[V2F64]])
   v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64);
 }
+
+void test_builtin_elementwise_fma(float f32, double f64,
+                                  float2 v2f32, float4 v4f32,
+                                  double2 v2f64, double3 v3f64,
+                                  const float4 c_v4f32,
+                                  half f16, half2 v2f16) {
+  // CHECK-LABEL: define void @test_builtin_elementwise_fma(
+  // CHECK:      [[F32_0:%.+]] = load float, ptr %f32.addr
+  // CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr
+  // CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr
+  // CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]])
+  float f2 = __builtin_elementwise_fma(f32, f32, f32);
+
+  // CHECK:      [[F64_0:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]])
+  double d2 = __builtin_elementwise_fma(f64, f64, f64);
+
+  // CHECK:      [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr
+  // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr
+  // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr
+  // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]])
+  float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32);
+
+
+  // FIXME: Are we really still doing the 3 vector load workaround
+  // CHECK:      [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr
+  // CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector
+  // CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr
+  // CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector
+  // CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr
+  // CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector
+    // CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]])
+  v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64);
+
+  // CHECK:      [[F64_0:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]])
+  v2f64 = __builtin_elementwise_fma(f64, f64, f64);
+
+  // CHECK:      [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr
+  // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr
+  // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr
+  // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]])
+  v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32);
+
+  // CHECK:      [[F16_0:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]])
+  half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16);
+
+  // CHECK:      [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]])
+  half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16);
+
+  // CHECK:      [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement
+  // CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer
+  // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]])
+  half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16);
+
+  // CHECK:      [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> <half 0xH4400, half 0xH4400>)
+  half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0);
+
+}
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -2626,20 +2626,16 @@
       return ExprError();
 
     QualType ArgTy = TheCall->getArg(0)->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-    if (!EltTy->isFloatingType()) {
-      Diag(TheCall->getArg(0)->getBeginLoc(),
-           diag::err_builtin_invalid_arg_type)
-          << 1 << /* float ty*/ 5 << ArgTy;
-
+    if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
+                                      ArgTy, 1))
+      return ExprError();
+    break;
+  }
+  case Builtin::BI__builtin_elementwise_fma: {
+    if (SemaBuiltinElementwiseTernaryMath(TheCall))
       return ExprError();
-    }
     break;
   }
-
   // These builtins restrict the element type to integer
   // types only.
   case Builtin::BI__builtin_elementwise_add_sat:
@@ -17877,6 +17873,40 @@
   return false;
 }
 
+bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
+  if (checkArgCount(*this, TheCall, 3))
+    return true;
+
+  Expr *Args[3];
+  for (int I = 0; I < 3; ++I) {
+    ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I));
+    if (Converted.isInvalid())
+      return true;
+    Args[I] = Converted.get();
+  }
+
+  int ArgOrdinal = 1;
+  for (Expr *Arg : Args) {
+    if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                      ArgOrdinal++))
+      return true;
+  }
+
+  for (int I = 1; I < 3; ++I) {
+    if (Args[0]->getType().getCanonicalType() !=
+        Args[I]->getType().getCanonicalType()) {
+      return Diag(Args[0]->getBeginLoc(),
+                  diag::err_typecheck_call_different_arg_types)
+             << Args[0]->getType() << Args[I]->getType();
+    }
+
+    TheCall->setArg(I, Args[I]);
+  }
+
+  TheCall->setType(Args[0]->getType());
+  return false;
+}
+
 bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) {
   if (checkArgCount(*this, TheCall, 1))
     return true;
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -3118,6 +3118,8 @@
         emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc"));
   case Builtin::BI__builtin_elementwise_copysign:
     return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign));
+  case Builtin::BI__builtin_elementwise_fma:
+    return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma));
   case Builtin::BI__builtin_elementwise_add_sat:
   case Builtin::BI__builtin_elementwise_sub_sat: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -13531,6 +13531,7 @@
   bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
 
   bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
+  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
   bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -671,6 +671,7 @@
 BUILTIN(__builtin_elementwise_trunc, "v.", "nct")
 BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct")
 BUILTIN(__builtin_elementwise_copysign, "v.", "nct")
+BUILTIN(__builtin_elementwise_fma, "v.", "nct")
 BUILTIN(__builtin_elementwise_add_sat, "v.", "nct")
 BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct")
 BUILTIN(__builtin_reduce_max, "v.", "nct")
Index: clang/docs/LanguageExtensions.rst
===================================================================
--- clang/docs/LanguageExtensions.rst
+++ clang/docs/LanguageExtensions.rst
@@ -631,6 +631,7 @@
 =========================================== ================================================================ =========================================
  T __builtin_elementwise_abs(T x)            return the absolute value of a number x; the absolute value of   signed integer and floating point types
                                              the most negative integer remains the most negative integer
+ T __builtin_elementwise_fma(T x, T y, T z)  fused multiply add, (x * y) +  z.                                              floating point types
  T __builtin_elementwise_ceil(T x)           return the smallest integral value greater than or equal to x    floating point types
  T __builtin_elementwise_sin(T x)            return the sine of x interpreted as an angle in radians          floating point types
  T __builtin_elementwise_cos(T x)            return the cosine of x interpreted as an angle in radians        floating point types
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to