https://github.com/ssahasra created 
https://github.com/llvm/llvm-project/pull/121738

None

>From e1611a9dbfe7a8239b93b84fa7682e68dc727f0f Mon Sep 17 00:00:00 2001
From: Sameer Sahasrabuddhe <sameer.sahasrabud...@amd.com>
Date: Mon, 6 Jan 2025 14:01:49 +0530
Subject: [PATCH] [clang][NFC] clean up the handling of convergence control
 tokens

---
 clang/lib/CodeGen/CGCall.cpp        |  4 +--
 clang/lib/CodeGen/CGStmt.cpp        | 46 +++++++++++++----------------
 clang/lib/CodeGen/CodeGenFunction.h | 23 +++++----------
 3 files changed, 29 insertions(+), 44 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f139c30f3dfd44..89e2eace9120bf 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4871,7 +4871,7 @@ llvm::CallInst 
*CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   call->setCallingConv(getRuntimeCC());
 
   if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
-    return addControlledConvergenceToken(call);
+    return cast<llvm::CallInst>(addConvergenceControlToken(call));
   return call;
 }
 
@@ -5787,7 +5787,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo 
&CallInfo,
     CI->setName("call");
 
   if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
-    CI = addControlledConvergenceToken(CI);
+    CI = addConvergenceControlToken(CI);
 
   // Update largest vector width from the return type.
   LargestVectorWidth =
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 3974739d2abb47..7904e17dbebb81 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -1024,8 +1024,8 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   EmitBlock(LoopHeader.getBlock());
 
   if (CGM.shouldEmitConvergenceTokens())
-    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
-        LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(LoopHeader.getBlock()));
 
   // Create an exit block for when the condition fails, which will
   // also become the break target.
@@ -1152,8 +1152,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
     EmitBlockWithFallThrough(LoopBody, &S);
 
   if (CGM.shouldEmitConvergenceTokens())
-    ConvergenceTokenStack.push_back(
-        emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(LoopBody));
 
   {
     RunCleanupsScope BodyScope(*this);
@@ -1231,8 +1230,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   EmitBlock(CondBlock);
 
   if (CGM.shouldEmitConvergenceTokens())
-    ConvergenceTokenStack.push_back(
-        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(CondBlock));
 
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
@@ -1369,8 +1367,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const 
CXXForRangeStmt &S,
   EmitBlock(CondBlock);
 
   if (CGM.shouldEmitConvergenceTokens())
-    ConvergenceTokenStack.push_back(
-        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(CondBlock));
 
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
@@ -3245,35 +3242,32 @@ CodeGenFunction::GenerateCapturedStmtFunction(const 
CapturedStmt &S) {
   return F;
 }
 
-namespace {
 // Returns the first convergence entry/loop/anchor instruction found in |BB|.
 // std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+static llvm::ConvergenceControlInst *getConvergenceToken(llvm::BasicBlock *BB) 
{
   for (auto &I : *BB) {
-    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
-    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
-      return II;
+    if (auto *CI = dyn_cast<llvm::ConvergenceControlInst>(&I))
+      return CI;
   }
   return nullptr;
 }
 
-} // namespace
-
 llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken) {
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input) {
+  llvm::ConvergenceControlInst *ParentToken = ConvergenceTokenStack.back();
+  assert(ParentToken);
+
   llvm::Value *bundleArgs[] = {ParentToken};
   llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
-  auto Output = llvm::CallBase::addOperandBundle(
+  auto *Output = llvm::CallBase::addOperandBundle(
       Input, llvm::LLVMContext::OB_convergencectrl, OB, Input->getIterator());
   Input->replaceAllUsesWith(Output);
   Input->eraseFromParent();
   return Output;
 }
 
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                          llvm::Value *ParentToken) {
+llvm::ConvergenceControlInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB) {
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
   if (BB->empty())
     Builder.SetInsertPoint(BB);
@@ -3284,14 +3278,14 @@ 
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
       llvm::Intrinsic::experimental_convergence_loop, {}, {});
   Builder.restoreIP(IP);
 
-  llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
-  return cast<llvm::IntrinsicInst>(I);
+  CB = addConvergenceControlToken(CB);
+  return cast<llvm::ConvergenceControlInst>(CB);
 }
 
-llvm::IntrinsicInst *
+llvm::ConvergenceControlInst *
 CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
   llvm::BasicBlock *BB = &F->getEntryBlock();
-  llvm::IntrinsicInst *Token = getConvergenceToken(BB);
+  llvm::ConvergenceControlInst *Token = getConvergenceToken(BB);
   if (Token)
     return Token;
 
@@ -3306,5 +3300,5 @@ 
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
   assert(isa<llvm::IntrinsicInst>(I));
   Builder.restoreIP(IP);
 
-  return cast<llvm::IntrinsicInst>(I);
+  return cast<llvm::ConvergenceControlInst>(I);
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.h 
b/clang/lib/CodeGen/CodeGenFunction.h
index 1a5c42f8f974d0..46f267983edad1 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -315,7 +315,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
 
   /// Stack to track the controlled convergence tokens.
-  SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+  SmallVector<llvm::ConvergenceControlInst *, 4> ConvergenceTokenStack;
 
   /// Number of nested loop to be consumed by the last surrounding
   /// loop-associated directive.
@@ -5234,29 +5234,20 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *emitBoolVecConversion(llvm::Value *SrcVec,
                                      unsigned NumElementsDst,
                                      const llvm::Twine &Name = "");
-  // Adds a convergence_ctrl token to |Input| and emits the required parent
-  // convergence instructions.
-  template <typename CallType>
-  CallType *addControlledConvergenceToken(CallType *Input) {
-    return cast<CallType>(
-        addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
-  }
 
 private:
   // Emits a convergence_loop instruction for the given |BB|, with 
|ParentToken|
   // as it's parent convergence instr.
-  llvm::IntrinsicInst *emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                                llvm::Value *ParentToken);
+  llvm::ConvergenceControlInst *emitConvergenceLoopToken(llvm::BasicBlock *BB);
+
   // Adds a convergence_ctrl token with |ParentToken| as parent convergence
   // instr to the call |Input|.
-  llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input,
-                                             llvm::Value *ParentToken);
+  llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input);
+
   // Find the convergence_entry instruction |F|, or emits ones if none exists.
   // Returns the convergence instruction.
-  llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
-  // Find the convergence_loop instruction for the loop defined by |LI|, or
-  // emits one if none exists. Returns the convergence instruction.
-  llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
+  llvm::ConvergenceControlInst *
+  getOrEmitConvergenceEntryToken(llvm::Function *F);
 
 private:
   llvm::MDNode *getRangeForLoadFromType(QualType Ty);

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to