================ @@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) { ".omp.reduction.func", &M); } -OpenMPIRBuilder::InsertPointOrErrorTy -OpenMPIRBuilder::createReductions(const LocationDescription &Loc, - InsertPointTy AllocaIP, - ArrayRef<ReductionInfo> ReductionInfos, - ArrayRef<bool> IsByRef, bool IsNoWait) { - assert(ReductionInfos.size() == IsByRef.size()); - for (const ReductionInfo &RI : ReductionInfos) { - (void)RI; - assert(RI.Variable && "expected non-null variable"); - assert(RI.PrivateVariable && "expected non-null private variable"); - assert(RI.ReductionGen && "expected non-null reduction generator callback"); - assert(RI.Variable->getType() == RI.PrivateVariable->getType() && - "expected variables and their private equivalents to have the same " - "type"); - assert(RI.Variable->getType()->isPointerTy() && - "expected variables to be pointers"); +static Error populateReductionFunction( + Function *ReductionFunc, + ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos, + IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) { + Module *Module = ReductionFunc->getParent(); + BasicBlock *ReductionFuncBlock = + BasicBlock::Create(Module->getContext(), "", ReductionFunc); + Builder.SetInsertPoint(ReductionFuncBlock); + Value *LHSArrayPtr = nullptr; + Value *RHSArrayPtr = nullptr; + if (IsGPU) { + // Need to alloca memory here and deal with the pointers before getting + // LHS/RHS pointers out + // + Argument *Arg0 = ReductionFunc->getArg(0); + Argument *Arg1 = ReductionFunc->getArg(1); + Type *Arg0Type = Arg0->getType(); + Type *Arg1Type = Arg1->getType(); + + Value *LHSAlloca = + Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr"); + Value *RHSAlloca = + Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr"); + Value *LHSAddrCast = + Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type); + Value *RHSAddrCast = + Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type); + Builder.CreateStore(Arg0, LHSAddrCast); + Builder.CreateStore(Arg1, RHSAddrCast); + LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast); + RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast); + } else { + LHSArrayPtr = ReductionFunc->getArg(0); + RHSArrayPtr = ReductionFunc->getArg(1); } + unsigned NumReductions = ReductionInfos.size(); + Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions); + + for (auto En : enumerate(ReductionInfos)) { + const OpenMPIRBuilder::ReductionInfo &RI = En.value(); + Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( + RedArrayTy, LHSArrayPtr, 0, En.index()); + Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr); + Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( + LHSI8Ptr, RI.Variable->getType()); + Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr); + Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( + RedArrayTy, RHSArrayPtr, 0, En.index()); + Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr); + Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( + RHSI8Ptr, RI.PrivateVariable->getType()); + Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr); + Value *Reduced; + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced); + if (!AfterIP) + return AfterIP.takeError(); + + Builder.restoreIP(*AfterIP); + // TODO: Consider flagging an error. + if (!Builder.GetInsertBlock()) + return Error::success(); ---------------- jsjodin wrote:
I'm going to skip this for now. I have some thoughts how to improve the IP handling, but we can discuss that separately. https://github.com/llvm/llvm-project/pull/133310 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits