yaxunl created this revision.
yaxunl added a reviewer: tra.
Herald added subscribers: jansvoboda11, dexonsmith, dang.
yaxunl requested review of this revision.
gcc and clang currently do not have a consistent ABI
for half precision types. Passing aggregate args containing half precision
types between clang and gcc can cause UB.

This patch adds an option -fhip-allow-half-arg. When off, clang
will diagnose aggregate arguments containing half precision
types in host functions.


https://reviews.llvm.org/D98143

Files:
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Basic/LangOptions.def
  clang/include/clang/Driver/Options.td
  clang/lib/Headers/__clang_hip_cmath.h
  clang/lib/Sema/SemaDecl.cpp
  clang/test/SemaCUDA/half-arg.cu

Index: clang/test/SemaCUDA/half-arg.cu
===================================================================
--- /dev/null
+++ clang/test/SemaCUDA/half-arg.cu
@@ -0,0 +1,136 @@
+// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify -x hip %s
+// RUN: %clang_cc1 -std=c++11 -fcuda-is-device -fsyntax-only -verify -x hip %s
+// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify=allow -fhip-allow-half-arg -x hip %s
+
+// allow-no-diagnostics
+
+#include "Inputs/cuda.h"
+
+// Check _Float16/__fp16 or structs containing them are not allowed as function
+// parameter in HIP host functions.
+
+typedef _Float16 half;
+
+typedef _Float16 half2 __attribute__((ext_vector_type(2)));
+
+struct A { // expected-note 4{{within field or base class of type 'A' declared here}}
+  _Float16 x; // expected-note 7{{field of illegal type '_Float16' declared here}}
+};
+
+struct B { // expected-note {{within field or base class of type 'B' declared here}}
+  _Float16 x[2]; // expected-note {{field of illegal type '_Float16 [2]' declared here}}
+};
+
+struct C { // expected-note {{within field or base class of type 'C' declared here}}
+  _Float16 x[2][2]; // expected-note {{field of illegal type '_Float16 [2][2]' declared here}}
+};
+
+struct D { // expected-note {{within field or base class of type 'D' declared here}}
+  A x; // expected-note {{within field or base class of type 'A' declared here}}
+};
+
+struct E : public A { // expected-note {{within field or base class of type 'E' declared here}}
+};
+
+struct F : virtual public A { // expected-note {{within field or base class of type 'F' declared here}}
+};
+
+struct G { // expected-note {{within field or base class of type 'G' declared here}}
+  __fp16 x; // expected-note {{field of illegal type '__fp16' declared here}}
+};
+
+struct H {
+  void f(A x);
+  // expected-error@-1 {{Invalid function parameter type: 'A'}}
+};
+
+template<typename T>
+struct I {
+  T x;
+  void f(T x);
+  // expected-error@-1 {{Invalid function parameter type: 'A'}}
+};
+
+struct J { // expected-note {{within field or base class of type 'J' declared here}}
+  half2 v; // expected-note {{field of illegal type 'half2' (vector of 2 '_Float16' values) declared here}}
+};
+
+struct empty {};
+
+struct K : public empty {
+  int x;
+};
+
+struct undefined;
+
+void fa1(_Float16 x);
+// expected-error@-1 {{Invalid function parameter type: '_Float16'}}
+
+void fa2(A x);
+// expected-error@-1 {{Invalid function parameter type: 'A'}}
+
+void fa3(B x);
+// expected-error@-1 {{Invalid function parameter type: 'B'}}
+
+void fa4(C x);
+// expected-error@-1 {{Invalid function parameter type: 'C'}}
+
+void fa5(D x);
+// expected-error@-1 {{Invalid function parameter type: 'D'}}
+
+void fa6(E x);
+// expected-error@-1 {{Invalid function parameter type: 'E'}}
+
+void fa7(F x);
+// expected-error@-1 {{Invalid function parameter type: 'F'}}
+
+void fa8(G x);
+// expected-error@-1 {{Invalid function parameter type: 'G'}}
+
+template<typename T> void fa9(T x);
+// expected-error@-1 {{Invalid function parameter type: 'A'}}
+// expected-note@-2 {{candidate template ignored: substitution failure [with T = A]}}
+void fa9_caller() {
+  A x;
+  fa9(x);
+  // expected-error@-1 {{no matching function for call to 'fa9'}}
+  // expected-note@-2 {{in instantiation of function template specialization 'fa9<A>' requested here}}
+}
+
+void fa10() {
+  I<A> x;
+  // expected-note@-1 {{in instantiation of template class 'I<A>' requested here}}
+}
+
+void fa11(half x);
+// expected-error@-1 {{Invalid function parameter type: 'half' (aka '_Float16')}}
+
+void fa12(half2 x);
+// expected-error@-1 {{Invalid function parameter type: 'half2' (vector of 2 '_Float16' values)}}
+
+void fa13(J x);
+// expected-error@-1 {{Invalid function parameter type: 'J'}}
+
+void fa14(int x, _Float16 y);
+// expected-error@-1 {{Invalid function parameter type: '_Float16'}}
+
+_Float16 fa15();
+// expected-error@-1 {{Invalid function return type: '_Float16'}}
+
+void fa16(K x);
+
+undefined fa17();
+
+// Check reference or pointers to _Float16/__fp16 or structs containing
+// them are allowed as function parameters in HIP host functions.
+
+void fb1(_Float16 &x);
+void fb2(_Float16 *x);
+void fb3(A &x);
+void fb4(A *x);
+
+// Check device function can use _Float16/__fp16 or struct containing
+// them as parameter type.
+__device__ void fc1(A x);
+__global__ void fc2(A x);
+__host__ __device__ void fc3(A x);
Index: clang/lib/Sema/SemaDecl.cpp
===================================================================
--- clang/lib/Sema/SemaDecl.cpp
+++ clang/lib/Sema/SemaDecl.cpp
@@ -8591,6 +8591,129 @@
   }
 }
 
+// Result type returned by the functor checking struct field.
+enum class CheckFieldResult {
+  Valid,   // The filed is valid
+  Recurse, // The field is a struct which needs to be checked recursively
+  Invalid, // The filed is invalid
+};
+
+// Check whether struct or array type contains invalid fields or elements by
+// recursively visiting fields of the structs with the functor CheckField.
+// Returns true if the type is valid. CheckField returns Valid if the field is
+// valid, emits a diagnostic message and returns Invalid if the field is
+// invalid, returns Recurse if the field is a struct which needs further check.
+// ValidTypes contain known valid types.
+static bool
+checkStructOrArrayType(Sema &S, QualType PT,
+                       llvm::SmallPtrSetImpl<const Type *> &ValidTypes,
+                       std::function<CheckFieldResult(QualType)> CheckFieldType,
+                       std::function<void(QualType)> DiagInvalidParam) {
+  // Track nested structs we will inspect
+  SmallVector<const Decl *, 4> VisitStack;
+
+  // Track where we are in the nested structs. Items will migrate from
+  // VisitStack to HistoryStack as we do the DFS for bad field.
+  SmallVector<const FieldDecl *, 4> HistoryStack;
+  HistoryStack.push_back(nullptr);
+
+  // At this point we already handled everything except of a RecordType or
+  // an ArrayType of a RecordType.
+  assert((PT->isArrayType() || PT->isRecordType()) && "Unexpected type.");
+  const RecordType *RecTy =
+      PT->getPointeeOrArrayElementType()->getAs<RecordType>();
+  const RecordDecl *OrigRecDecl = RecTy->getDecl();
+
+  VisitStack.push_back(RecTy->getDecl());
+  assert(VisitStack.back() && "First decl null?");
+
+  do {
+    const Decl *Next = VisitStack.pop_back_val();
+    if (!Next) {
+      // HistoryStack is empty if a struct has no fields or base.
+      if (HistoryStack.empty())
+        continue;
+      // Found a marker, we have gone up a level
+      if (const FieldDecl *Hist = HistoryStack.pop_back_val())
+        ValidTypes.insert(Hist->getType().getTypePtr());
+
+      continue;
+    }
+
+    // Adds everything except the original parameter declaration (which is not a
+    // field itself) to the history stack.
+    const RecordDecl *RD;
+    if (const FieldDecl *Field = dyn_cast<FieldDecl>(Next)) {
+      HistoryStack.push_back(Field);
+
+      QualType FieldTy = Field->getType();
+      // Other field types (known to be valid or invalid) are handled while we
+      // walk around RecordDecl::fields().
+      assert((FieldTy->isArrayType() || FieldTy->isRecordType()) &&
+             "Unexpected type.");
+      const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType();
+
+      RD = FieldRecTy->castAs<RecordType>()->getDecl();
+    } else {
+      RD = cast<RecordDecl>(Next);
+    }
+
+    RD = RD->getDefinition();
+    // A struct return type can be undefined.
+    if (!RD)
+      continue;
+
+    // Add a null marker so we know when we've gone back up a level
+    VisitStack.push_back(nullptr);
+
+    if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+      for (auto Base : CXXRD->bases()) {
+        // Skip non-record type, e.g. TemplateSpecializationType
+        if (auto *RT = Base.getType().getCanonicalType()->getAs<RecordType>()) {
+          VisitStack.push_back(RT->getDecl());
+        }
+      }
+
+    for (const auto *FD : RD->fields()) {
+      QualType QT = FD->getType();
+
+      if (ValidTypes.count(QT.getTypePtr()))
+        continue;
+
+      auto Result = CheckFieldType(QT);
+
+      if (Result == CheckFieldResult::Valid)
+        continue;
+
+      if (Result == CheckFieldResult::Recurse) {
+        VisitStack.push_back(FD);
+        continue;
+      }
+
+      assert(Result == CheckFieldResult::Invalid);
+      DiagInvalidParam(QT);
+      S.Diag(OrigRecDecl->getLocation(), diag::note_within_field_of_type)
+          << OrigRecDecl->getDeclName() << S.getLangOpts().CPlusPlus;
+
+      // We have an error, now let's go back up through history and show where
+      // the offending field came from
+      for (ArrayRef<const FieldDecl *>::const_iterator
+               I = HistoryStack.begin() + 1,
+               E = HistoryStack.end();
+           I != E; ++I) {
+        const FieldDecl *OuterField = *I;
+        S.Diag(OuterField->getLocation(), diag::note_within_field_of_type)
+            << OuterField->getType() << S.getLangOpts().CPlusPlus;
+      }
+
+      S.Diag(FD->getLocation(), diag::note_illegal_field_declared_here)
+          << QT->isPointerType() << QT;
+      return false;
+    }
+  } while (!VisitStack.empty());
+  return true;
+}
+
 enum OpenCLParamType {
   ValidKernelParam,
   PtrPtrKernelParam,
@@ -8752,106 +8875,109 @@
     break;
   }
 
-  // Track nested structs we will inspect
-  SmallVector<const Decl *, 4> VisitStack;
-
-  // Track where we are in the nested structs. Items will migrate from
-  // VisitStack to HistoryStack as we do the DFS for bad field.
-  SmallVector<const FieldDecl *, 4> HistoryStack;
-  HistoryStack.push_back(nullptr);
-
-  // At this point we already handled everything except of a RecordType or
-  // an ArrayType of a RecordType.
-  assert((PT->isArrayType() || PT->isRecordType()) && "Unexpected type.");
-  const RecordType *RecTy =
-      PT->getPointeeOrArrayElementType()->getAs<RecordType>();
-  const RecordDecl *OrigRecDecl = RecTy->getDecl();
-
-  VisitStack.push_back(RecTy->getDecl());
-  assert(VisitStack.back() && "First decl null?");
-
-  do {
-    const Decl *Next = VisitStack.pop_back_val();
-    if (!Next) {
-      assert(!HistoryStack.empty());
-      // Found a marker, we have gone up a level
-      if (const FieldDecl *Hist = HistoryStack.pop_back_val())
-        ValidTypes.insert(Hist->getType().getTypePtr());
-
-      continue;
-    }
-
-    // Adds everything except the original parameter declaration (which is not a
-    // field itself) to the history stack.
-    const RecordDecl *RD;
-    if (const FieldDecl *Field = dyn_cast<FieldDecl>(Next)) {
-      HistoryStack.push_back(Field);
-
-      QualType FieldTy = Field->getType();
-      // Other field types (known to be valid or invalid) are handled while we
-      // walk around RecordDecl::fields().
-      assert((FieldTy->isArrayType() || FieldTy->isRecordType()) &&
-             "Unexpected type.");
-      const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType();
-
-      RD = FieldRecTy->castAs<RecordType>()->getDecl();
+  auto DiagInvalidParam = [&](QualType ParamTy) {
+    OpenCLParamType ParamType = getOpenCLKernelParameterType(S, ParamTy);
+    // OpenCL v1.2 s6.9.p:
+    // Arguments to kernel functions that are declared to be a struct or union
+    // do not allow OpenCL objects to be passed as elements of the struct or
+    // union.
+    if (ParamType == PtrKernelParam || ParamType == PtrPtrKernelParam ||
+        ParamType == InvalidAddrSpacePtrKernelParam) {
+      S.Diag(Param->getLocation(), diag::err_record_with_pointers_kernel_param)
+          << PT->isUnionType() << PT;
     } else {
-      RD = cast<RecordDecl>(Next);
+      S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT;
     }
+  };
+  auto CheckFieldType = [&](QualType QT) {
+    OpenCLParamType ParamType = getOpenCLKernelParameterType(S, QT);
+    if (ParamType == ValidKernelParam)
+      return CheckFieldResult::Valid;
 
-    // Add a null marker so we know when we've gone back up a level
-    VisitStack.push_back(nullptr);
+    if (ParamType == RecordKernelParam) {
+      return CheckFieldResult::Recurse;
+    }
 
-    for (const auto *FD : RD->fields()) {
-      QualType QT = FD->getType();
+    return CheckFieldResult::Invalid;
+  };
+  if (!checkStructOrArrayType(S, PT, ValidTypes, CheckFieldType,
+                              DiagInvalidParam))
+    D.setInvalidType();
+}
 
-      if (ValidTypes.count(QT.getTypePtr()))
-        continue;
+// Check whether HIP host function has parameters of half precision type or
+// struct type containing half precision type and diagnose them. This is
+// because gcc and clang does not have consistent ABI for half precision
+// type for now.
+// ToDo: disable the diagnostics once gcc and clang have a consistent ABI
+// about half precision types.
+static void checkHIPFunctionParameters(Sema &S, FunctionDecl *FD) {
+  if (S.getLangOpts().HIPAllowHalfArg || FD->hasAttr<CUDADeviceAttr>() ||
+      FD->hasAttr<CUDAGlobalAttr>())
+    return;
 
-      OpenCLParamType ParamType = getOpenCLKernelParameterType(S, QT);
-      if (ParamType == ValidKernelParam)
-        continue;
+  auto IsInvalidType = [](QualType T) {
+    if (T->isArrayType())
+      T = QualType(T->getPointeeOrArrayElementType(), 0);
+    if (T->isVectorType())
+      T = T->getAs<VectorType>()->getElementType();
+    return T->isFloat16Type() || T->isHalfType();
+  };
 
-      if (ParamType == RecordKernelParam) {
-        VisitStack.push_back(FD);
-        continue;
-      }
+  // Check field type.
+  auto CheckFieldType = [&](QualType FT) {
+    if (IsInvalidType(FT)) {
+      return CheckFieldResult::Invalid;
+    }
+    if (FT->isRecordType())
+      return CheckFieldResult::Recurse;
+    return CheckFieldResult::Valid;
+  };
 
-      // OpenCL v1.2 s6.9.p:
-      // Arguments to kernel functions that are declared to be a struct or union
-      // do not allow OpenCL objects to be passed as elements of the struct or
-      // union.
-      if (ParamType == PtrKernelParam || ParamType == PtrPtrKernelParam ||
-          ParamType == InvalidAddrSpacePtrKernelParam) {
-        S.Diag(Param->getLocation(),
-               diag::err_record_with_pointers_kernel_param)
-          << PT->isUnionType()
-          << PT;
-      } else {
-        S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT;
-      }
+  // Cache for known valid types to avoid repeated check.
+  llvm::SmallPtrSet<const Type *, 16> ValidTypes;
 
-      S.Diag(OrigRecDecl->getLocation(), diag::note_within_field_of_type)
-          << OrigRecDecl->getDeclName();
+  // Information about parameter or return types to be checked.
+  struct TypeCheckInfo {
+    QualType Ty;
+    SourceLocation Loc;
+    bool IsRet; // Whether it is return type
+    TypeCheckInfo(QualType T, SourceLocation L, bool _IsRet)
+        : Ty(T), Loc(L), IsRet(_IsRet) {}
+  };
+  llvm::SmallVector<TypeCheckInfo, 8> TCInfo;
+  for (auto ParmVar : FD->parameters())
+    TCInfo.emplace_back(
+        TypeCheckInfo{ParmVar->getType(), ParmVar->getLocation(), false});
+  TCInfo.emplace_back(
+      TypeCheckInfo{FD->getReturnType(), FD->getLocation(), true});
+
+  for (auto Info : TCInfo) {
+    QualType T = Info.Ty;
+
+    // Diagnose invalid parameter type for the current parameter.
+    auto DiagInvalidType = [&](QualType Ty) {
+      unsigned DiagID = S.getDiagnostics().getCustomDiagID(
+          DiagnosticsEngine::Error,
+          "Invalid function %select{parameter|return}0 type: %1");
+      S.Diag(Info.Loc, DiagID) << Info.IsRet << T;
+    };
+
+    if (IsInvalidType(T)) {
+      DiagInvalidType(T);
+      FD->setInvalidDecl();
+      continue;
+    }
 
-      // We have an error, now let's go back up through history and show where
-      // the offending field came from
-      for (ArrayRef<const FieldDecl *>::const_iterator
-               I = HistoryStack.begin() + 1,
-               E = HistoryStack.end();
-           I != E; ++I) {
-        const FieldDecl *OuterField = *I;
-        S.Diag(OuterField->getLocation(), diag::note_within_field_of_type)
-          << OuterField->getType();
-      }
+    if (!T->isRecordType() && !T->isArrayType())
+      continue;
 
-      S.Diag(FD->getLocation(), diag::note_illegal_field_declared_here)
-        << QT->isPointerType()
-        << QT;
-      D.setInvalidType();
+    if (!checkStructOrArrayType(S, T, ValidTypes, CheckFieldType,
+                                DiagInvalidType)) {
+      FD->setInvalidDecl();
       return;
     }
-  } while (!VisitStack.empty());
+  }
 }
 
 /// Find the DeclContext in which a tag is implicitly declared if we see an
@@ -10866,6 +10992,10 @@
   if (LangOpts.OpenMP)
     ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(NewFD);
 
+  // Check HIP host function parameter types.
+  if (getLangOpts().HIP)
+    checkHIPFunctionParameters(*this, NewFD);
+
   // Semantic checking for this function declaration (in isolation).
 
   if (getLangOpts().CPlusPlus) {
Index: clang/lib/Headers/__clang_hip_cmath.h
===================================================================
--- clang/lib/Headers/__clang_hip_cmath.h
+++ clang/lib/Headers/__clang_hip_cmath.h
@@ -225,7 +225,9 @@
 
 template <class _Tp> struct __numeric_type {
   static void __test(...);
-  static _Float16 __test(_Float16);
+  // _Float16 is not allowed as host function arguments until ABI compatibility
+  // issue with gcc is resolved.
+  static __device__ _Float16 __test(_Float16);
   static float __test(float);
   static double __test(char);
   static double __test(int);
Index: clang/include/clang/Driver/Options.td
===================================================================
--- clang/include/clang/Driver/Options.td
+++ clang/include/clang/Driver/Options.td
@@ -921,6 +921,12 @@
   LangOpts<"HIPUseNewLaunchAPI">, DefaultFalse,
   PosFlag<SetTrue, [CC1Option], "Use">, NegFlag<SetFalse, [], "Don't use">,
   BothFlags<[], " new kernel launching API for HIP">>;
+defm hip_allow_half_arg : BoolFOption<"hip-allow-half-arg",
+  LangOpts<"HIPAllowHalfArg">, DefaultTrue,
+  PosFlag<SetTrue, [CC1Option], "Allow">, NegFlag<SetFalse, [], "Don't allow">,
+  BothFlags<[], " half precision types or aggregate types containing half "
+  "precision types as host function parameter type or return type">>,
+  ShouldParseIf<hip.KeyPath>;
 defm gpu_allow_device_init : BoolFOption<"gpu-allow-device-init",
   LangOpts<"GPUAllowDeviceInit">, DefaultFalse,
   PosFlag<SetTrue, [CC1Option], "Allow">, NegFlag<SetFalse, [], "Don't allow">,
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -253,6 +253,9 @@
 ENUM_LANGOPT(SYCLVersion  , SYCLMajorVersion, 1, SYCL_None, "Version of the SYCL standard used")
 
 LANGOPT(HIPUseNewLaunchAPI, 1, 0, "Use new kernel launching API for HIP")
+LANGOPT(HIPAllowHalfArg, 1, 1, "Allow half precision types or aggregate types "
+                               "containing half precision types as host "
+                               "function parameter and return types for HIP")
 
 LANGOPT(SizedDeallocation , 1, 0, "sized deallocation")
 LANGOPT(AlignedAllocation , 1, 0, "aligned allocation")
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -9858,7 +9858,7 @@
 def err_record_with_pointers_kernel_param : Error<
   "%select{struct|union}0 kernel parameters may not contain pointers">;
 def note_within_field_of_type : Note<
-  "within field of type %0 declared here">;
+  "within field %select{|or base class }1of type %0 declared here">;
 def note_illegal_field_declared_here : Note<
   "field of illegal %select{type|pointer type}0 %1 declared here">;
 def err_opencl_type_struct_or_union_field : Error<
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D98143: [HIP] Diagnose a... Yaxun Liu via Phabricator via cfe-commits

Reply via email to