https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/150926
>From 6a81001ec131371c981e789f3dd402fb277e3c62 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Fri, 4 Jul 2025 16:32:03 +0100 Subject: [PATCH] [OpenMP][OMPIRBuilder] Support parallel in Generic kernels This patch introduces codegen logic to produce a wrapper function argument for the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode. --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 100 ++++++++++++++++-- .../LLVMIR/omptarget-parallel-llvm.mlir | 25 ++++- 2 files changed, 116 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index a913958c0de9a..0005a72e86324 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl( return Error::success(); } +// Create wrapper function used to gather the outlined function's argument +// structure from a shared buffer and to forward them to it when running in +// Generic mode. +// +// The outlined function is expected to receive 2 integer arguments followed by +// an optional pointer argument to an argument structure holding the rest. +static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder, + Function &OutlinedFn) { + size_t NumArgs = OutlinedFn.arg_size(); + assert((NumArgs == 2 || NumArgs == 3) && + "expected a 2-3 argument parallel outlined function"); + bool UseArgStruct = NumArgs == 3; + + IRBuilder<> &Builder = OMPIRBuilder->Builder; + IRBuilder<>::InsertPointGuard IPG(Builder); + auto *FnTy = FunctionType::get(Builder.getVoidTy(), + {Builder.getInt16Ty(), Builder.getInt32Ty()}, + /*isVarArg=*/false); + auto *WrapperFn = + Function::Create(FnTy, GlobalValue::InternalLinkage, + OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M); + + WrapperFn->addParamAttr(0, Attribute::NoUndef); + WrapperFn->addParamAttr(0, Attribute::ZExt); + WrapperFn->addParamAttr(1, Attribute::NoUndef); + + BasicBlock *EntryBB = + BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn); + Builder.SetInsertPoint(EntryBB); + + // Allocation. + Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), + /*ArraySize=*/nullptr, "addr"); + AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast( + AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0), + AddrAlloca->getName() + ".ascast"); + + Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), + /*ArraySize=*/nullptr, "zero"); + ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast( + ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0), + ZeroAlloca->getName() + ".ascast"); + + Value *ArgsAlloca = nullptr; + if (UseArgStruct) { + ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(), + /*ArraySize=*/nullptr, "global_args"); + ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast( + ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0), + ArgsAlloca->getName() + ".ascast"); + } + + // Initialization. + Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca); + Builder.CreateStore(Builder.getInt32(0), ZeroAlloca); + if (UseArgStruct) { + Builder.CreateCall( + OMPIRBuilder->getOrCreateRuntimeFunctionPtr( + llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables), + {ArgsAlloca}); + } + + SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca}; + + // Load structArg from global_args. + if (UseArgStruct) { + Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca); + StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg, + {Builder.getInt64(0)}); + StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg"); + Args.push_back(StructArg); + } + + // Call the outlined function holding the parallel body. + Builder.CreateCall(&OutlinedFn, Args); + Builder.CreateRetVoid(); + + return WrapperFn; +} + // Callback used to create OpenMP runtime calls to support // omp parallel clause for the device. // We need to use this callback to replace call to the OutlinedFn in OuterFn @@ -1332,6 +1412,10 @@ static void targetParallelCallback( BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition, Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr, Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) { + assert(OutlinedFn.arg_size() >= 2 && + "Expected at least tid and bounded tid as arguments"); + unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2; + // Add some known attributes. IRBuilder<> &Builder = OMPIRBuilder->Builder; OutlinedFn.addParamAttr(0, Attribute::NoAlias); @@ -1340,17 +1424,12 @@ static void targetParallelCallback( OutlinedFn.addParamAttr(1, Attribute::NoUndef); OutlinedFn.addFnAttr(Attribute::NoUnwind); - assert(OutlinedFn.arg_size() >= 2 && - "Expected at least tid and bounded tid as arguments"); - unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2; - CallInst *CI = cast<CallInst>(OutlinedFn.user_back()); assert(CI && "Expected call instruction to outlined function"); CI->getParent()->setName("omp_parallel"); Builder.SetInsertPoint(CI); Type *PtrTy = OMPIRBuilder->VoidPtr; - Value *NullPtrValue = Constant::getNullValue(PtrTy); // Add alloca for kernel args OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP(); @@ -1376,6 +1455,15 @@ static void targetParallelCallback( IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32) : Builder.getInt32(1); + // If this is not a Generic kernel, we can skip generating the wrapper. + std::optional<omp::OMPTgtExecModeFlags> ExecMode = + getTargetKernelExecMode(*OuterFn); + Value *WrapperFn; + if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) + WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn); + else + WrapperFn = Constant::getNullValue(PtrTy); + // Build kmpc_parallel_51 call Value *Parallel51CallArgs[] = { /* identifier*/ Ident, @@ -1384,7 +1472,7 @@ static void targetParallelCallback( /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1), /* Proc bind */ Builder.getInt32(-1), /* outlined function */ &OutlinedFn, - /* wrapper function */ NullPtrValue, + /* wrapper function */ WrapperFn, /* arguments of the outlined funciton*/ Args, /* number of arguments */ Builder.getInt64(NumCapturedVars)}; diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir index 504e39c96f008..ca998b4672ba0 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir @@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8 // CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0 // CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8 -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1) +// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1) // CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8) // CHECK: call void @__kmpc_target_deinit() @@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr), // CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156, -// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1) +// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1) // One of the arguments of kmpc_parallel_51 function is responsible for handling if clause // of omp parallel construct for target region. If this argument is nonzero, @@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) {{.*}} to ptr), // CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1, -// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1) +// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1) + +// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]]) +// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5) +// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr +// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5) +// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr +// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5) +// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr +// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]] +// CHECK: store i32 0, ptr %[[ZERO_ASCAST]] +// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]]) +// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8 +// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0 +// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8 +// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]]) + +// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}}) +// CHECK-NOT: define +// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}}) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits