https://github.com/alexander-shaposhnikov updated 
https://github.com/llvm/llvm-project/pull/131546

>From 86eefd7db18252d74f7b5891e7490653b6378eb0 Mon Sep 17 00:00:00 2001
From: Alexander Shaposhnikov <ashaposhni...@google.com>
Date: Mon, 17 Mar 2025 00:39:24 +0000
Subject: [PATCH] [CudaSPIRV] Allow using integral non-type template parameters
 as reqd_work_group_size arguments

---
 .../altera/SingleWorkItemBarrierCheck.cpp     |  8 +-
 clang/include/clang/Basic/Attr.td             |  7 +-
 clang/lib/CodeGen/CodeGenFunction.cpp         | 18 +++--
 clang/lib/CodeGen/Targets/AMDGPU.cpp          | 14 +++-
 clang/lib/CodeGen/Targets/TCE.cpp             | 18 ++---
 clang/lib/Sema/SemaDeclAttr.cpp               | 79 ++++++++++++++++---
 .../lib/Sema/SemaTemplateInstantiateDecl.cpp  | 32 ++++++++
 clang/test/SemaCUDA/spirv-attrs-diag.cu       | 34 ++++++++
 clang/test/SemaCUDA/spirv-attrs.cu            | 14 ++++
 9 files changed, 190 insertions(+), 34 deletions(-)
 create mode 100644 clang/test/SemaCUDA/spirv-attrs-diag.cu

diff --git a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp 
b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
index df21c425ea956..c5da66a1f28b6 100644
--- a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
+++ b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
@@ -54,8 +54,12 @@ void SingleWorkItemBarrierCheck::check(const 
MatchFinder::MatchResult &Result) {
     bool IsNDRange = false;
     if (MatchedDecl->hasAttr<ReqdWorkGroupSizeAttr>()) {
       const auto *Attribute = MatchedDecl->getAttr<ReqdWorkGroupSizeAttr>();
-      if (Attribute->getXDim() > 1 || Attribute->getYDim() > 1 ||
-          Attribute->getZDim() > 1)
+      auto Eval = [&](Expr *E) {
+        return E->EvaluateKnownConstInt(MatchedDecl->getASTContext())
+            .getExtValue();
+      };
+      if (Eval(Attribute->getXDim()) > 1 || Eval(Attribute->getYDim()) > 1 ||
+          Eval(Attribute->getZDim()) > 1)
         IsNDRange = true;
     }
     if (IsNDRange) // No warning if kernel is treated as an NDRange.
diff --git a/clang/include/clang/Basic/Attr.td 
b/clang/include/clang/Basic/Attr.td
index 4d34346460561..cceb4085d523d 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -3044,8 +3044,7 @@ def NoDeref : TypeAttr {
 def ReqdWorkGroupSize : InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"reqd_work_group_size">];
-  let Args = [UnsignedArgument<"XDim">, UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, 
ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
@@ -3053,9 +3052,7 @@ def ReqdWorkGroupSize : InheritableAttr {
 def WorkGroupSizeHint :  InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"work_group_size_hint">];
-  let Args = [UnsignedArgument<"XDim">,
-              UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, 
ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp 
b/clang/lib/CodeGen/CodeGenFunction.cpp
index 447192bc7f60c..de1e04a7982fa 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -649,18 +649,24 @@ void CodeGenFunction::EmitKernelMetadata(const 
FunctionDecl *FD,
   }
 
   if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("work_group_size_hint", llvm::MDNode::get(Context, 
AttrMDArgs));
   }
 
   if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("reqd_work_group_size", llvm::MDNode::get(Context, 
AttrMDArgs));
   }
 
diff --git a/clang/lib/CodeGen/Targets/AMDGPU.cpp 
b/clang/lib/CodeGen/Targets/AMDGPU.cpp
index 7d84310ba0bf5..8f999ab1130f4 100644
--- a/clang/lib/CodeGen/Targets/AMDGPU.cpp
+++ b/clang/lib/CodeGen/Targets/AMDGPU.cpp
@@ -753,12 +753,20 @@ void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
     int32_t *MaxThreadsVal) {
   unsigned Min = 0;
   unsigned Max = 0;
+  auto Eval = [&](Expr *E) {
+    return E->EvaluateKnownConstInt(getContext()).getExtValue();
+  };
   if (FlatWGS) {
-    Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
-    Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
+    Min = Eval(
+        FlatWGS
+            ->getMin());
+    Max = Eval(
+        FlatWGS
+            ->getMax());
   }
   if (ReqdWGS && Min == 0 && Max == 0)
-    Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
+    Min = Max = Eval(ReqdWGS->getXDim()) * Eval(ReqdWGS->getYDim()) *
+                Eval(ReqdWGS->getZDim());
 
   if (Min != 0) {
     assert(Min <= Max && "Min must be less than or equal Max");
diff --git a/clang/lib/CodeGen/Targets/TCE.cpp 
b/clang/lib/CodeGen/Targets/TCE.cpp
index d7178b4b8a949..bfce262ea965c 100644
--- a/clang/lib/CodeGen/Targets/TCE.cpp
+++ b/clang/lib/CodeGen/Targets/TCE.cpp
@@ -53,15 +53,15 @@ void TCETargetCodeGenInfo::setTargetAttributes(
         SmallVector<llvm::Metadata *, 5> Operands;
         Operands.push_back(llvm::ConstantAsMetadata::get(F));
 
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getXDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getYDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getZDim()))));
+        auto Eval = [&](Expr *E) {
+          return E->EvaluateKnownConstInt(FD->getASTContext());
+        };
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, 
Eval(Attr->getXDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, 
Eval(Attr->getYDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, 
Eval(Attr->getZDim()))));
 
         // Add a boolean constant operand for "required" (true) or "hint"
         // (false) for implementing the work_group_size_hint attr later.
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 769f5316ed934..0a8a3e1c49414 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -2914,21 +2914,70 @@ static void handleWeakImportAttr(Sema &S, Decl *D, 
const ParsedAttr &AL) {
   D->addAttr(::new (S.Context) WeakImportAttr(S.Context, AL));
 }
 
+// Checks whether an argument of launch_bounds-like attribute is
+// acceptable, performs implicit conversion to Rvalue, and returns
+// non-nullptr Expr result on success. Otherwise, it returns nullptr
+// and may output an error.
+template <class Attribute>
+static Expr *makeAttributeArgExpr(Sema &S, Expr *E, const Attribute &Attr,
+                                  const unsigned Idx) {
+  if (S.DiagnoseUnexpandedParameterPack(E))
+    return nullptr;
+
+  // Accept template arguments for now as they depend on something else.
+  // We'll get to check them when they eventually get instantiated.
+  if (E->isValueDependent())
+    return E;
+
+  std::optional<llvm::APSInt> I = llvm::APSInt(64);
+  if (!(I = E->getIntegerConstantExpr(S.Context))) {
+    S.Diag(E->getExprLoc(), diag::err_attribute_argument_n_type)
+        << &Attr << Idx << AANT_ArgumentIntegerConstant << E->getSourceRange();
+    return nullptr;
+  }
+  // Make sure we can fit it in 32 bits.
+  if (!I->isIntN(32)) {
+    S.Diag(E->getExprLoc(), diag::err_ice_too_large)
+        << toString(*I, 10, false) << 32 << /* Unsigned */ 1;
+    return nullptr;
+  }
+  if (*I < 0)
+    S.Diag(E->getExprLoc(), diag::err_attribute_requires_positive_integer)
+        << &Attr << /*non-negative*/ 1 << E->getSourceRange();
+
+  // We may need to perform implicit conversion of the argument.
+  InitializedEntity Entity = InitializedEntity::InitializeParameter(
+      S.Context, S.Context.getConstType(S.Context.IntTy), /*consume*/ false);
+  ExprResult ValArg = S.PerformCopyInitialization(Entity, SourceLocation(), E);
+  assert(!ValArg.isInvalid() &&
+         "Unexpected PerformCopyInitialization() failure.");
+
+  return ValArg.getAs<Expr>();
+}
+
 // Handles reqd_work_group_size and work_group_size_hint.
 template <typename WorkGroupAttr>
 static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
-  uint32_t WGSize[3];
+  Expr *WGSize[3];
   for (unsigned i = 0; i < 3; ++i) {
-    const Expr *E = AL.getArgAsExpr(i);
-    if (!S.checkUInt32Argument(AL, E, WGSize[i], i,
-                               /*StrictlyUnsigned=*/true))
+    if (Expr *E = makeAttributeArgExpr(S, AL.getArgAsExpr(i), AL, i))
+      WGSize[i] = E;
+    else
       return;
   }
 
-  if (!llvm::all_of(WGSize, [](uint32_t Size) { return Size == 0; })) {
+  auto IsZero = [&](Expr *E) {
+    if (E->isValueDependent())
+      return false;
+    std::optional<llvm::APSInt> I = E->getIntegerConstantExpr(S.Context);
+    assert(I && "Non-integer constant expr");
+    return I->isZero();
+  };
+
+  if (!llvm::all_of(WGSize, IsZero)) {
     for (unsigned i = 0; i < 3; ++i) {
       const Expr *E = AL.getArgAsExpr(i);
-      if (WGSize[i] == 0) {
+      if (IsZero(WGSize[i])) {
         S.Diag(AL.getLoc(), diag::err_attribute_argument_is_zero)
             << AL << E->getSourceRange();
         return;
@@ -2936,10 +2985,22 @@ static void handleWorkGroupSize(Sema &S, Decl *D, const 
ParsedAttr &AL) {
     }
   }
 
+  auto Equal = [&](Expr *LHS, Expr *RHS) {
+    if (LHS->isValueDependent() || RHS->isValueDependent())
+      return true;
+    std::optional<llvm::APSInt> L = LHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    std::optional<llvm::APSInt> R = RHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    return L == R;
+  };
+
   WorkGroupAttr *Existing = D->getAttr<WorkGroupAttr>();
-  if (Existing && !(Existing->getXDim() == WGSize[0] &&
-                    Existing->getYDim() == WGSize[1] &&
-                    Existing->getZDim() == WGSize[2]))
+  if (Existing &&
+      !llvm::equal(std::initializer_list<Expr *>{Existing->getXDim(),
+                                                 Existing->getYDim(),
+                                                 Existing->getZDim()},
+                   WGSize, Equal))
     S.Diag(AL.getLoc(), diag::warn_duplicate_attribute) << AL;
 
   D->addAttr(::new (S.Context)
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp 
b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 170c0b0d39f86..d5330430a3b73 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -572,6 +572,32 @@ static void 
instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
   S.AMDGPU().addAMDGPUFlatWorkGroupSizeAttr(New, Attr, MinExpr, MaxExpr);
 }
 
+static void instantiateDependentReqdWorkGroupSizeAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const ReqdWorkGroupSizeAttr &Attr, Decl *New) {
+  // Both min and max expression are constant expressions.
+  EnterExpressionEvaluationContext Unevaluated(
+      S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+  ExprResult Result = S.SubstExpr(Attr.getXDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *X = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getYDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Y = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getZDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Z = Result.getAs<Expr>();
+
+  ASTContext &Context = S.getASTContext();
+  New->addAttr(::new (Context) ReqdWorkGroupSizeAttr(Context, Attr, X, Y, Z));
+}
+
 ExplicitSpecifier Sema::instantiateExplicitSpecifier(
     const MultiLevelTemplateArgumentList &TemplateArgs, ExplicitSpecifier ES) {
   if (!ES.getExpr())
@@ -812,6 +838,12 @@ void Sema::InstantiateAttrs(const 
MultiLevelTemplateArgumentList &TemplateArgs,
       continue;
     }
 
+    if (const auto *ReqdWorkGroupSize =
+            dyn_cast<ReqdWorkGroupSizeAttr>(TmplAttr)) {
+      instantiateDependentReqdWorkGroupSizeAttr(*this, TemplateArgs,
+                                                *ReqdWorkGroupSize, New);
+    }
+
     if (const auto *AMDGPUFlatWorkGroupSize =
             dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
       instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
diff --git a/clang/test/SemaCUDA/spirv-attrs-diag.cu 
b/clang/test/SemaCUDA/spirv-attrs-diag.cu
new file mode 100644
index 0000000000000..033272a21bc2b
--- /dev/null
+++ b/clang/test/SemaCUDA/spirv-attrs-diag.cu
@@ -0,0 +1,34 @@
+// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
+// RUN:   -fcuda-is-device -verify -fsyntax-only %s
+
+#include "Inputs/cuda.h"
+
+__attribute__((reqd_work_group_size(0x100000000, 1, 1))) // expected-error 
{{integer constant expression evaluates to value 4294967296 that cannot be 
represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg1(void);
+
+__attribute__((work_group_size_hint(0x100000000, 1, 1))) // expected-error 
{{integer constant expression evaluates to value 4294967296 that cannot be 
represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg2(void);
+
+template <int... Args>
+__attribute__((reqd_work_group_size(Args))) // expected-error {{expression 
contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs1(void) {}
+
+template <int... Args>
+__attribute__((work_group_size_hint(Args))) // expected-error {{expression 
contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs2(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((reqd_work_group_size(a, 1, 1))) // expected-error {{'a' does 
not refer to a value}}
+__global__ void TestTemplateArgClass1(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((work_group_size_hint(a, 1, 1))) // expected-error {{'a' does 
not refer to a value}}
+__global__ void TestTemplateArgClass2(void) {}
+
+constexpr int A = 512;
+
+__attribute__((reqd_work_group_size(A, A, A)))
+__global__ void TestConstIntArg1(void) {}
+
+__attribute__((work_group_size_hint(A, A, A)))
+__global__ void TestConstIntArg2(void) {}
diff --git a/clang/test/SemaCUDA/spirv-attrs.cu 
b/clang/test/SemaCUDA/spirv-attrs.cu
index 6539421423ee1..355ed43550c16 100644
--- a/clang/test/SemaCUDA/spirv-attrs.cu
+++ b/clang/test/SemaCUDA/spirv-attrs.cu
@@ -8,9 +8,23 @@
 __attribute__((reqd_work_group_size(128, 1, 1)))
 __global__ void reqd_work_group_size_128_1_1() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((reqd_work_group_size(a, b, c)))
+__global__ void reqd_work_group_size_a_b_c() {}
+
+template <>
+__global__ void reqd_work_group_size_a_b_c<128,1,1>(void);
+
 __attribute__((work_group_size_hint(2, 2, 2)))
 __global__ void work_group_size_hint_2_2_2() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((work_group_size_hint(a, b, c)))
+__global__ void work_group_size_hint_a_b_c() {}
+
+template <>
+__global__ void work_group_size_hint_a_b_c<128,1,1>(void);
+
 __attribute__((vec_type_hint(int)))
 __global__ void vec_type_hint_int() {}
 

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to