yaxunl updated this revision to Diff 256903.
yaxunl retitled this revision from "[CUDA][HIP] Fix overload resolution issue 
for device host functions" to "[CUDA][HIP] Fix host/device based overload 
resolution".
yaxunl added a comment.

Revised by John's comments.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D77954/new/

https://reviews.llvm.org/D77954

Files:
  clang/lib/Sema/SemaOverload.cpp
  clang/test/SemaCUDA/function-overload.cu

Index: clang/test/SemaCUDA/function-overload.cu
===================================================================
--- clang/test/SemaCUDA/function-overload.cu
+++ clang/test/SemaCUDA/function-overload.cu
@@ -331,9 +331,6 @@
 // If we have a mix of HD and H-only or D-only candidates in the overload set,
 // normal C++ overload resolution rules apply first.
 template <typename T> TemplateReturnTy template_vs_hd_function(T arg)
-#ifdef __CUDA_ARCH__
-//expected-note@-2 {{declared here}}
-#endif
 {
   return TemplateReturnTy();
 }
@@ -342,11 +339,13 @@
 }
 
 __host__ __device__ void test_host_device_calls_hd_template() {
-  HostDeviceReturnTy ret1 = template_vs_hd_function(1.0f);
-  TemplateReturnTy ret2 = template_vs_hd_function(1);
 #ifdef __CUDA_ARCH__
-  // expected-error@-2 {{reference to __host__ function 'template_vs_hd_function<int>' in __host__ __device__ function}}
+  typedef HostDeviceReturnTy ExpectedReturnTy;
+#else
+  typedef TemplateReturnTy ExpectedReturnTy;
 #endif
+  HostDeviceReturnTy ret1 = template_vs_hd_function(1.0f);
+  ExpectedReturnTy ret2 = template_vs_hd_function(1);
 }
 
 __host__ void test_host_calls_hd_template() {
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -9475,6 +9475,35 @@
   else if (!Cand1.Viable)
     return false;
 
+  // If Cand1 can be emitted and Cand2 cannot be emitted in the current context,
+  // Cand1 is better than Cand2. If Cand1 can not be emitted and Cand2 can be
+  // emitted, Cand1 is not better than Cand2. This rule should have precedence
+  // over other rules.
+  //
+  // If both Cand1 and Cand2 can be emitted, or neither can be emitted, then
+  // other rules should be used to determine which is better.
+  //
+  // If other rules cannot determine which is better, CUDA preference will be
+  // used again to determine which is better.
+  //
+  // TODO: Currently IdentifyCUDAPreference does not return correct values
+  // for functions called in global variable initializers due to missing
+  // correct context about device/host. Therefore we can only enforce this
+  // rule when there is a caller. We should enforce this rule for functions
+  // in global variable initializers once proper context is added.
+  if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) {
+    if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext)) {
+      auto Cand1Emittable = S.IdentifyCUDAPreference(Caller, Cand1.Function) >
+                            Sema::CFP_WrongSide;
+      auto Cand2Emittable = S.IdentifyCUDAPreference(Caller, Cand2.Function) >
+                            Sema::CFP_WrongSide;
+      if (Cand1Emittable && !Cand2Emittable)
+        return true;
+      if (!Cand1Emittable && Cand2Emittable)
+        return false;
+    }
+  }
+
   // C++ [over.match.best]p1:
   //
   //   -- if F is a static member function, ICS1(F) is defined such
@@ -9709,12 +9738,6 @@
       return Cmp == Comparison::Better;
   }
 
-  if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) {
-    FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
-    return S.IdentifyCUDAPreference(Caller, Cand1.Function) >
-           S.IdentifyCUDAPreference(Caller, Cand2.Function);
-  }
-
   bool HasPS1 = Cand1.Function != nullptr &&
                 functionHasPassObjectSizeParams(Cand1.Function);
   bool HasPS2 = Cand2.Function != nullptr &&
@@ -9722,7 +9745,19 @@
   if (HasPS1 != HasPS2 && HasPS1)
     return true;
 
-  return isBetterMultiversionCandidate(Cand1, Cand2);
+  if (isBetterMultiversionCandidate(Cand1, Cand2))
+    return true;
+
+  // If other rules cannot determine which is better, CUDA preference is used
+  // to determine which is better.
+  if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) {
+    if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext)) {
+      return S.IdentifyCUDAPreference(Caller, Cand1.Function) >
+             S.IdentifyCUDAPreference(Caller, Cand2.Function);
+    }
+  }
+
+  return false;
 }
 
 /// Determine whether two declarations are "equivalent" for the purposes of
@@ -9808,33 +9843,6 @@
   std::transform(begin(), end(), std::back_inserter(Candidates),
                  [](OverloadCandidate &Cand) { return &Cand; });
 
-  // [CUDA] HD->H or HD->D calls are technically not allowed by CUDA but
-  // are accepted by both clang and NVCC. However, during a particular
-  // compilation mode only one call variant is viable. We need to
-  // exclude non-viable overload candidates from consideration based
-  // only on their host/device attributes. Specifically, if one
-  // candidate call is WrongSide and the other is SameSide, we ignore
-  // the WrongSide candidate.
-  if (S.getLangOpts().CUDA) {
-    const FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
-    bool ContainsSameSideCandidate =
-        llvm::any_of(Candidates, [&](OverloadCandidate *Cand) {
-          // Check viable function only.
-          return Cand->Viable && Cand->Function &&
-                 S.IdentifyCUDAPreference(Caller, Cand->Function) ==
-                     Sema::CFP_SameSide;
-        });
-    if (ContainsSameSideCandidate) {
-      auto IsWrongSideCandidate = [&](OverloadCandidate *Cand) {
-        // Check viable function only to avoid unnecessary data copying/moving.
-        return Cand->Viable && Cand->Function &&
-               S.IdentifyCUDAPreference(Caller, Cand->Function) ==
-                   Sema::CFP_WrongSide;
-      };
-      llvm::erase_if(Candidates, IsWrongSideCandidate);
-    }
-  }
-
   // Find the best viable function.
   Best = end();
   for (auto *Cand : Candidates) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to