This revision was landed with ongoing or failed builds.
This revision was automatically updated to reflect the committed changes.
yaxunl marked an inline comment as done.
Closed by commit rGd41445113bcc: [CUDA][HIP] Fix hostness check with -fopenmp 
(authored by yaxunl).
Herald added a project: clang.

Changed prior to commit:
  https://reviews.llvm.org/D121765?vs=416974&id=418017#toc

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D121765/new/

https://reviews.llvm.org/D121765

Files:
  clang/include/clang/Sema/Sema.h
  clang/lib/Sema/Sema.cpp
  clang/lib/Sema/SemaCUDA.cpp
  clang/lib/Sema/SemaOverload.cpp
  clang/test/CodeGenCUDA/openmp-parallel.cu
  clang/test/SemaCUDA/openmp-parallel.cu

Index: clang/test/SemaCUDA/openmp-parallel.cu
===================================================================
--- /dev/null
+++ clang/test/SemaCUDA/openmp-parallel.cu
@@ -0,0 +1,19 @@
+// RUN: %clang_cc1 -fopenmp -fsyntax-only -verify %s
+
+#include "Inputs/cuda.h"
+
+__device__ void foo(int) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+// expected-note@-1 {{'foo' declared here}}
+
+int main() {
+  #pragma omp parallel
+  for (int i = 0; i < 100; i++)
+    foo(1); // expected-error {{no matching function for call to 'foo'}}
+  
+  auto Lambda = []() {
+    #pragma omp parallel
+    for (int i = 0; i < 100; i++)
+      foo(1); // expected-error {{reference to __device__ function 'foo' in __host__ __device__ function}}
+    };
+  Lambda(); // expected-note {{called by 'main'}}
+}
Index: clang/test/CodeGenCUDA/openmp-parallel.cu
===================================================================
--- /dev/null
+++ clang/test/CodeGenCUDA/openmp-parallel.cu
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
+// RUN:   -fopenmp -emit-llvm -o -  -x hip %s | FileCheck %s
+
+#include "Inputs/cuda.h"
+
+void foo(double) {}
+__device__ void foo(int) {}
+
+// Check foo resolves to the host function.
+// CHECK-LABEL: define {{.*}}@_Z5test1v
+// CHECK: call void @_Z3food(double noundef 1.000000e+00)
+void test1() {
+  #pragma omp parallel
+  for (int i = 0; i < 100; i++)
+    foo(1);
+}
+
+// Check foo resolves to the host function.
+// CHECK-LABEL: define {{.*}}@_Z5test2v
+// CHECK: call void @_Z3food(double noundef 1.000000e+00)
+void test2() {
+  auto Lambda = []() {
+    #pragma omp parallel
+    for (int i = 0; i < 100; i++)
+      foo(1);
+  };
+  Lambda();
+}
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -6473,7 +6473,7 @@
 
   // (CUDA B.1): Check for invalid calls between targets.
   if (getLangOpts().CUDA)
-    if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
+    if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
       // Skip the check for callers that are implicit members, because in this
       // case we may not yet know what the member's target is; the target is
       // inferred for the member automatically, based on the bases and fields of
@@ -6983,7 +6983,7 @@
 
   // (CUDA B.1): Check for invalid calls between targets.
   if (getLangOpts().CUDA)
-    if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
+    if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
       if (!IsAllowedCUDACall(Caller, Method)) {
         Candidate.Viable = false;
         Candidate.FailureKind = ovl_fail_bad_target;
@@ -9639,7 +9639,7 @@
   // overloading resolution diagnostics.
   if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function &&
       S.getLangOpts().GPUExcludeWrongSideOverloads) {
-    if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext)) {
+    if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
       bool IsCallerImplicitHD = Sema::isCUDAImplicitHostDeviceFunction(Caller);
       bool IsCand1ImplicitHD =
           Sema::isCUDAImplicitHostDeviceFunction(Cand1.Function);
@@ -9922,7 +9922,7 @@
   // If other rules cannot determine which is better, CUDA preference is used
   // to determine which is better.
   if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) {
-    FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
+    FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
     return S.IdentifyCUDAPreference(Caller, Cand1.Function) >
            S.IdentifyCUDAPreference(Caller, Cand2.Function);
   }
@@ -10043,7 +10043,7 @@
   // -fgpu-exclude-wrong-side-overloads is on, all candidates are compared
   // uniformly in isBetterOverloadCandidate.
   if (S.getLangOpts().CUDA && !S.getLangOpts().GPUExcludeWrongSideOverloads) {
-    const FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
+    const FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
     bool ContainsSameSideCandidate =
         llvm::any_of(Candidates, [&](OverloadCandidate *Cand) {
           // Check viable function only.
@@ -11077,7 +11077,7 @@
 
 /// CUDA: diagnose an invalid call across targets.
 static void DiagnoseBadTarget(Sema &S, OverloadCandidate *Cand) {
-  FunctionDecl *Caller = cast<FunctionDecl>(S.CurContext);
+  FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
   FunctionDecl *Callee = Cand->Function;
 
   Sema::CUDAFunctionTarget CallerTarget = S.IdentifyCUDATarget(Caller),
@@ -12136,7 +12136,7 @@
 
     if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
       if (S.getLangOpts().CUDA)
-        if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext))
+        if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
           if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
             return false;
       if (FunDecl->isMultiVersion()) {
@@ -12253,7 +12253,8 @@
   }
 
   void EliminateSuboptimalCudaMatches() {
-    S.EraseUnwantedCUDAMatches(dyn_cast<FunctionDecl>(S.CurContext), Matches);
+    S.EraseUnwantedCUDAMatches(S.getCurFunctionDecl(/*AllowLambda=*/true),
+                               Matches);
   }
 
 public:
Index: clang/lib/Sema/SemaCUDA.cpp
===================================================================
--- clang/lib/Sema/SemaCUDA.cpp
+++ clang/lib/Sema/SemaCUDA.cpp
@@ -728,8 +728,9 @@
 Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
                                                        unsigned DiagID) {
   assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
+  FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true);
   SemaDiagnosticBuilder::Kind DiagKind = [&] {
-    if (!isa<FunctionDecl>(CurContext))
+    if (!CurFunContext)
       return SemaDiagnosticBuilder::K_Nop;
     switch (CurrentCUDATarget()) {
     case CFT_Global:
@@ -743,7 +744,7 @@
         return SemaDiagnosticBuilder::K_Nop;
       if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
         return SemaDiagnosticBuilder::K_Immediate;
-      return (getEmissionStatus(cast<FunctionDecl>(CurContext)) ==
+      return (getEmissionStatus(CurFunContext) ==
               FunctionEmissionStatus::Emitted)
                  ? SemaDiagnosticBuilder::K_ImmediateWithCallStack
                  : SemaDiagnosticBuilder::K_Deferred;
@@ -751,15 +752,15 @@
       return SemaDiagnosticBuilder::K_Nop;
     }
   }();
-  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID,
-                               dyn_cast<FunctionDecl>(CurContext), *this);
+  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this);
 }
 
 Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
                                                      unsigned DiagID) {
   assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
+  FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true);
   SemaDiagnosticBuilder::Kind DiagKind = [&] {
-    if (!isa<FunctionDecl>(CurContext))
+    if (!CurFunContext)
       return SemaDiagnosticBuilder::K_Nop;
     switch (CurrentCUDATarget()) {
     case CFT_Host:
@@ -772,7 +773,7 @@
         return SemaDiagnosticBuilder::K_Nop;
       if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
         return SemaDiagnosticBuilder::K_Immediate;
-      return (getEmissionStatus(cast<FunctionDecl>(CurContext)) ==
+      return (getEmissionStatus(CurFunContext) ==
               FunctionEmissionStatus::Emitted)
                  ? SemaDiagnosticBuilder::K_ImmediateWithCallStack
                  : SemaDiagnosticBuilder::K_Deferred;
@@ -780,8 +781,7 @@
       return SemaDiagnosticBuilder::K_Nop;
     }
   }();
-  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID,
-                               dyn_cast<FunctionDecl>(CurContext), *this);
+  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this);
 }
 
 bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
@@ -794,7 +794,7 @@
 
   // FIXME: Is bailing out early correct here?  Should we instead assume that
   // the caller is a global initializer?
-  FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
+  FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
   if (!Caller)
     return true;
 
@@ -860,7 +860,7 @@
 
   // File-scope lambda can only do init captures for global variables, which
   // results in passing by value for these global variables.
-  FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
+  FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
   if (!Caller)
     return;
 
Index: clang/lib/Sema/Sema.cpp
===================================================================
--- clang/lib/Sema/Sema.cpp
+++ clang/lib/Sema/Sema.cpp
@@ -1421,19 +1421,18 @@
 // Helper functions.
 //===----------------------------------------------------------------------===//
 
-DeclContext *Sema::getFunctionLevelDeclContext() {
+DeclContext *Sema::getFunctionLevelDeclContext(bool AllowLambda) {
   DeclContext *DC = CurContext;
 
   while (true) {
     if (isa<BlockDecl>(DC) || isa<EnumDecl>(DC) || isa<CapturedDecl>(DC) ||
         isa<RequiresExprBodyDecl>(DC)) {
       DC = DC->getParent();
-    } else if (isa<CXXMethodDecl>(DC) &&
+    } else if (!AllowLambda && isa<CXXMethodDecl>(DC) &&
                cast<CXXMethodDecl>(DC)->getOverloadedOperator() == OO_Call &&
                cast<CXXRecordDecl>(DC->getParent())->isLambda()) {
       DC = DC->getParent()->getParent();
-    }
-    else break;
+    } else break;
   }
 
   return DC;
@@ -1442,8 +1441,8 @@
 /// getCurFunctionDecl - If inside of a function body, this returns a pointer
 /// to the function decl for the function being parsed.  If we're currently
 /// in a 'block', this returns the containing context.
-FunctionDecl *Sema::getCurFunctionDecl() {
-  DeclContext *DC = getFunctionLevelDeclContext();
+FunctionDecl *Sema::getCurFunctionDecl(bool AllowLambda) {
+  DeclContext *DC = getFunctionLevelDeclContext(AllowLambda);
   return dyn_cast<FunctionDecl>(DC);
 }
 
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -3318,12 +3318,14 @@
   void ActOnReenterFunctionContext(Scope* S, Decl* D);
   void ActOnExitFunctionContext();
 
-  DeclContext *getFunctionLevelDeclContext();
-
-  /// getCurFunctionDecl - If inside of a function body, this returns a pointer
-  /// to the function decl for the function being parsed.  If we're currently
-  /// in a 'block', this returns the containing context.
-  FunctionDecl *getCurFunctionDecl();
+  /// If \p AllowLambda is true, treat lambda as function.
+  DeclContext *getFunctionLevelDeclContext(bool AllowLambda = false);
+
+  /// Returns a pointer to the innermost enclosing function, or nullptr if the
+  /// current context is not inside a function. If \p AllowLambda is true,
+  /// this can return the call operator of an enclosing lambda, otherwise
+  /// lambdas are skipped when looking for an enclosing function.
+  FunctionDecl *getCurFunctionDecl(bool AllowLambda = false);
 
   /// getCurMethodDecl - If inside of a method body, this returns a pointer to
   /// the method decl for the method being parsed.  If we're currently
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to