ggeorgakoudis created this revision.
Herald added subscribers: guansong, yaxunl.
ggeorgakoudis requested review of this revision.
Herald added a reviewer: jdoerfert.
Herald added subscribers: cfe-commits, sstefan1.
Herald added a project: clang.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D102107

Files:
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/CodeGen/CodeGenFunction.h

Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -3276,8 +3276,13 @@
   llvm::Function *EmitCapturedStmt(const CapturedStmt &S, CapturedRegionKind K);
   llvm::Function *GenerateCapturedStmtFunction(const CapturedStmt &S);
   Address GenerateCapturedStmtArgument(const CapturedStmt &S);
+  llvm::Function *
+  GenerateOpenMPCapturedStmtFunctionAggregate(const CapturedStmt &S,
+                                              SourceLocation Loc);
   llvm::Function *GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
                                                      SourceLocation Loc);
+  void GenerateOpenMPCapturedVarsAggregate(
+      const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars);
   void GenerateOpenMPCapturedVars(const CapturedStmt &S,
                                   SmallVectorImpl<llvm::Value *> &CapturedVars);
   void emitOMPSimpleStore(LValue LVal, RValue RVal, QualType RValTy,
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -318,6 +318,32 @@
   return CGM.getSize(SizeInChars);
 }
 
+void CodeGenFunction::GenerateOpenMPCapturedVarsAggregate(
+    const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars) {
+  const RecordDecl *RD = S.getCapturedRecordDecl();
+  QualType RecordTy = getContext().getRecordType(RD);
+  // Create the aggregate argument struct for the outlined function.
+  LValue AggLV = MakeAddrLValue(
+      CreateMemTemp(RecordTy, "omp.outlined.arg.agg."), RecordTy);
+
+  // Initialize the aggregate with captured values.
+  auto CurField = RD->field_begin();
+  for (CapturedStmt::const_capture_init_iterator I = S.capture_init_begin(),
+                                                 E = S.capture_init_end();
+       I != E; ++I, ++CurField) {
+    LValue LV = EmitLValueForFieldInitialization(AggLV, *CurField);
+    // Initialize for VLA.
+    if (CurField->hasCapturedVLAType()) {
+      EmitLambdaVLACapture(CurField->getCapturedVLAType(), LV);
+    } else
+      // Initialize for capturesThis, capturesVariableByCopy,
+      // capturesVariable
+      EmitInitializerForField(*CurField, LV, *I);
+  }
+
+  CapturedVars.push_back(AggLV.getPointer(*this));
+}
+
 void CodeGenFunction::GenerateOpenMPCapturedVars(
     const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars) {
   const RecordDecl *RD = S.getCapturedRecordDecl();
@@ -418,6 +444,101 @@
 };
 } // namespace
 
+static llvm::Function *emitOutlinedFunctionPrologueAggregate(
+    CodeGenFunction &CGF, FunctionArgList &Args,
+    llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
+        &LocalAddrs,
+    llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>>
+        &VLASizes,
+    llvm::Value *&CXXThisValue, const CapturedStmt &CS, SourceLocation Loc,
+    StringRef FunctionName) {
+  const CapturedDecl *CD = CS.getCapturedDecl();
+  const RecordDecl *RD = CS.getCapturedRecordDecl();
+  assert(CD->hasBody() && "missing CapturedDecl body");
+
+  CXXThisValue = nullptr;
+  // Build the argument list.
+  CodeGenModule &CGM = CGF.CGM;
+  ASTContext &Ctx = CGM.getContext();
+  Args.append(CD->param_begin(), CD->param_end());
+
+  // Create the function declaration.
+  const CGFunctionInfo &FuncInfo =
+      CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, Args);
+  llvm::FunctionType *FuncLLVMTy = CGM.getTypes().GetFunctionType(FuncInfo);
+
+  auto *F =
+      llvm::Function::Create(FuncLLVMTy, llvm::GlobalValue::InternalLinkage,
+                             FunctionName, &CGM.getModule());
+  CGM.SetInternalFunctionAttributes(CD, F, FuncInfo);
+  if (CD->isNothrow())
+    F->setDoesNotThrow();
+  F->setDoesNotRecurse();
+
+  // Generate the function.
+  CGF.StartFunction(CD, Ctx.VoidTy, F, FuncInfo, Args, Loc, Loc);
+  Address ContextAddr = CGF.GetAddrOfLocalVar(CD->getContextParam());
+  llvm::Value *ContextV = CGF.Builder.CreateLoad(ContextAddr);
+  LValue ContextLV = CGF.MakeNaturalAlignAddrLValue(
+      ContextV, CGM.getContext().getTagDeclType(RD));
+  auto I = CS.captures().begin();
+  for (const FieldDecl *FD : RD->fields()) {
+    LValue FieldLV = CGF.EmitLValueForFieldInitialization(ContextLV, FD);
+    // Do not map arguments if we emit function with non-original types.
+    Address LocalAddr = FieldLV.getAddress(CGF);
+    // If we are capturing a pointer by copy we don't need to do anything, just
+    // use the value that we get from the arguments.
+    if (I->capturesVariableByCopy() && FD->getType()->isAnyPointerType()) {
+      const VarDecl *CurVD = I->getCapturedVar();
+      LocalAddrs.insert({FD, {CurVD, LocalAddr}});
+      ++I;
+      continue;
+    }
+
+    LValue ArgLVal =
+        CGF.MakeAddrLValue(LocalAddr, FD->getType(), AlignmentSource::Decl);
+    if (FD->hasCapturedVLAType()) {
+      llvm::Value *ExprArg = CGF.EmitLoadOfScalar(ArgLVal, I->getLocation());
+      const VariableArrayType *VAT = FD->getCapturedVLAType();
+      VLASizes.try_emplace(FD, VAT->getSizeExpr(), ExprArg);
+    } else if (I->capturesVariable()) {
+      const VarDecl *Var = I->getCapturedVar();
+      QualType VarTy = Var->getType();
+      Address ArgAddr = ArgLVal.getAddress(CGF);
+      if (ArgLVal.getType()->isLValueReferenceType()) {
+        ArgAddr = CGF.EmitLoadOfReference(ArgLVal);
+      } else if (!VarTy->isVariablyModifiedType() || !VarTy->isPointerType()) {
+        assert(ArgLVal.getType()->isPointerType());
+        ArgAddr = CGF.EmitLoadOfPointer(
+            ArgAddr, ArgLVal.getType()->castAs<PointerType>());
+      }
+      LocalAddrs.insert(
+          {FD, {Var, Address(ArgAddr.getPointer(), Ctx.getDeclAlign(Var))}});
+    } else if (I->capturesVariableByCopy()) {
+      assert(!FD->getType()->isAnyPointerType() &&
+             "Not expecting a captured pointer.");
+      const VarDecl *Var = I->getCapturedVar();
+      Address CopyAddr = CGF.CreateMemTemp(FD->getType(), Ctx.getDeclAlign(FD),
+                                           Var->getName());
+      LValue CopyLVal =
+          CGF.MakeAddrLValue(CopyAddr, FD->getType(), AlignmentSource::Decl);
+
+      RValue ArgRVal = CGF.EmitLoadOfLValue(ArgLVal, I->getLocation());
+      CGF.EmitStoreThroughLValue(ArgRVal, CopyLVal);
+
+      LocalAddrs.insert({FD, {Var, CopyAddr}});
+    } else {
+      // If 'this' is captured, load it into CXXThisValue.
+      assert(I->capturesThis());
+      CXXThisValue = CGF.EmitLoadOfScalar(ArgLVal, I->getLocation());
+      LocalAddrs.insert({FD, {nullptr, ArgLVal.getAddress(CGF)}});
+    }
+    ++I;
+  }
+
+  return F;
+}
+
 static llvm::Function *emitOutlinedFunctionPrologue(
     CodeGenFunction &CGF, FunctionArgList &Args,
     llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
@@ -593,6 +714,37 @@
   return F;
 }
 
+llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunctionAggregate(
+    const CapturedStmt &S, SourceLocation Loc) {
+  assert(
+      CapturedStmtInfo &&
+      "CapturedStmtInfo should be set when generating the captured function");
+  const CapturedDecl *CD = S.getCapturedDecl();
+  // Build the argument list.
+  FunctionArgList Args;
+  llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
+  llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
+  StringRef FunctionName = CapturedStmtInfo->getHelperName();
+  llvm::Function *F = emitOutlinedFunctionPrologueAggregate(
+      *this, Args, LocalAddrs, VLASizes, CXXThisValue, S, Loc, FunctionName);
+  CodeGenFunction::OMPPrivateScope LocalScope(*this);
+  for (const auto &LocalAddrPair : LocalAddrs) {
+    if (LocalAddrPair.second.first) {
+      LocalScope.addPrivate(LocalAddrPair.second.first, [&LocalAddrPair]() {
+        return LocalAddrPair.second.second;
+      });
+    }
+  }
+  (void)LocalScope.Privatize();
+  for (const auto &VLASizePair : VLASizes)
+    VLASizeMap[VLASizePair.second.first] = VLASizePair.second.second;
+  PGO.assignRegionCounters(GlobalDecl(CD), F);
+  CapturedStmtInfo->EmitBody(*this, CD->getBody());
+  (void)LocalScope.ForceCleanup();
+  FinishFunction(CD->getBodyRBrace());
+  return F;
+}
+
 llvm::Function *
 CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
                                                     SourceLocation Loc) {
@@ -1581,7 +1733,7 @@
   // The following lambda takes care of appending the lower and upper bound
   // parameters when necessary
   CodeGenBoundParameters(CGF, S, CapturedVars);
-  CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
+  CGF.GenerateOpenMPCapturedVarsAggregate(*CS, CapturedVars);
   CGF.CGM.getOpenMPRuntime().emitParallelCall(CGF, S.getBeginLoc(), OutlinedFn,
                                               CapturedVars, IfCond);
 }
@@ -5298,7 +5450,7 @@
     const CapturedStmt *CS = S.getInnermostCapturedStmt();
     if (C) {
       llvm::SmallVector<llvm::Value *, 16> CapturedVars;
-      CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
+      CGF.GenerateOpenMPCapturedVarsAggregate(*CS, CapturedVars);
       llvm::Function *OutlinedFn =
           emitOutlinedOrderedFunction(CGM, CS, S.getBeginLoc());
       CGM.getOpenMPRuntime().emitOutlinedFunctionCall(CGF, S.getBeginLoc(),
@@ -6029,7 +6181,7 @@
 
   OMPTeamsScope Scope(CGF, S);
   llvm::SmallVector<llvm::Value *, 16> CapturedVars;
-  CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
+  CGF.GenerateOpenMPCapturedVarsAggregate(*CS, CapturedVars);
   CGF.CGM.getOpenMPRuntime().emitTeamsCall(CGF, S, S.getBeginLoc(), OutlinedFn,
                                            CapturedVars);
 }
Index: clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -2106,12 +2106,14 @@
     // TODO: Is that needed?
     CodeGenFunction::OMPPrivateScope PrivateArgScope(CGF);
 
+    // Store addresses of global arguments to pass to the parallel call.
     Address CapturedVarsAddrs = CGF.CreateDefaultAlignTempAlloca(
         llvm::ArrayType::get(CGM.VoidPtrTy, CapturedVars.size()),
         "captured_vars_addrs");
-    // There's something to share.
+
+    // Store globalized values to push, pop through the global stack.
+    SmallVector<llvm::Value *, 4> GlobalValues;
     if (!CapturedVars.empty()) {
-      // Prepare for parallel region. Indicate the outlined function.
       ASTContext &Ctx = CGF.getContext();
       unsigned Idx = 0;
       for (llvm::Value *V : CapturedVars) {
@@ -2119,8 +2121,38 @@
         llvm::Value *PtrV;
         if (V->getType()->isIntegerTy())
           PtrV = Bld.CreateIntToPtr(V, CGF.VoidPtrTy);
-        else
-          PtrV = Bld.CreatePointerBitCastOrAddrSpaceCast(V, CGF.VoidPtrTy);
+        else {
+          assert(V->getType()->isPointerTy() &&
+                 "Expected Pointer Type to globalize.");
+          // Globalize and store pointer.
+          llvm::Type *PtrElemTy = V->getType()->getPointerElementType();
+          auto &DL = CGM.getDataLayout();
+          unsigned GlobalSize = DL.getTypeAllocSize(PtrElemTy);
+
+          // Use shared memory to store globalized pointer values, for now this
+          // should be the outlined args aggregate struct.
+          llvm::Value *GlobalSizeArg[] = {
+              llvm::ConstantInt::get(CGM.SizeTy, GlobalSize),
+              CGF.Builder.getInt16(/*UseSharedMemory*/ 1)};
+          llvm::Value *GlobalValue = CGF.EmitRuntimeCall(
+              OMPBuilder.getOrCreateRuntimeFunction(
+                  CGM.getModule(),
+                  IsInTTDRegion
+                      ? OMPRTL___kmpc_data_sharing_push_stack
+                      : OMPRTL___kmpc_data_sharing_coalesced_push_stack),
+              GlobalSizeArg);
+          GlobalValues.push_back(GlobalValue);
+
+          llvm::Value *CapturedVarVal = Bld.CreateAlignedLoad(
+              PtrElemTy, V, DL.getABITypeAlign(PtrElemTy));
+          llvm::Value *GlobalValueCast =
+              Bld.CreatePointerBitCastOrAddrSpaceCast(
+                  GlobalValue, PtrElemTy->getPointerTo());
+          Bld.CreateDefaultAlignedStore(CapturedVarVal, GlobalValueCast);
+
+          PtrV = Bld.CreatePointerBitCastOrAddrSpaceCast(GlobalValue,
+                                                         CGF.VoidPtrTy);
+        }
         CGF.EmitStoreOfScalar(PtrV, Dst, /*Volatile=*/false,
                               Ctx.getPointerType(Ctx.VoidPtrTy));
         ++Idx;
@@ -2133,8 +2165,9 @@
                                     /* isSigned */ false);
     else
       IfCondVal = llvm::ConstantInt::get(CGF.Int32Ty, 1);
-
     assert(IfCondVal && "Expected a value");
+
+    // Create the parallel call.
     llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);
     llvm::Value *Args[] = {
         RTLoc,
@@ -2150,6 +2183,14 @@
     CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
                             CGM.getModule(), OMPRTL___kmpc_parallel_51),
                         Args);
+
+    // Pop any globalized values from the global stack.
+    for (auto *V : GlobalValues) {
+      CGF.EmitRuntimeCall(
+          OMPBuilder.getOrCreateRuntimeFunction(
+              CGM.getModule(), OMPRTL___kmpc_data_sharing_pop_stack),
+          V);
+    }
   };
 
   RegionCodeGenTy RCG(ParallelGen);
@@ -4062,7 +4103,6 @@
                     D.getBeginLoc(), D.getBeginLoc());
 
   const auto *RD = CS.getCapturedRecordDecl();
-  auto CurField = RD->field_begin();
 
   Address ZeroAddr = CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty,
                                                       /*Name=*/".zero.addr");
@@ -4074,7 +4114,6 @@
   Args.emplace_back(ZeroAddr.getPointer());
 
   CGBuilderTy &Bld = CGF.Builder;
-  auto CI = CS.capture_begin();
 
   // Use global memory for data sharing.
   // Handle passing of global args to workers.
@@ -4121,23 +4160,27 @@
     ++Idx;
   }
   if (CS.capture_size() > 0) {
+    auto CI = CS.capture_begin();
+    // Load the outlined arg aggregate struct.
     ASTContext &CGFContext = CGF.getContext();
-    for (unsigned I = 0, E = CS.capture_size(); I < E; ++I, ++CI, ++CurField) {
-      QualType ElemTy = CurField->getType();
-      Address Src = Bld.CreateConstInBoundsGEP(SharedArgListAddress, I + Idx);
-      Address TypedAddress = Bld.CreatePointerBitCastOrAddrSpaceCast(
-          Src, CGF.ConvertTypeForMem(CGFContext.getPointerType(ElemTy)));
-      llvm::Value *Arg = CGF.EmitLoadOfScalar(TypedAddress,
-                                              /*Volatile=*/false,
-                                              CGFContext.getPointerType(ElemTy),
-                                              CI->getLocation());
-      if (CI->capturesVariableByCopy() &&
-          !CI->getCapturedVar()->getType()->isAnyPointerType()) {
-        Arg = castValueToType(CGF, Arg, ElemTy, CGFContext.getUIntPtrType(),
-                              CI->getLocation());
-      }
-      Args.emplace_back(Arg);
-    }
+    QualType RecordPointerTy =
+        CGFContext.getPointerType(CGFContext.getRecordType(RD));
+    Address Src = Bld.CreateConstInBoundsGEP(SharedArgListAddress, Idx);
+    Address TypedAddress = Bld.CreatePointerBitCastOrAddrSpaceCast(
+        Src, CGF.ConvertTypeForMem(CGFContext.getPointerType(RecordPointerTy)));
+    llvm::Value *Arg = CGF.EmitLoadOfScalar(
+        TypedAddress,
+        /*Volatile=*/false, CGFContext.getPointerType(RecordPointerTy),
+        CI->getLocation());
+    Args.emplace_back(Arg);
+  } else {
+    // If there are no captured arguments, use nullptr.
+    ASTContext &CGFContext = CGF.getContext();
+    QualType RecordPointerTy =
+        CGFContext.getPointerType(CGFContext.getRecordType(RD));
+    llvm::Value *Arg =
+        llvm::Constant::getNullValue(CGF.ConvertTypeForMem(RecordPointerTy));
+    Args.emplace_back(Arg);
   }
 
   emitOutlinedFunctionCall(CGF, D.getBeginLoc(), OutlinedParallelFn, Args);
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -1280,7 +1280,7 @@
   CGOpenMPOutlinedRegionInfo CGInfo(*CS, ThreadIDVar, CodeGen, InnermostKind,
                                     HasCancel, OutlinedHelperName);
   CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
-  return CGF.GenerateOpenMPCapturedStmtFunction(*CS, D.getBeginLoc());
+  return CGF.GenerateOpenMPCapturedStmtFunctionAggregate(*CS, D.getBeginLoc());
 }
 
 llvm::Function *CGOpenMPRuntime::emitParallelOutlinedFunction(
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to