================
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
                           ".omp.reduction.func", &M);
 }
 
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
-                                  InsertPointTy AllocaIP,
-                                  ArrayRef<ReductionInfo> ReductionInfos,
-                                  ArrayRef<bool> IsByRef, bool IsNoWait) {
-  assert(ReductionInfos.size() == IsByRef.size());
-  for (const ReductionInfo &RI : ReductionInfos) {
-    (void)RI;
-    assert(RI.Variable && "expected non-null variable");
-    assert(RI.PrivateVariable && "expected non-null private variable");
-    assert(RI.ReductionGen && "expected non-null reduction generator 
callback");
-    assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
-           "expected variables and their private equivalents to have the same "
-           "type");
-    assert(RI.Variable->getType()->isPointerTy() &&
-           "expected variables to be pointers");
+static Error populateReductionFunction(
+    Function *ReductionFunc,
+    ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+    IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
+  Module *Module = ReductionFunc->getParent();
+  BasicBlock *ReductionFuncBlock =
+      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
+  Builder.SetInsertPoint(ReductionFuncBlock);
+  Value *LHSArrayPtr = nullptr;
+  Value *RHSArrayPtr = nullptr;
+  if (IsGPU) {
+    // Need to alloca memory here and deal with the pointers before getting
+    // LHS/RHS pointers out
+    //
+    Argument *Arg0 = ReductionFunc->getArg(0);
+    Argument *Arg1 = ReductionFunc->getArg(1);
+    Type *Arg0Type = Arg0->getType();
+    Type *Arg1Type = Arg1->getType();
+
+    Value *LHSAlloca =
+        Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+    Value *RHSAlloca =
+        Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+    Value *LHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
+    Value *RHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
+    Builder.CreateStore(Arg0, LHSAddrCast);
+    Builder.CreateStore(Arg1, RHSAddrCast);
+    LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+    RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+  } else {
+    LHSArrayPtr = ReductionFunc->getArg(0);
+    RHSArrayPtr = ReductionFunc->getArg(1);
   }
 
+  unsigned NumReductions = ReductionInfos.size();
+  Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, LHSArrayPtr, 0, En.index());
+    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+    Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSI8Ptr, RI.Variable->getType());
+    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, RHSArrayPtr, 0, En.index());
+    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+    Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        RHSI8Ptr, RI.PrivateVariable->getType());
+    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+    Value *Reduced;
+    OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+    if (!AfterIP)
+      return AfterIP.takeError();
+
+    Builder.restoreIP(*AfterIP);
+    // TODO: Consider flagging an error.
+    if (!Builder.GetInsertBlock())
+      return Error::success();
----------------
skatrak wrote:

If the reduction callback returns an invalid IP, that seems like an 
implementation bug rather than some error condition based on the input the user 
should know about. Maybe we should just assert in this case, since there are 
plenty of other places where we assume the returned IP is valid if the callback 
didn't return an `llvm::Error`.

Having said that, I can see there are also a few other places where the IP is 
checked, so feel free to leave this if you think it's a better option. At some 
point, we should probably decide on one single approach and apply it 
everywhere, though.

https://github.com/llvm/llvm-project/pull/133310
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to