yaxunl updated this revision to Diff 276417.
yaxunl added a comment.

revised by Artem's comments


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78655/new/

https://reviews.llvm.org/D78655

Files:
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/Sema/SemaCUDA.cpp
  clang/lib/Sema/SemaLambda.cpp
  clang/test/CodeGenCUDA/lambda.cu
  clang/test/SemaCUDA/Inputs/cuda.h
  clang/test/SemaCUDA/lambda.cu

Index: clang/test/SemaCUDA/lambda.cu
===================================================================
--- /dev/null
+++ clang/test/SemaCUDA/lambda.cu
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -std=c++17 -fsyntax-only -verify=com %s
+// RUN: %clang_cc1 -std=c++17 -fsyntax-only -fcuda-is-device -verify=com,dev %s
+
+#include "Inputs/cuda.h"
+
+auto global_lambda = [] () { return 123; };
+
+template<class F>
+__global__ void kernel(F f) { f(); }
+// dev-note@-1 7{{called by 'kernel<(lambda}}
+
+__host__ __device__ void hd(int x);
+
+class A {
+  int b;
+public:
+  void test() {
+    [=](){ hd(b); }();
+
+    [&](){ hd(b); }();
+
+    kernel<<<1,1>>>([](){ hd(0); });
+
+    kernel<<<1,1>>>([=](){ hd(b); });
+    // dev-error@-1 {{capture host side class data member by this pointer in device or host device lambda function}}
+
+    kernel<<<1,1>>>([&](){ hd(b); });
+    // dev-error@-1 {{capture host side class data member by this pointer in device or host device lambda function}}
+
+    kernel<<<1,1>>>([&] __device__ (){ hd(b); });
+    // dev-error@-1 {{capture host side class data member by this pointer in device or host device lambda function}}
+
+    kernel<<<1,1>>>([&](){
+      auto f = [&]{ hd(b); };
+      // dev-error@-1 {{capture host side class data member by this pointer in device or host device lambda function}}
+      f();
+    });
+  }
+};
+
+int main(void) {
+  auto lambda_kernel = [&]__global__(){};
+  // com-error@-1 {{kernel function 'operator()' must be a free function or static member function}}
+
+  int b;
+  [&](){ hd(b); }();
+
+  [=, &b](){ hd(b); }();
+
+  kernel<<<1,1>>>(global_lambda);
+
+  kernel<<<1,1>>>([](){ hd(0); });
+
+  kernel<<<1,1>>>([=](){ hd(b); });
+
+  kernel<<<1,1>>>([b](){ hd(b); });
+
+  kernel<<<1,1>>>([&](){ hd(b); });
+  // dev-error@-1 {{capture host variable 'b' by reference in device or host device lambda function}}
+
+  kernel<<<1,1>>>([=, &b](){ hd(b); });
+  // dev-error@-1 {{capture host variable 'b' by reference in device or host device lambda function}}
+
+  kernel<<<1,1>>>([&, b](){ hd(b); });
+
+  kernel<<<1,1>>>([&](){
+      auto f = [&]{ hd(b); };
+      // dev-error@-1 {{capture host variable 'b' by reference in device or host device lambda function}}
+      f();
+  });
+
+  return 0;
+}
Index: clang/test/SemaCUDA/Inputs/cuda.h
===================================================================
--- clang/test/SemaCUDA/Inputs/cuda.h
+++ clang/test/SemaCUDA/Inputs/cuda.h
@@ -17,6 +17,19 @@
   __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {}
 };
 
+#ifdef __HIP__
+typedef struct hipStream *hipStream_t;
+typedef enum hipError {} hipError_t;
+int hipConfigureCall(dim3 gridSize, dim3 blockSize, size_t sharedSize = 0,
+                     hipStream_t stream = 0);
+extern "C" hipError_t __hipPushCallConfiguration(dim3 gridSize, dim3 blockSize,
+                                                 size_t sharedSize = 0,
+                                                 hipStream_t stream = 0);
+extern "C" hipError_t hipLaunchKernel(const void *func, dim3 gridDim,
+                                      dim3 blockDim, void **args,
+                                      size_t sharedMem,
+                                      hipStream_t stream);
+#else
 typedef struct cudaStream *cudaStream_t;
 typedef enum cudaError {} cudaError_t;
 
@@ -29,6 +42,7 @@
 extern "C" cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim,
                                         dim3 blockDim, void **args,
                                         size_t sharedMem, cudaStream_t stream);
+#endif
 
 // Host- and device-side placement new overloads.
 void *operator new(__SIZE_TYPE__, void *p) { return p; }
Index: clang/test/CodeGenCUDA/lambda.cu
===================================================================
--- /dev/null
+++ clang/test/CodeGenCUDA/lambda.cu
@@ -0,0 +1,85 @@
+// RUN: %clang_cc1 -x hip -emit-llvm -std=c++11 %s -o - \
+// RUN:   -triple x86_64-linux-gnu \
+// RUN:   | FileCheck -check-prefix=HOST %s
+// RUN: %clang_cc1 -x hip -emit-llvm -std=c++11 %s -o - \
+// RUN:   -triple amdgcn-amd-amdhsa -fcuda-is-device \
+// RUN:   | FileCheck -check-prefix=DEV %s
+
+#include "Inputs/cuda.h"
+
+// Device side kernel name.
+// HOST: @[[KERN_CAPTURE:[0-9]+]] = {{.*}} c"_Z1gIZ12test_capturevEUlvE_EvT_\00"
+// HOST: @[[KERN_RESOLVE:[0-9]+]] = {{.*}} c"_Z1gIZ12test_resolvevEUlvE_EvT_\00"
+
+// Check functions emitted for test_capture in host compilation.
+// Check lambda is not emitted in host compilation.
+// HOST-LABEL: define void @_Z12test_capturev
+// HOST:  call void @_Z19test_capture_helperIZ12test_capturevEUlvE_EvT_
+// HOST-LABEL: define internal void @_Z19test_capture_helperIZ12test_capturevEUlvE_EvT_
+// HOST:  call void @_Z16__device_stub__gIZ12test_capturevEUlvE_EvT_
+// HOST-NOT: define{{.*}}@_ZZ4mainENKUlvE_clEv
+
+// Check functions emitted for test_resolve in host compilation.
+// Check host version of template function 'overloaded' is emitted and called
+// by the lambda function.
+// HOST-LABEL: define void @_Z12test_resolvev
+// HOST:  call void @_Z19test_resolve_helperIZ12test_resolvevEUlvE_EvT_()
+// HOST-LABEL: define internal void @_Z19test_resolve_helperIZ12test_resolvevEUlvE_EvT_
+// HOST:  call void @_Z16__device_stub__gIZ12test_resolvevEUlvE_EvT_
+// HOST:  call void @_ZZ12test_resolvevENKUlvE_clEv
+// HOST-LABEL: define internal void @_ZZ12test_resolvevENKUlvE_clEv
+// HOST:  call i32 @_Z10overloadedIiET_v
+// HOST-LABEL: define linkonce_odr i32 @_Z10overloadedIiET_v
+// HOST:  ret i32 2
+
+// Check kernel is registered with correct device side kernel name.
+// HOST: @__hipRegisterFunction({{.*}}@[[KERN_CAPTURE]]
+// HOST: @__hipRegisterFunction({{.*}}@[[KERN_RESOLVE]]
+
+// DEV: @a = addrspace(1) externally_initialized global i32 0
+
+// Check functions emitted for test_capture in device compilation.
+// Check lambda is emitted in device compilation and accessing device variable.
+// DEV-LABEL: define amdgpu_kernel void @_Z1gIZ12test_capturevEUlvE_EvT_
+// DEV:  call void @_ZZ12test_capturevENKUlvE_clEv
+// DEV-LABEL: define internal void @_ZZ12test_capturevENKUlvE_clEv
+// DEV:  store i32 1, i32* addrspacecast (i32 addrspace(1)* @a to i32*)
+
+// Check functions emitted for test_resolve in device compilation.
+// Check device version of template function 'overloaded' is emitted and called
+// by the lambda function.
+// DEV-LABEL: define amdgpu_kernel void @_Z1gIZ12test_resolvevEUlvE_EvT_
+// DEV:  call void @_ZZ12test_resolvevENKUlvE_clEv
+// DEV-LABEL: define internal void @_ZZ12test_resolvevENKUlvE_clEv
+// DEV:  call i32 @_Z10overloadedIiET_v
+// DEV-LABEL: define linkonce_odr i32 @_Z10overloadedIiET_v
+// DEV:  ret i32 1
+
+__device__ int a;
+
+template<class T>
+__device__ T overloaded() { return 1; }
+
+template<class T>
+__host__ T overloaded() { return 2; }
+
+template<class F>
+__global__ void g(F f) { f(); }
+
+template<class F>
+void test_capture_helper(F f) { g<<<1,1>>>(f); }
+
+template<class F>
+void test_resolve_helper(F f) { g<<<1,1>>>(f); f(); }
+
+// Test capture of device variable in lambda function.
+void test_capture(void) {
+  test_capture_helper([](){ a = 1;});
+}
+
+// Test resolving host/device function in lambda function.
+// Callee should resolve to correct host/device function based on where
+// the lambda function is called, not where it is defined.
+void test_resolve(void) {
+  test_resolve_helper([](){ overloaded<int>();});
+}
Index: clang/lib/Sema/SemaLambda.cpp
===================================================================
--- clang/lib/Sema/SemaLambda.cpp
+++ clang/lib/Sema/SemaLambda.cpp
@@ -990,8 +990,7 @@
   // Attributes on the lambda apply to the method.
   ProcessDeclAttributes(CurScope, Method, ParamInfo);
 
-  // CUDA lambdas get implicit attributes based on the scope in which they're
-  // declared.
+  // CUDA lambdas get implicit host and device attributes.
   if (getLangOpts().CUDA)
     CUDASetLambdaAttrs(Method);
 
@@ -1780,6 +1779,9 @@
       BuildCaptureField(Class, From);
       Captures.push_back(Capture);
       CaptureInits.push_back(Init.get());
+
+      if (LangOpts.CUDA)
+        CUDACheckLambdaCapture(CallOperator, From);
     }
 
     Class->setCaptures(Captures);
Index: clang/lib/Sema/SemaCUDA.cpp
===================================================================
--- clang/lib/Sema/SemaCUDA.cpp
+++ clang/lib/Sema/SemaCUDA.cpp
@@ -17,6 +17,7 @@
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Sema/Lookup.h"
+#include "clang/Sema/ScopeInfo.h"
 #include "clang/Sema/Sema.h"
 #include "clang/Sema/SemaDiagnostic.h"
 #include "clang/Sema/SemaInternal.h"
@@ -746,20 +747,58 @@
          DiagKind != DeviceDiagBuilder::K_ImmediateWithCallStack;
 }
 
+// Check the wrong-sided reference capture of lambda for CUDA/HIP.
+// A lambda function may capture a stack variable by reference when it is
+// defined and uses the capture by reference when the lambda is called. When
+// the capture and use happen on different sides, the capture is invalid and
+// should be diagnosed.
+void Sema::CUDACheckLambdaCapture(CXXMethodDecl *Callee,
+                                  const sema::Capture &Capture) {
+  // In host compilation we only need to check lambda functions emitted on host
+  // side. In such lambda functions, a reference capture is invalid only
+  // if the lambda structure is populated by a device function or kernel then
+  // is passed to and called by a host function. However that is impossible,
+  // since a device function or kernel can only call a device function, also a
+  // kernel cannot pass a lambda back to a host function since we cannot
+  // define a kernel argument type which can hold the lambda before the lambda
+  // itself is defined.
+  if (!LangOpts.CUDAIsDevice)
+    return;
+
+  // File-scope lambda can only do init captures for global variables, which
+  // results in passing by value for these global variables.
+  FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
+  if (!Caller)
+    return;
+
+  // In device compilation, we only need to check lambda functions which are
+  // emitted on device side. For such lambdas, a reference capture is invalid
+  // only if the lambda structure is populated by a host function then passed
+  // to and called in a device function or kernel.
+  bool CalleeIsDevice = Callee->hasAttr<CUDADeviceAttr>();
+  bool CallerIsHost =
+      !Caller->hasAttr<CUDAGlobalAttr>() && !Caller->hasAttr<CUDADeviceAttr>();
+  bool ShouldCheck = CalleeIsDevice && CallerIsHost;
+  if (!ShouldCheck || !Capture.isReferenceCapture())
+    return;
+  auto DiagKind = DeviceDiagBuilder::K_Deferred;
+  if (Capture.isVariableCapture()) {
+    DeviceDiagBuilder(DiagKind, Capture.getLocation(),
+                      diag::err_capture_bad_target, Callee, *this)
+        << Capture.getVariable();
+  } else if (Capture.isThisCapture()) {
+    DeviceDiagBuilder(DiagKind, Capture.getLocation(),
+                      diag::err_capture_bad_target_this_ptr, Callee, *this);
+  }
+  return;
+}
+
 void Sema::CUDASetLambdaAttrs(CXXMethodDecl *Method) {
   assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
   if (Method->hasAttr<CUDAHostAttr>() || Method->hasAttr<CUDADeviceAttr>())
     return;
-  FunctionDecl *CurFn = dyn_cast<FunctionDecl>(CurContext);
-  if (!CurFn)
-    return;
-  CUDAFunctionTarget Target = IdentifyCUDATarget(CurFn);
-  if (Target == CFT_Global || Target == CFT_Device) {
-    Method->addAttr(CUDADeviceAttr::CreateImplicit(Context));
-  } else if (Target == CFT_HostDevice) {
-    Method->addAttr(CUDADeviceAttr::CreateImplicit(Context));
-    Method->addAttr(CUDAHostAttr::CreateImplicit(Context));
-  }
+  Method->addAttr(CUDADeviceAttr::CreateImplicit(Context));
+  Method->addAttr(CUDAHostAttr::CreateImplicit(Context));
 }
 
 void Sema::checkCUDATargetOverload(FunctionDecl *NewFD,
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11781,12 +11781,13 @@
   /// - Otherwise, returns true without emitting any diagnostics.
   bool CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee);
 
+  void CUDACheckLambdaCapture(CXXMethodDecl *D, const sema::Capture &Capture);
+
   /// Set __device__ or __host__ __device__ attributes on the given lambda
   /// operator() method.
   ///
-  /// CUDA lambdas declared inside __device__ or __global__ functions inherit
-  /// the __device__ attribute.  Similarly, lambdas inside __host__ __device__
-  /// functions become __host__ __device__ themselves.
+  /// CUDA lambdas by default is host device function unless it has explicit
+  /// host or device attribute.
   void CUDASetLambdaAttrs(CXXMethodDecl *Method);
 
   /// Finds a function in \p Matches with highest calling priority
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7993,6 +7993,10 @@
 def err_ref_bad_target_global_initializer : Error<
   "reference to %select{__device__|__global__|__host__|__host__ __device__}0 "
   "function %1 in global initializer">;
+def err_capture_bad_target : Error<
+  "capture host variable %0 by reference in device or host device lambda function">;
+def err_capture_bad_target_this_ptr : Error<
+  "capture host side class data member by this pointer in device or host device lambda function">;
 def warn_kern_is_method : Extension<
   "kernel function %0 is a member function; this may not be accepted by nvcc">,
   InGroup<CudaCompat>;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to