llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

<details>
<summary>Changes</summary>

Currently, constraints are checked in Sema::FinishTemplateArgumentDeduction, 
where the current function in ASTContext is set to the instantiated template 
function. When resolving functions for the constraints, clang assumes the 
caller is the current function, This causes incompatibility with nvcc and also 
for constexpr template functions with C++.

clang caches the constraint checking result per concept/type matching. It 
assumes the result does not depend on the instantiation context.

This patch let constraint checking have its own host/device context and by 
default it is host to be compatible with C++. This makes the constraint 
checking independent of callers and make the caching valid.

In the future, we may introduce device constraints by other means, e.g. adding 
__device__ attribute per function call in constraints.

Fixes: https://github.com/llvm/llvm-project/issues/67507

---
Full diff: https://github.com/llvm/llvm-project/pull/67721.diff


5 Files Affected:

- (modified) clang/docs/HIPSupport.rst (+31) 
- (modified) clang/include/clang/Sema/Sema.h (+7-2) 
- (modified) clang/lib/Sema/SemaCUDA.cpp (+21-13) 
- (modified) clang/lib/Sema/SemaConcept.cpp (+2) 
- (added) clang/test/SemaCUDA/concept.cu (+23) 


``````````diff
diff --git a/clang/docs/HIPSupport.rst b/clang/docs/HIPSupport.rst
index 8b4649733a9c777..ea7eed0fe7ce1eb 100644
--- a/clang/docs/HIPSupport.rst
+++ b/clang/docs/HIPSupport.rst
@@ -176,3 +176,34 @@ Predefined Macros
    * - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
      - Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
 
+C++20 Concepts with HIP and CUDA
+--------------------------------
+
+In Clang, when working with HIP or CUDA, it's important to note that all 
constraints in C++20 concepts are assumed to be for the host side only. This 
behavior is consistent across both programming models, and developers should be 
aware of this assumption when writing code that utilizes C++20 concepts.
+
+Example:
+.. code-block:: c++
+
+   template <class T>
+   concept MyConcept = requires(T& obj) {
+     my_function(obj);  // Assumed to be a host-side requirement
+   };
+
+   template <MyConcept T>
+   __global__ void kernel() {
+      // Kernel code
+   }
+
+   struct MyType {};
+
+   inline void my_function(MyType& obj) {}
+
+   int main() {
+      kernel<MyType><<<1,1>>>();
+      return 0;
+   }
+
+In the above example, the ``MyConcept`` concept is assumed to check the 
host-side requirements, even though it's being used in a device kernel. 
Developers should structure their code accordingly to ensure correct behavior 
and to satisfy the host-side constraints assumed by Clang.
+
+This assumption helps maintain a consistent behavior when dealing with 
template constraints, and simplifies the compilation model by reducing the 
complexity associated with differentiating between host and device-side 
requirements.
+
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 712db0a3dd895d5..9b1545b634177d4 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13312,6 +13312,7 @@ class Sema final {
     CTCK_Unknown,       /// Unknown context
     CTCK_InitGlobalVar, /// Function called during global variable
                         /// initialization
+    CTCK_Constraint,    /// Function called for constraint checking
   };
 
   /// Define the current global CUDA host/device context where a function may 
be
@@ -13319,13 +13320,17 @@ class Sema final {
   struct CUDATargetContext {
     CUDAFunctionTarget Target = CFT_HostDevice;
     CUDATargetContextKind Kind = CTCK_Unknown;
-    Decl *D = nullptr;
+    const Decl *D = nullptr;
+    const Expr *E = nullptr;
+    /// Whether should override the current function.
+    bool shouldOverride(const Decl *D) const;
   } CurCUDATargetCtx;
 
   struct CUDATargetContextRAII {
     Sema &S;
     CUDATargetContext SavedCtx;
-    CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
+    CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, const Decl *D,
+                          const Expr *E = nullptr);
     ~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
   };
 
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index 88f5484575db17a..e72c42e672167d9 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -114,27 +114,35 @@ static bool hasAttr(const Decl *D, bool 
IgnoreImplicitAttr) {
 
 Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
                                                    CUDATargetContextKind K,
-                                                   Decl *D)
+                                                   const Decl *D,
+                                                   const Expr *E)
     : S(S_) {
   SavedCtx = S.CurCUDATargetCtx;
-  assert(K == CTCK_InitGlobalVar);
-  auto *VD = dyn_cast_or_null<VarDecl>(D);
-  if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
-    auto Target = CFT_Host;
-    if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
-         !hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
-        hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
-        hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
-      Target = CFT_Device;
-    S.CurCUDATargetCtx = {Target, K, VD};
+  auto Target = CFT_Host;
+  if (K == CTCK_InitGlobalVar) {
+    auto *VD = dyn_cast_or_null<VarDecl>(D);
+    if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
+      if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
+           !hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
+          hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
+          hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
+        Target = CFT_Device;
+      S.CurCUDATargetCtx = {Target, K, D, E};
+    }
+    return;
   }
+  assert(K == CTCK_Constraint);
+  S.CurCUDATargetCtx = {Target, K, D, E};
+}
+
+bool Sema::CUDATargetContext::shouldOverride(const Decl *D)const {
+  return Kind == CTCK_Constraint || D == nullptr;
 }
 
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this 
function
 Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
                                                   bool IgnoreImplicitHDAttr) {
-  // Code that lives outside a function gets the target from CurCUDATargetCtx.
-  if (D == nullptr)
+  if (CurCUDATargetCtx.shouldOverride(D))
     return CurCUDATargetCtx.Target;
 
   if (D->hasAttr<CUDAInvalidTargetAttr>())
diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp
index 036548b68247bfa..6475f4e3dcde49d 100644
--- a/clang/lib/Sema/SemaConcept.cpp
+++ b/clang/lib/Sema/SemaConcept.cpp
@@ -336,6 +336,8 @@ static ExprResult calculateConstraintSatisfaction(
     Sema &S, const NamedDecl *Template, SourceLocation TemplateNameLoc,
     const MultiLevelTemplateArgumentList &MLTAL, const Expr *ConstraintExpr,
     ConstraintSatisfaction &Satisfaction) {
+  Sema::CUDATargetContextRAII X(S, Sema::CTCK_Constraint,
+      /*Decl=*/nullptr, ConstraintExpr);
   return calculateConstraintSatisfaction(
       S, ConstraintExpr, Satisfaction, [&](const Expr *AtomicExpr) {
         EnterExpressionEvaluationContext ConstantEvaluated(
diff --git a/clang/test/SemaCUDA/concept.cu b/clang/test/SemaCUDA/concept.cu
new file mode 100644
index 000000000000000..1ed906b01a94efa
--- /dev/null
+++ b/clang/test/SemaCUDA/concept.cu
@@ -0,0 +1,23 @@
+// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -fcuda-is-device -x hip %s \
+// RUN:   -std=c++20 -fsyntax-only -verify
+// RUN: %clang_cc1 -triple x86_64 -x hip %s \
+// RUN:   -std=c++20 -fsyntax-only -verify
+
+// expected-no-diagnostics
+
+#include "Inputs/cuda.h"
+
+template <class T>
+concept C = requires(T x) {
+  func(x);
+};
+
+struct A {};
+void func(A x) {}
+
+template <C T> __attribute__((global)) void kernel(T x) { }
+
+int main() {
+  A a;
+  kernel<<<1,1>>>(a);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/67721
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to