https://github.com/skatrak updated 
https://github.com/llvm/llvm-project/pull/150926

>From 6a81001ec131371c981e789f3dd402fb277e3c62 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safon...@amd.com>
Date: Fri, 4 Jul 2025 16:32:03 +0100
Subject: [PATCH] [OpenMP][OMPIRBuilder] Support parallel in Generic kernels

This patch introduces codegen logic to produce a wrapper function argument for
the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed
using device shared memory in Generic mode.
---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 100 ++++++++++++++++--
 .../LLVMIR/omptarget-parallel-llvm.mlir       |  25 ++++-
 2 files changed, 116 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a913958c0de9a..0005a72e86324 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
   return Error::success();
 }
 
+// Create wrapper function used to gather the outlined function's argument
+// structure from a shared buffer and to forward them to it when running in
+// Generic mode.
+//
+// The outlined function is expected to receive 2 integer arguments followed by
+// an optional pointer argument to an argument structure holding the rest.
+static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
+                                             Function &OutlinedFn) {
+  size_t NumArgs = OutlinedFn.arg_size();
+  assert((NumArgs == 2 || NumArgs == 3) &&
+         "expected a 2-3 argument parallel outlined function");
+  bool UseArgStruct = NumArgs == 3;
+
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  IRBuilder<>::InsertPointGuard IPG(Builder);
+  auto *FnTy = FunctionType::get(Builder.getVoidTy(),
+                                 {Builder.getInt16Ty(), Builder.getInt32Ty()},
+                                 /*isVarArg=*/false);
+  auto *WrapperFn =
+      Function::Create(FnTy, GlobalValue::InternalLinkage,
+                       OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
+
+  WrapperFn->addParamAttr(0, Attribute::NoUndef);
+  WrapperFn->addParamAttr(0, Attribute::ZExt);
+  WrapperFn->addParamAttr(1, Attribute::NoUndef);
+
+  BasicBlock *EntryBB =
+      BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
+  Builder.SetInsertPoint(EntryBB);
+
+  // Allocation.
+  Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+                                           /*ArraySize=*/nullptr, "addr");
+  AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+      AddrAlloca->getName() + ".ascast");
+
+  Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+                                           /*ArraySize=*/nullptr, "zero");
+  ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+      ZeroAlloca->getName() + ".ascast");
+
+  Value *ArgsAlloca = nullptr;
+  if (UseArgStruct) {
+    ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
+                                      /*ArraySize=*/nullptr, "global_args");
+    ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+        ArgsAlloca->getName() + ".ascast");
+  }
+
+  // Initialization.
+  Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
+  Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
+  if (UseArgStruct) {
+    Builder.CreateCall(
+        OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
+            llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
+        {ArgsAlloca});
+  }
+
+  SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
+
+  // Load structArg from global_args.
+  if (UseArgStruct) {
+    Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
+    StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
+                                          {Builder.getInt64(0)});
+    StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
+    Args.push_back(StructArg);
+  }
+
+  // Call the outlined function holding the parallel body.
+  Builder.CreateCall(&OutlinedFn, Args);
+  Builder.CreateRetVoid();
+
+  return WrapperFn;
+}
+
 // Callback used to create OpenMP runtime calls to support
 // omp parallel clause for the device.
 // We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1332,6 +1412,10 @@ static void targetParallelCallback(
     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
+  assert(OutlinedFn.arg_size() >= 2 &&
+         "Expected at least tid and bounded tid as arguments");
+  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
   // Add some known attributes.
   IRBuilder<> &Builder = OMPIRBuilder->Builder;
   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1340,17 +1424,12 @@ static void targetParallelCallback(
   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
   OutlinedFn.addFnAttr(Attribute::NoUnwind);
 
-  assert(OutlinedFn.arg_size() >= 2 &&
-         "Expected at least tid and bounded tid as arguments");
-  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
-
   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
   assert(CI && "Expected call instruction to outlined function");
   CI->getParent()->setName("omp_parallel");
 
   Builder.SetInsertPoint(CI);
   Type *PtrTy = OMPIRBuilder->VoidPtr;
-  Value *NullPtrValue = Constant::getNullValue(PtrTy);
 
   // Add alloca for kernel args
   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1376,6 +1455,15 @@ static void targetParallelCallback(
       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
                   : Builder.getInt32(1);
 
+  // If this is not a Generic kernel, we can skip generating the wrapper.
+  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+      getTargetKernelExecMode(*OuterFn);
+  Value *WrapperFn;
+  if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
+    WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
+  else
+    WrapperFn = Constant::getNullValue(PtrTy);
+
   // Build kmpc_parallel_51 call
   Value *Parallel51CallArgs[] = {
       /* identifier*/ Ident,
@@ -1384,7 +1472,7 @@ static void targetParallelCallback(
       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
       /* Proc bind */ Builder.getInt32(-1),
       /* outlined function */ &OutlinedFn,
-      /* wrapper function */ NullPtrValue,
+      /* wrapper function */ WrapperFn,
       /* arguments of the outlined funciton*/ Args,
       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir 
b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 504e39c96f008..ca998b4672ba0 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr 
%[[TMP2]], i64 0, i64 0
 // CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
-// CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr 
addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, 
i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr 
addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, 
i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
 // CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (
 // CHECK-SAME:  ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
 // CHECK-SAME:  i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
-// CHECK-SAME:  i32 -1,  ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr 
[[NUM_THREADS_TMP1:%.*]], i64 1)
+// CHECK-SAME:  i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr 
@[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
 
 // One of the arguments of  kmpc_parallel_51 function is responsible for 
handling if clause
 // of omp parallel construct for target region. If this  argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = 
#dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (
 // CHECK-SAME:  ptr addrspace(1) {{.*}} to ptr),
 // CHECK-SAME:  i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
-// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
+// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
+
+// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, 
i32 noundef %[[ADDR:.*]])
+// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) 
%[[ADDR_ALLOCA]] to ptr
+// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) 
%[[ZERO_ALLOCA]] to ptr
+// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
+// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) 
%[[ARGS_ALLOCA]] to ptr
+// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
+// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
+// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
+// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
+// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], 
i64 0
+// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
+// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr 
%[[STRUCTARG]])
+
+// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, 
i32 noundef %{{.*}})
+// CHECK-NOT: define
+// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to