tra updated this revision to Diff 75652.
tra marked an inline comment as done.
tra added a comment.

Addressed remaining nits.


https://reviews.llvm.org/D25809

Files:
  include/clang/Basic/DiagnosticSemaKinds.td
  include/clang/Sema/Sema.h
  lib/Sema/SemaCUDA.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/SemaOverload.cpp
  lib/Sema/SemaTemplate.cpp
  test/CodeGenCUDA/launch-bounds.cu
  test/SemaCUDA/function-overload.cu
  test/SemaCUDA/function-template-overload.cu
  test/SemaCUDA/target_attr_inheritance.cu

Index: test/SemaCUDA/target_attr_inheritance.cu
===================================================================
--- test/SemaCUDA/target_attr_inheritance.cu
+++ /dev/null
@@ -1,29 +0,0 @@
-// Verifies correct inheritance of target attributes during template
-// instantiation and specialization.
-
-// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fsyntax-only -verify %s
-// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fsyntax-only -fcuda-is-device -verify %s
-
-#include "Inputs/cuda.h"
-
-// Function must inherit target attributes during instantiation, but not during
-// specialization.
-template <typename T> __host__ __device__ T function_template(const T &a);
-
-// Specialized functions have their own attributes.
-// expected-note@+1 {{candidate function not viable: call to __host__ function from __device__ function}}
-template <> __host__ float function_template<float>(const float &from);
-
-// expected-note@+1 {{candidate function not viable: call to __device__ function from __host__ function}}
-template <> __device__ double function_template<double>(const double &from);
-
-__host__ void hf() {
-  function_template<float>(1.0f); // OK. Specialization is __host__.
-  function_template<double>(2.0); // expected-error {{no matching function for call to 'function_template'}}
-  function_template(1);           // OK. Instantiated function template is HD.
-}
-__device__ void df() {
-  function_template<float>(3.0f); // expected-error {{no matching function for call to 'function_template'}}
-  function_template<double>(4.0); // OK. Specialization is __device__.
-  function_template(1);           // OK. Instantiated function template is HD.
-}
Index: test/SemaCUDA/function-template-overload.cu
===================================================================
--- /dev/null
+++ test/SemaCUDA/function-template-overload.cu
@@ -0,0 +1,82 @@
+// RUN: %clang_cc1 -std=c++11 -triple x86_64-unknown-linux-gnu -fsyntax-only -verify %s
+// RUN: %clang_cc1 -std=c++11 -triple nvptx64-nvidia-cuda -fsyntax-only -fcuda-is-device -verify %s
+
+#include "Inputs/cuda.h"
+
+struct HType {}; // expected-note-re 6 {{candidate constructor {{.*}} not viable: no known conversion from 'DType'}}
+struct DType {}; // expected-note-re 6 {{candidate constructor {{.*}} not viable: no known conversion from 'HType'}}
+struct HDType {};
+
+template <typename T> __host__ HType overload_h_d(T a) { return HType(); }
+// expected-note@-1 2 {{candidate template ignored: could not match 'HType' against 'DType'}}
+// expected-note@-2 2 {{candidate template ignored: target attributes do not match}}
+template <typename T> __device__ DType overload_h_d(T a) { return DType(); }
+// expected-note@-1 2 {{candidate template ignored: could not match 'DType' against 'HType'}}
+// expected-note@-2 2 {{candidate template ignored: target attributes do not match}}
+
+// Check explicit instantiation.
+template  __device__ __host__ DType overload_h_d(int a); // There's no HD template...
+// expected-error@-1 {{explicit instantiation of 'overload_h_d' does not refer to a function template, variable template, member function, member class, or static data member}}
+template  __device__ __host__ HType overload_h_d(int a); // There's no HD template...
+// expected-error@-1 {{explicit instantiation of 'overload_h_d' does not refer to a function template, variable template, member function, member class, or static data member}}
+template  __device__ DType overload_h_d(int a); // OK. instantiates D
+template  __host__ HType overload_h_d(int a); // OK. instantiates H
+
+// Check explicit specialization.
+template  <> __device__ __host__ DType overload_h_d(long a); // There's no HD template...
+// expected-error@-1 {{no function template matches function template specialization 'overload_h_d'}}
+template  <> __device__ __host__ HType overload_h_d(long a); // There's no HD template...
+// expected-error@-1 {{no function template matches function template specialization 'overload_h_d'}}
+template  <> __device__ DType overload_h_d(long a); // OK. instantiates D
+template  <> __host__ HType overload_h_d(long a); // OK. instantiates H
+
+
+// Can't overload HD template with H or D template, though functions are OK.
+template <typename T> __host__ __device__ HDType overload_hd(T a) { return HDType(); }
+// expected-note@-1 {{previous declaration is here}}
+// expected-note@-2 2 {{candidate template ignored: could not match 'HDType' against 'HType'}}
+template <typename T> __device__ HDType overload_hd(T a);
+// expected-error@-1 {{__device__ function 'overload_hd' cannot overload __host__ __device__ function 'overload_hd'}}
+__device__ HDType overload_hd(int a); // OK.
+
+// Verify that target attributes are taken into account when we
+// explicitly specialize or instantiate function tempaltes.
+template <> __host__ HType overload_hd(int a);
+// expected-error@-1 {{no function template matches function template specialization 'overload_hd'}}
+template __host__ HType overload_hd(long a);
+// expected-error@-1 {{explicit instantiation of 'overload_hd' does not refer to a function template, variable template, member function, member class, or static data member}}
+__host__ HType overload_hd(int a); // OK
+
+template <typename T> __host__ T overload_h(T a); // expected-note {{previous declaration is here}}
+template <typename T> __host__ __device__ T overload_h(T a);
+// expected-error@-1 {{__host__ __device__ function 'overload_h' cannot overload __host__ function 'overload_h'}}
+template <typename T> __device__ T overload_h(T a); // OK. D can overload H.
+
+template <typename T> __host__ HType overload_h_d2(T a) { return HType(); }
+template <typename T> __host__ __device__ HDType overload_h_d2(T a) { return HDType(); }
+template <typename T1, typename T2 = int> __device__ DType overload_h_d2(T1 a) { T1 x; T2 y; return DType(); }
+
+__host__ void hf() {
+  overload_hd(13);
+
+  HType h = overload_h_d(10);
+  HType h2i = overload_h_d2<int>(11);
+  HType h2ii = overload_h_d2<int>(12);
+
+  // These should be implicitly instantiated from __host__ template returning HType.
+  DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2i = overload_h_d2<int>(21); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+  DType d2ii = overload_h_d2<int>(22); // expected-error {{no viable conversion from 'HType' to 'DType'}}
+}
+__device__ void df() {
+  overload_hd(23);
+
+  // These should be implicitly instantiated from __device__ template returning DType.
+  HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2i = overload_h_d2<int>(11); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+  HType h2ii = overload_h_d2<int>(12); // expected-error {{no viable conversion from 'DType' to 'HType'}}
+
+  DType d = overload_h_d(20);
+  DType d2i = overload_h_d2<int>(21);
+  DType d2ii = overload_h_d2<int>(22);
+}
Index: test/SemaCUDA/function-overload.cu
===================================================================
--- test/SemaCUDA/function-overload.cu
+++ test/SemaCUDA/function-overload.cu
@@ -40,21 +40,21 @@
 __device__ DeviceReturnTy dh() { return DeviceReturnTy(); }
 
 // H/HD and D/HD are not allowed.
-__host__ __device__ int hdh() { return 0; } // expected-note {{previous definition is here}}
-__host__ int hdh() { return 0; }            // expected-error {{redefinition of 'hdh'}}
+__host__ __device__ int hdh() { return 0; } // expected-note {{previous declaration is here}}
+__host__ int hdh() { return 0; }
+// expected-error@-1 {{__host__ function 'hdh' cannot overload __host__ __device__ function 'hdh'}}
 
-__host__ int hhd() { return 0; }            // expected-note {{previous definition is here}}
-__host__ __device__ int hhd() { return 0; } // expected-error {{redefinition of 'hhd'}}
-// expected-warning@-1 {{attribute declaration must precede definition}}
-// expected-note@-3 {{previous definition is here}}
+__host__ int hhd() { return 0; }            // expected-note {{previous declaration is here}}
+__host__ __device__ int hhd() { return 0; }
+// expected-error@-1 {{__host__ __device__ function 'hhd' cannot overload __host__ function 'hhd'}}
 
-__host__ __device__ int hdd() { return 0; } // expected-note {{previous definition is here}}
-__device__ int hdd() { return 0; }          // expected-error {{redefinition of 'hdd'}}
+__host__ __device__ int hdd() { return 0; } // expected-note {{previous declaration is here}}
+__device__ int hdd() { return 0; }
+// expected-error@-1 {{__device__ function 'hdd' cannot overload __host__ __device__ function 'hdd'}}
 
-__device__ int dhd() { return 0; }          // expected-note {{previous definition is here}}
-__host__ __device__ int dhd() { return 0; } // expected-error {{redefinition of 'dhd'}}
-// expected-warning@-1 {{attribute declaration must precede definition}}
-// expected-note@-3 {{previous definition is here}}
+__device__ int dhd() { return 0; }          // expected-note {{previous declaration is here}}
+__host__ __device__ int dhd() { return 0; }
+// expected-error@-1 {{__host__ __device__ function 'dhd' cannot overload __device__ function 'dhd'}}
 
 // Same tests for extern "C" functions.
 extern "C" __host__ int chh() { return 0; } // expected-note {{previous definition is here}}
@@ -65,13 +65,13 @@
 extern "C" __host__ HostReturnTy cdh() { return HostReturnTy(); }
 
 // H/HD and D/HD overloading is not allowed.
-extern "C" __host__ __device__ int chhd1() { return 0; } // expected-note {{previous definition is here}}
-extern "C" __host__ int chhd1() { return 0; }            // expected-error {{redefinition of 'chhd1'}}
+extern "C" __host__ __device__ int chhd1() { return 0; } // expected-note {{previous declaration is here}}
+extern "C" __host__ int chhd1() { return 0; }
+// expected-error@-1 {{__host__ function 'chhd1' cannot overload __host__ __device__ function 'chhd1'}}
 
-extern "C" __host__ int chhd2() { return 0; }            // expected-note {{previous definition is here}}
-extern "C" __host__ __device__ int chhd2() { return 0; } // expected-error {{redefinition of 'chhd2'}}
-// expected-warning@-1 {{attribute declaration must precede definition}}
-// expected-note@-3 {{previous definition is here}}
+extern "C" __host__ int chhd2() { return 0; } // expected-note {{previous declaration is here}}
+extern "C" __host__ __device__ int chhd2() { return 0; }
+// expected-error@-1 {{__host__ __device__ function 'chhd2' cannot overload __host__ function 'chhd2'}}
 
 // Helper functions to verify calling restrictions.
 __device__ DeviceReturnTy d() { return DeviceReturnTy(); }
@@ -250,33 +250,39 @@
 
 struct m_hhd {
   __host__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __host__ __device__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __host__ __device__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__host__ __device__ function 'operator delete' cannot overload __host__ function 'operator delete'}}
 };
 
 struct m_hdh {
   __host__ __device__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __host__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __host__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__host__ function 'operator delete' cannot overload __host__ __device__ function 'operator delete'}}
 };
 
 struct m_dhd {
   __device__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __host__ __device__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __host__ __device__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__host__ __device__ function 'operator delete' cannot overload __device__ function 'operator delete'}}
 };
 
 struct m_hdd {
   __host__ __device__ void operator delete(void *ptr) {} // expected-note {{previous declaration is here}}
-  __device__ void operator delete(void *ptr) {} // expected-error {{class member cannot be redeclared}}
+  __device__ void operator delete(void *ptr) {}
+  // expected-error@-1 {{__device__ function 'operator delete' cannot overload __host__ __device__ function 'operator delete'}}
 };
 
 // __global__ functions can't be overloaded based on attribute
 // difference.
 struct G {
-  friend void friend_of_g(G &arg);
+  friend void friend_of_g(G &arg); // expected-note {{previous declaration is here}}
 private:
-  int x;
+  int x; // expected-note {{declared private here}}
 };
-__global__ void friend_of_g(G &arg) { int x = arg.x; } // expected-note {{previous definition is here}}
-void friend_of_g(G &arg) { int x = arg.x; } // expected-error {{redefinition of 'friend_of_g'}}
+__global__ void friend_of_g(G &arg) { int x = arg.x; }
+// expected-error@-1 {{__global__ function 'friend_of_g' cannot overload __host__ function 'friend_of_g'}}
+// expected-error@-2 {{'x' is a private member of 'G'}}
+void friend_of_g(G &arg) { int x = arg.x; }
 
 // HD functions are sometimes allowed to call H or D functions -- this
 // is an artifact of the source-to-source splitting performed by nvcc
Index: test/CodeGenCUDA/launch-bounds.cu
===================================================================
--- test/CodeGenCUDA/launch-bounds.cu
+++ test/CodeGenCUDA/launch-bounds.cu
@@ -36,16 +36,16 @@
 {
 }
 
-template void Kernel3<MAX_THREADS_PER_BLOCK>();
+template __global__ void Kernel3<MAX_THREADS_PER_BLOCK>();
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel3{{.*}}, !"maxntidx", i32 256}
 
 template <int max_threads_per_block, int min_blocks_per_mp>
 __global__ void
 __launch_bounds__(max_threads_per_block, min_blocks_per_mp)
 Kernel4()
 {
 }
-template void Kernel4<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
+template __global__ void Kernel4<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
 
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel4{{.*}}, !"maxntidx", i32 256}
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel4{{.*}}, !"minctasm", i32 2}
@@ -58,7 +58,7 @@
 Kernel5()
 {
 }
-template void Kernel5<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
+template __global__ void Kernel5<MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP>();
 
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel5{{.*}}, !"maxntidx", i32 356}
 // CHECK: !{{[0-9]+}} = !{void ()* @{{.*}}Kernel5{{.*}}, !"minctasm", i32 258}
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -7041,6 +7041,19 @@
         continue;
       }
 
+      // Target attributes are part of function signature during cuda
+      // compilation, so deduced template must also have matching CUDA
+      // target. Given that regular template deduction does not take
+      // target attributes into account, we perform target match check
+      // here and reject candidates that have different target.
+      if (LangOpts.CUDA &&
+          IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(FD)) {
+        FailedCandidates.addCandidate().set(
+            I.getPair(), FunTmpl->getTemplatedDecl(),
+            MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
+        continue;
+      }
+
       // Record this candidate.
       if (ExplicitTemplateArgs)
         ConvertedTemplateArgs[Specialization] = std::move(Args);
@@ -8066,6 +8079,7 @@
   //  instantiated from the member definition associated with its class
   //  template.
   UnresolvedSet<8> Matches;
+  AttributeList *Attr = D.getDeclSpec().getAttributes().getList();
   TemplateSpecCandidateSet FailedCandidates(D.getIdentifierLoc());
   for (LookupResult::iterator P = Previous.begin(), PEnd = Previous.end();
        P != PEnd; ++P) {
@@ -8102,6 +8116,26 @@
       continue;
     }
 
+    // Target attributes are part of function signature during cuda
+    // compilation, so deduced template must also have matching CUDA
+    // target. Given that regular template deduction does not take it
+    // into account, we perform target match check here and reject
+    // candidates that have different target.
+    if (LangOpts.CUDA) {
+      CUDAFunctionTarget DeclaratorTarget = IdentifyCUDATarget(Attr);
+      // We need to adjust target when HD is forced by
+      // #pragma clang force_cuda_host_device
+      if (ForceCUDAHostDeviceDepth > 0 &&
+          (DeclaratorTarget == CFT_Device || DeclaratorTarget == CFT_Host))
+        DeclaratorTarget = CFT_HostDevice;
+      if (IdentifyCUDATarget(Specialization) != DeclaratorTarget) {
+        FailedCandidates.addCandidate().set(
+            P.getPair(), FunTmpl->getTemplatedDecl(),
+            MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info));
+        continue;
+      }
+    }
+
     Matches.addDecl(Specialization, P.getAccess());
   }
 
@@ -8172,7 +8206,6 @@
   }
 
   Specialization->setTemplateSpecializationKind(TSK, D.getIdentifierLoc());
-  AttributeList *Attr = D.getDeclSpec().getAttributes().getList();
   if (Attr)
     ProcessDeclAttributeList(S, Specialization, Attr);
 
Index: lib/Sema/SemaOverload.cpp
===================================================================
--- lib/Sema/SemaOverload.cpp
+++ lib/Sema/SemaOverload.cpp
@@ -580,6 +580,7 @@
   case Sema::TDK_TooManyArguments:
   case Sema::TDK_TooFewArguments:
   case Sema::TDK_MiscellaneousDeductionFailure:
+  case Sema::TDK_CUDATargetMismatch:
     Result.Data = nullptr;
     break;
 
@@ -647,6 +648,7 @@
   case Sema::TDK_TooFewArguments:
   case Sema::TDK_InvalidExplicitArguments:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     break;
 
   case Sema::TDK_Inconsistent:
@@ -689,6 +691,7 @@
   case Sema::TDK_DeducedMismatch:
   case Sema::TDK_NonDeducedMismatch:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return TemplateParameter();
 
   case Sema::TDK_Incomplete:
@@ -720,6 +723,7 @@
   case Sema::TDK_Underqualified:
   case Sema::TDK_NonDeducedMismatch:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return nullptr;
 
   case Sema::TDK_DeducedMismatch:
@@ -747,6 +751,7 @@
   case Sema::TDK_InvalidExplicitArguments:
   case Sema::TDK_SubstitutionFailure:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return nullptr;
 
   case Sema::TDK_Inconsistent:
@@ -774,6 +779,7 @@
   case Sema::TDK_InvalidExplicitArguments:
   case Sema::TDK_SubstitutionFailure:
   case Sema::TDK_FailedOverloadResolution:
+  case Sema::TDK_CUDATargetMismatch:
     return nullptr;
 
   case Sema::TDK_Inconsistent:
@@ -1138,20 +1144,11 @@
 
     CUDAFunctionTarget NewTarget = IdentifyCUDATarget(New),
                        OldTarget = IdentifyCUDATarget(Old);
-    if (NewTarget == CFT_InvalidTarget || NewTarget == CFT_Global)
+    if (NewTarget == CFT_InvalidTarget)
       return false;
 
     assert((OldTarget != CFT_InvalidTarget) && "Unexpected invalid target.");
 
-    // Don't allow HD and global functions to overload other functions with the
-    // same signature.  We allow overloading based on CUDA attributes so that
-    // functions can have different implementations on the host and device, but
-    // HD/global functions "exist" in some sense on both the host and device, so
-    // should have the same implementation on both sides.
-    if ((NewTarget == CFT_HostDevice) || (OldTarget == CFT_HostDevice) ||
-        (NewTarget == CFT_Global) || (OldTarget == CFT_Global))
-      return false;
-
     // Allow overloading of functions with same signature and different CUDA
     // target attributes.
     return NewTarget != OldTarget;
@@ -9711,6 +9708,10 @@
     S.Diag(Templated->getLocation(), diag::note_ovl_candidate_bad_deduction);
     MaybeEmitInheritedConstructorNote(S, Found);
     return;
+  case Sema::TDK_CUDATargetMismatch:
+    S.Diag(Templated->getLocation(),
+           diag::note_cuda_ovl_candidate_target_mismatch);
+    return;
   }
 }
 
@@ -9967,6 +9968,7 @@
   case Sema::TDK_DeducedMismatch:
   case Sema::TDK_NonDeducedMismatch:
   case Sema::TDK_MiscellaneousDeductionFailure:
+  case Sema::TDK_CUDATargetMismatch:
     return 3;
 
   case Sema::TDK_InstantiationDepth:
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -8996,6 +8996,9 @@
                !R->isObjCObjectPointerType())
         Diag(NewFD->getLocation(), diag::warn_return_value_udt) << NewFD << R;
     }
+
+    if (!Redeclaration && LangOpts.CUDA)
+      checkCUDATargetOverload(NewFD, Previous);
   }
   return Redeclaration;
 }
Index: lib/Sema/SemaCUDA.cpp
===================================================================
--- lib/Sema/SemaCUDA.cpp
+++ lib/Sema/SemaCUDA.cpp
@@ -54,6 +54,45 @@
                        /*IsExecConfig=*/true);
 }
 
+Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const AttributeList *Attr) {
+  bool HasHostAttr = false;
+  bool HasDeviceAttr = false;
+  bool HasGlobalAttr = false;
+  bool HasInvalidTargetAttr = false;
+  while (Attr) {
+    switch(Attr->getKind()){
+    case AttributeList::AT_CUDAGlobal:
+      HasGlobalAttr = true;
+      break;
+    case AttributeList::AT_CUDAHost:
+      HasHostAttr = true;
+      break;
+    case AttributeList::AT_CUDADevice:
+      HasDeviceAttr = true;
+      break;
+    case AttributeList::AT_CUDAInvalidTarget:
+      HasInvalidTargetAttr = true;
+      break;
+    default:
+      break;
+    }
+    Attr = Attr->getNext();
+  }
+  if (HasInvalidTargetAttr)
+    return CFT_InvalidTarget;
+
+  if (HasGlobalAttr)
+    return CFT_Global;
+
+  if (HasHostAttr && HasDeviceAttr)
+    return CFT_HostDevice;
+
+  if (HasDeviceAttr)
+    return CFT_Device;
+
+  return CFT_Host;
+}
+
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this function
 Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D) {
   // Code that lives outside a function is run on the host.
@@ -801,3 +840,29 @@
     Method->addAttr(CUDAHostAttr::CreateImplicit(Context));
   }
 }
+
+void Sema::checkCUDATargetOverload(FunctionDecl *NewFD,
+                                   LookupResult &Previous) {
+  assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
+  CUDAFunctionTarget NewTarget = IdentifyCUDATarget(NewFD);
+  for (NamedDecl *OldND : Previous) {
+    FunctionDecl *OldFD = OldND->getAsFunction();
+    CUDAFunctionTarget OldTarget = IdentifyCUDATarget(OldFD);
+    // Don't allow HD and global functions to overload other functions with the
+    // same signature.  We allow overloading based on CUDA attributes so that
+    // functions can have different implementations on the host and device, but
+    // HD/global functions "exist" in some sense on both the host and device, so
+    // should have the same implementation on both sides.
+    if (NewTarget != OldTarget &&
+        ((NewTarget == CFT_HostDevice) || (OldTarget == CFT_HostDevice) ||
+         (NewTarget == CFT_Global) || (OldTarget == CFT_Global)) &&
+        !IsOverload(NewFD, OldFD, /* UseMemberUsingDeclRules = */ false,
+                    /* ConsiderCudaAttrs = */ false)) {
+      Diag(NewFD->getLocation(), diag::err_cuda_ovl_target)
+          << NewTarget << NewFD->getDeclName() << OldTarget << OldFD;
+      Diag(OldFD->getLocation(), diag::note_previous_declaration);
+      NewFD->setInvalidDecl();
+      break;
+    }
+  }
+}
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -6556,7 +6556,9 @@
     /// not be resolved to a suitable function.
     TDK_FailedOverloadResolution,
     /// \brief Deduction failed; that's all we know.
-    TDK_MiscellaneousDeductionFailure
+    TDK_MiscellaneousDeductionFailure,
+    /// \brief CUDA Target attributes do not match.
+    TDK_CUDATargetMismatch
   };
 
   TemplateDeductionResult
@@ -9389,6 +9391,7 @@
   /// Use this rather than examining the function's attributes yourself -- you
   /// will get it wrong.  Returns CFT_Host if D is null.
   CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D);
+  CUDAFunctionTarget IdentifyCUDATarget(const AttributeList *Attr);
 
   /// Gets the CUDA target for the current context.
   CUDAFunctionTarget CurrentCUDATarget() {
@@ -9487,6 +9490,10 @@
   bool isEmptyCudaConstructor(SourceLocation Loc, CXXConstructorDecl *CD);
   bool isEmptyCudaDestructor(SourceLocation Loc, CXXDestructorDecl *CD);
 
+  /// Check whether NewFD is a valid overload for CUDA. Emits
+  /// diagnostics and invalidates NewFD if not.
+  void checkCUDATargetOverload(FunctionDecl *NewFD, LookupResult &Previous);
+
   /// \name Code completion
   //@{
   /// \brief Describes the context in which code completion occurs.
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -6758,6 +6758,11 @@
     "__shared__ local variables not allowed in "
     "%select{__device__|__global__|__host__|__host__ __device__}0 functions">;
 def err_cuda_nonglobal_constant : Error<"__constant__ variables must be global">;
+def err_cuda_ovl_target : Error<
+  "%select{__device__|__global__|__host__|__host__ __device__}0 function %1 "
+  "cannot overload %select{__device__|__global__|__host__|__host__ __device__}2 function %3">;
+def note_cuda_ovl_candidate_target_mismatch : Note<
+    "candidate template ignored: target attributes do not match">;
 
 def warn_non_pod_vararg_with_format_string : Warning<
   "cannot pass %select{non-POD|non-trivial}0 object of type %1 to variadic "
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to