================ @@ -5212,6 +5273,78 @@ static Function *createOutlinedFunction( return Func; } +// Create an entry point for a target task with the following. +// It'll have the following signature +// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task) +// This function is called from emitTargetTask once the +// code to launch the target kernel has been outlined already. +static Function *emitProxyTaskFunction(OpenMPIRBuilder &OMPBuilder, + IRBuilderBase &Builder, + CallInst *StaleCI) { + Module &M = OMPBuilder.M; + // CalledFunction is the target launch function, i.e. + // the function that sets up kernel arguments and calls + // __tgt_target_kernel to launch the kernel on the device. + Function *CalledFunction = StaleCI->getCalledFunction(); + OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(), + StaleCI->getIterator()); + LLVMContext &Ctx = StaleCI->getParent()->getContext(); + Type *ThreadIDTy = Type::getInt32Ty(Ctx); + Type *TaskPtrTy = OMPBuilder.TaskPtr; + Type *TaskTy = OMPBuilder.Task; + auto ProxyFnTy = + FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy}, + /* isVarArg */ false); + auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage, + ".omp_target_task_proxy_func", + Builder.GetInsertBlock()->getModule()); + + BasicBlock *EntryBB = + BasicBlock::Create(Builder.getContext(), "entry", ProxyFn); + Builder.SetInsertPoint(EntryBB); + + bool HasShareds = StaleCI->arg_size() > 1; + // TODO: This is a temporary assert to prove to ourselves that + // the outlined target launch function is always going to have + // atmost two arguments if there is any data shared between + // host and device. + assert((!HasShareds || (StaleCI->arg_size() == 2)) && + "StaleCI with shareds should have exactly two arguments."); + if (HasShareds) { + AllocaInst *ArgStructAlloca = + dyn_cast<AllocaInst>(StaleCI->getArgOperand(1)); + assert(ArgStructAlloca && + "Unable to find the alloca instruction corresponding to arguments " + "for extracted function"); + StructType *ArgStructType = + dyn_cast<StructType>(ArgStructAlloca->getAllocatedType()); + LLVM_DEBUG(dbgs() << "ArgStructType = " << *ArgStructType << "\n"); + + AllocaInst *NewArgStructAlloca = + Builder.CreateAlloca(ArgStructType, nullptr, "structArg"); + Value *TaskT = ProxyFn->getArg(1); + Value *ThreadId = ProxyFn->getArg(0); + LLVM_DEBUG(dbgs() << "TaskT = " << *TaskT << "\n"); + Value *SharedsSize = + Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); + + Value *Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0); + LoadInst *LoadShared = + Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds); + + // TODO: Are these alignment values correct? + Builder.CreateMemCpy( + NewArgStructAlloca, + NewArgStructAlloca->getPointerAlignment(M.getDataLayout()), LoadShared, + LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize); + + Builder.CreateCall(CalledFunction, {ThreadId, NewArgStructAlloca}); + } + ProxyFn->getArg(0)->setName("thread.id"); + ProxyFn->getArg(1)->setName("task"); ---------------- ergawy wrote:
nit: move these closer to where `ProxyFn` is created? https://github.com/llvm/llvm-project/pull/93977 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits