================
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
                           ".omp.reduction.func", &M);
 }
 
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
-                                  InsertPointTy AllocaIP,
-                                  ArrayRef<ReductionInfo> ReductionInfos,
-                                  ArrayRef<bool> IsByRef, bool IsNoWait) {
-  assert(ReductionInfos.size() == IsByRef.size());
-  for (const ReductionInfo &RI : ReductionInfos) {
-    (void)RI;
-    assert(RI.Variable && "expected non-null variable");
-    assert(RI.PrivateVariable && "expected non-null private variable");
-    assert(RI.ReductionGen && "expected non-null reduction generator 
callback");
-    assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
-           "expected variables and their private equivalents to have the same "
-           "type");
-    assert(RI.Variable->getType()->isPointerTy() &&
-           "expected variables to be pointers");
+static Error populateReductionFunction(
+    Function *ReductionFunc,
+    ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+    IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
+  Module *Module = ReductionFunc->getParent();
+  BasicBlock *ReductionFuncBlock =
+      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
+  Builder.SetInsertPoint(ReductionFuncBlock);
+  Value *LHSArrayPtr = nullptr;
+  Value *RHSArrayPtr = nullptr;
+  if (IsGPU) {
+    // Need to alloca memory here and deal with the pointers before getting
+    // LHS/RHS pointers out
+    //
+    Argument *Arg0 = ReductionFunc->getArg(0);
+    Argument *Arg1 = ReductionFunc->getArg(1);
+    Type *Arg0Type = Arg0->getType();
+    Type *Arg1Type = Arg1->getType();
+
+    Value *LHSAlloca =
+        Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+    Value *RHSAlloca =
+        Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+    Value *LHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
+    Value *RHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
+    Builder.CreateStore(Arg0, LHSAddrCast);
+    Builder.CreateStore(Arg1, RHSAddrCast);
+    LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+    RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+  } else {
+    LHSArrayPtr = ReductionFunc->getArg(0);
+    RHSArrayPtr = ReductionFunc->getArg(1);
   }
 
+  unsigned NumReductions = ReductionInfos.size();
+  Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, LHSArrayPtr, 0, En.index());
+    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+    Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSI8Ptr, RI.Variable->getType());
+    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, RHSArrayPtr, 0, En.index());
+    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+    Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        RHSI8Ptr, RI.PrivateVariable->getType());
+    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+    Value *Reduced;
+    OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+    if (!AfterIP)
+      return AfterIP.takeError();
+
+    Builder.restoreIP(*AfterIP);
+    // TODO: Consider flagging an error.
+    if (!Builder.GetInsertBlock())
+      return Error::success();
----------------
jsjodin wrote:

I'm going to skip this for now. I have some thoughts how to improve the IP 
handling, but we can discuss that separately.

https://github.com/llvm/llvm-project/pull/133310
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to