This revision was landed with ongoing or failed builds.
This revision was automatically updated to reflect the committed changes.
Closed by commit rGac65cc121513: [OpenMP][OpenMPIRBuilder] Migrate kernel 
launch code and host fallback code… (authored by jsjodin).
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D151035/new/

https://reviews.llvm.org/D151035

Files:
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CGOpenMPRuntime.h
  llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -331,6 +331,35 @@
   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
 }
 
+void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
+                                          IRBuilderBase &Builder,
+                                          SmallVector<Value *> &ArgsVector) {
+  Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
+  Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
+  auto Int32Ty = Type::getInt32Ty(Builder.getContext());
+  Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
+  Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
+
+  Value *NumTeams3D =
+      Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
+  Value *NumThreads3D =
+      Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
+
+  ArgsVector = {Version,
+                PointerNum,
+                KernelArgs.RTArgs.BasePointersArray,
+                KernelArgs.RTArgs.PointersArray,
+                KernelArgs.RTArgs.SizesArray,
+                KernelArgs.RTArgs.MapTypesArray,
+                KernelArgs.RTArgs.MapNamesArray,
+                KernelArgs.RTArgs.MappersArray,
+                KernelArgs.NumIterations,
+                Flags,
+                NumTeams3D,
+                NumThreads3D,
+                KernelArgs.DynCGGroupMem};
+}
+
 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
   LLVMContext &Ctx = Fn.getContext();
   Triple T(M.getTargetTriple());
@@ -880,6 +909,67 @@
   return Builder.saveIP();
 }
 
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
+    const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
+    EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
+    Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
+
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  Builder.restoreIP(Loc.IP);
+  // On top of the arrays that were filled up, the target offloading call
+  // takes as arguments the device id as well as the host pointer. The host
+  // pointer is used by the runtime library to identify the current target
+  // region, so it only has to be unique and not necessarily point to
+  // anything. It could be the pointer to the outlined function that
+  // implements the target region, but we aren't using that so that the
+  // compiler doesn't need to keep that, and could therefore inline the host
+  // function if proven worthwhile during optimization.
+
+  // From this point on, we need to have an ID of the target region defined.
+  assert(OutlinedFnID && "Invalid outlined function ID!");
+  (void)OutlinedFnID;
+
+  // Return value of the runtime offloading call.
+  Value *Return;
+
+  // Arguments for the target kernel.
+  SmallVector<Value *> ArgsVector;
+  getKernelArgsVector(Args, Builder, ArgsVector);
+
+  // The target region is an outlined function launched by the runtime
+  // via calls to __tgt_target_kernel().
+  //
+  // Note that on the host and CPU targets, the runtime implementation of
+  // these calls simply call the outlined function without forking threads.
+  // The outlined functions themselves have runtime calls to
+  // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
+  // the compiler in emitTeamsCall() and emitParallelCall().
+  //
+  // In contrast, on the NVPTX target, the implementation of
+  // __tgt_target_teams() launches a GPU kernel with the requested number
+  // of teams and threads so no additional calls to the runtime are required.
+  // Check the error code and execute the host version if required.
+  Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
+                                     Args.NumTeams, Args.NumThreads,
+                                     OutlinedFnID, ArgsVector));
+
+  BasicBlock *OffloadFailedBlock =
+      BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
+  BasicBlock *OffloadContBlock =
+      BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
+  Value *Failed = Builder.CreateIsNotNull(Return);
+  Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
+
+  auto CurFn = Builder.GetInsertBlock()->getParent();
+  emitBlock(OffloadFailedBlock, CurFn);
+  Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
+  emitBranch(OffloadContBlock);
+  emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
+  return Builder.saveIP();
+}
+
 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
                                                omp::Directive CanceledDirective,
                                                FinalizeCallbackTy ExitCB) {
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1533,7 +1533,6 @@
 
   /// Container for the arguments used to pass data to the runtime library.
   struct TargetDataRTArgs {
-    explicit TargetDataRTArgs() {}
     /// The array of base pointer passed to the runtime library.
     Value *BasePointersArray = nullptr;
     /// The array of section pointers passed to the runtime library.
@@ -1553,8 +1552,53 @@
     /// The array of original declaration names of mapped pointers sent to the
     /// runtime library for debugging
     Value *MapNamesArray = nullptr;
+
+    explicit TargetDataRTArgs() {}
+    explicit TargetDataRTArgs(Value *BasePointersArray, Value *PointersArray,
+                              Value *SizesArray, Value *MapTypesArray,
+                              Value *MapTypesArrayEnd, Value *MappersArray,
+                              Value *MapNamesArray)
+        : BasePointersArray(BasePointersArray), PointersArray(PointersArray),
+          SizesArray(SizesArray), MapTypesArray(MapTypesArray),
+          MapTypesArrayEnd(MapTypesArrayEnd), MappersArray(MappersArray),
+          MapNamesArray(MapNamesArray) {}
   };
 
+  /// Data structure that contains the needed information to construct the
+  /// kernel args vector.
+  struct TargetKernelArgs {
+    /// Number of arguments passed to the runtime library.
+    unsigned NumTargetItems;
+    /// Arguments passed to the runtime library
+    TargetDataRTArgs RTArgs;
+    /// The number of iterations
+    Value *NumIterations;
+    /// The number of teams.
+    Value *NumTeams;
+    /// The number of threads.
+    Value *NumThreads;
+    /// The size of the dynamic shared memory.
+    Value *DynCGGroupMem;
+    /// True if the kernel has 'no wait' clause.
+    bool HasNoWait;
+
+    /// Constructor for TargetKernelArgs
+    TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
+                     Value *NumIterations, Value *NumTeams, Value *NumThreads,
+                     Value *DynCGGroupMem, bool HasNoWait)
+        : NumTargetItems(NumTargetItems), RTArgs(RTArgs),
+          NumIterations(NumIterations), NumTeams(NumTeams),
+          NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
+          HasNoWait(HasNoWait) {}
+  };
+
+  /// Create the kernel args vector used by emitTargetKernel. This function
+  /// creates various constant values that are used in the resulting args
+  /// vector.
+  static void getKernelArgsVector(TargetKernelArgs &KernelArgs,
+                                  IRBuilderBase &Builder,
+                                  SmallVector<Value *> &ArgsVector);
+
   /// Struct that keeps the information that should be kept throughout
   /// a 'target data' region.
   class TargetDataInfo {
@@ -1636,6 +1680,28 @@
     }
   };
 
+  /// Callback function type for functions emitting the host fallback code that
+  /// is executed when the kernel launch fails. It takes an insertion point as
+  /// parameter where the code should be emitted. It returns an insertion point
+  /// that points right after after the emitted code.
+  using EmitFallbackCallbackTy = function_ref<InsertPointTy(InsertPointTy)>;
+
+  /// Generate a target region entry call and host fallback call.
+  ///
+  /// \param Loc The location at which the request originated and is fulfilled.
+  /// \param OutlinedFn The outlined kernel function.
+  /// \param OutlinedFnID The ooulined function ID.
+  /// \param EmitTargetCallFallbackCB Call back function to generate host
+  ///        fallback code.
+  /// \param Args Data structure holding information about the kernel arguments.
+  /// \param DeviceID Identifier for the device via the 'device' clause.
+  /// \param RTLoc Source location identifier
+  /// \param AllocaIP The insertion point to be used for alloca instructions.
+  InsertPointTy emitKernelLaunch(
+      const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
+      EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
+      Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP);
+
   /// Emit the arguments to be passed to the runtime library based on the
   /// arrays of base pointers, pointers, sizes, map types, and mappers.  If
   /// ForEndCall, emit map types to be passed for the end of the region instead
Index: llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -72,6 +72,9 @@
 #define OMP_IDENT_FLAG(Enum, ...) constexpr auto Enum = omp::IdentFlag::Enum;
 #include "llvm/Frontend/OpenMP/OMPKinds.def"
 
+// Version of the kernel argument format used by the omp runtime.
+#define OMP_KERNEL_ARG_VERSION 2
+
 /// \note This needs to be kept in sync with kmp.h enum sched_type.
 /// Todo: Update kmp.h to include this file, and remove the enums in kmp.h
 enum class OMPScheduleType {
Index: clang/lib/CodeGen/CGOpenMPRuntime.h
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.h
+++ clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -327,42 +327,6 @@
                                                 bool IsOffloadEntry,
                                                 const RegionCodeGenTy &CodeGen);
 
-  /// Emits object of ident_t type with info for source location.
-  /// \param Flags Flags for OpenMP location.
-  /// \param EmitLoc emit source location with debug-info is off.
-  ///
-  llvm::Value *emitUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc,
-                                  unsigned Flags = 0, bool EmitLoc = false);
-
-  /// Emit the number of teams for a target directive.  Inspect the num_teams
-  /// clause associated with a teams construct combined or closely nested
-  /// with the target directive.
-  ///
-  /// Emit a team of size one for directives such as 'target parallel' that
-  /// have no associated teams construct.
-  ///
-  /// Otherwise, return nullptr.
-  const Expr *getNumTeamsExprForTargetDirective(CodeGenFunction &CGF,
-                                                const OMPExecutableDirective &D,
-                                                int32_t &DefaultVal);
-  llvm::Value *emitNumTeamsForTargetDirective(CodeGenFunction &CGF,
-                                              const OMPExecutableDirective &D);
-  /// Emit the number of threads for a target directive.  Inspect the
-  /// thread_limit clause associated with a teams construct combined or closely
-  /// nested with the target directive.
-  ///
-  /// Emit the num_threads clause for directives such as 'target parallel' that
-  /// have no associated teams construct.
-  ///
-  /// Otherwise, return nullptr.
-  const Expr *
-  getNumThreadsExprForTargetDirective(CodeGenFunction &CGF,
-                                      const OMPExecutableDirective &D,
-                                      int32_t &DefaultVal);
-  llvm::Value *
-  emitNumThreadsForTargetDirective(CodeGenFunction &CGF,
-                                   const OMPExecutableDirective &D);
-
   /// Returns pointer to ident_t type.
   llvm::Type *getIdentTyPointerTy();
 
@@ -652,15 +616,6 @@
                             llvm::Function *TaskFunction, QualType SharedsTy,
                             Address Shareds, const OMPTaskDataTy &Data);
 
-  /// Return the trip count of loops associated with constructs / 'target teams
-  /// distribute' and 'teams distribute parallel for'. \param SizeEmitter Emits
-  /// the int64 value for the number of iterations of the associated loop.
-  llvm::Value *emitTargetNumIterationsCall(
-      CodeGenFunction &CGF, const OMPExecutableDirective &D,
-      llvm::function_ref<llvm::Value *(CodeGenFunction &CGF,
-                                       const OMPLoopDirective &D)>
-          SizeEmitter);
-
   /// Emit update for lastprivate conditional data.
   void emitLastprivateConditionalUpdate(CodeGenFunction &CGF, LValue IVLVal,
                                         StringRef UniqueDeclName, LValue LVal,
@@ -687,6 +642,51 @@
   virtual ~CGOpenMPRuntime() {}
   virtual void clear();
 
+  /// Emits object of ident_t type with info for source location.
+  /// \param Flags Flags for OpenMP location.
+  /// \param EmitLoc emit source location with debug-info is off.
+  ///
+  llvm::Value *emitUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc,
+                                  unsigned Flags = 0, bool EmitLoc = false);
+
+  /// Emit the number of teams for a target directive.  Inspect the num_teams
+  /// clause associated with a teams construct combined or closely nested
+  /// with the target directive.
+  ///
+  /// Emit a team of size one for directives such as 'target parallel' that
+  /// have no associated teams construct.
+  ///
+  /// Otherwise, return nullptr.
+  const Expr *getNumTeamsExprForTargetDirective(CodeGenFunction &CGF,
+                                                const OMPExecutableDirective &D,
+                                                int32_t &DefaultVal);
+  llvm::Value *emitNumTeamsForTargetDirective(CodeGenFunction &CGF,
+                                              const OMPExecutableDirective &D);
+  /// Emit the number of threads for a target directive.  Inspect the
+  /// thread_limit clause associated with a teams construct combined or closely
+  /// nested with the target directive.
+  ///
+  /// Emit the num_threads clause for directives such as 'target parallel' that
+  /// have no associated teams construct.
+  ///
+  /// Otherwise, return nullptr.
+  const Expr *
+  getNumThreadsExprForTargetDirective(CodeGenFunction &CGF,
+                                      const OMPExecutableDirective &D,
+                                      int32_t &DefaultVal);
+  llvm::Value *
+  emitNumThreadsForTargetDirective(CodeGenFunction &CGF,
+                                   const OMPExecutableDirective &D);
+
+  /// Return the trip count of loops associated with constructs / 'target teams
+  /// distribute' and 'teams distribute parallel for'. \param SizeEmitter Emits
+  /// the int64 value for the number of iterations of the associated loop.
+  llvm::Value *emitTargetNumIterationsCall(
+      CodeGenFunction &CGF, const OMPExecutableDirective &D,
+      llvm::function_ref<llvm::Value *(CodeGenFunction &CGF,
+                                       const OMPLoopDirective &D)>
+          SizeEmitter);
+
   /// Returns true if the current target is a GPU.
   virtual bool isTargetCodegen() const { return false; }
 
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -9573,6 +9573,253 @@
   return llvm::ConstantInt::get(CGF.Int64Ty, 0);
 }
 
+static void
+emitTargetCallFallback(CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
+                       const OMPExecutableDirective &D,
+                       llvm::SmallVectorImpl<llvm::Value *> &CapturedVars,
+                       bool RequiresOuterTask, const CapturedStmt &CS,
+                       bool OffloadingMandatory, CodeGenFunction &CGF) {
+  if (OffloadingMandatory) {
+    CGF.Builder.CreateUnreachable();
+  } else {
+    if (RequiresOuterTask) {
+      CapturedVars.clear();
+      CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
+    }
+    OMPRuntime->emitOutlinedFunctionCall(CGF, D.getBeginLoc(), OutlinedFn,
+                                         CapturedVars);
+  }
+}
+
+static llvm::Value *emitDeviceID(
+    llvm::PointerIntPair<const Expr *, 2, OpenMPDeviceClauseModifier> Device,
+    CodeGenFunction &CGF) {
+  // Emit device ID if any.
+  llvm::Value *DeviceID;
+  if (Device.getPointer()) {
+    assert((Device.getInt() == OMPC_DEVICE_unknown ||
+            Device.getInt() == OMPC_DEVICE_device_num) &&
+           "Expected device_num modifier.");
+    llvm::Value *DevVal = CGF.EmitScalarExpr(Device.getPointer());
+    DeviceID =
+        CGF.Builder.CreateIntCast(DevVal, CGF.Int64Ty, /*isSigned=*/true);
+  } else {
+    DeviceID = CGF.Builder.getInt64(OMP_DEVICEID_UNDEF);
+  }
+  return DeviceID;
+}
+
+llvm::Value *emitDynCGGroupMem(const OMPExecutableDirective &D,
+                               CodeGenFunction &CGF) {
+  llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
+
+  if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
+    CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
+    llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
+        DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
+    DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
+                                             /*isSigned=*/false);
+  }
+  return DynCGroupMem;
+}
+
+static void emitTargetCallKernelLaunch(
+    CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
+    const OMPExecutableDirective &D,
+    llvm::SmallVectorImpl<llvm::Value *> &CapturedVars, bool RequiresOuterTask,
+    const CapturedStmt &CS, bool OffloadingMandatory,
+    llvm::PointerIntPair<const Expr *, 2, OpenMPDeviceClauseModifier> Device,
+    llvm::Value *OutlinedFnID, CodeGenFunction::OMPTargetDataInfo &InputInfo,
+    llvm::Value *&MapTypesArray, llvm::Value *&MapNamesArray,
+    llvm::function_ref<llvm::Value *(CodeGenFunction &CGF,
+                                     const OMPLoopDirective &D)>
+        SizeEmitter,
+    CodeGenFunction &CGF, CodeGenModule &CGM) {
+  llvm::OpenMPIRBuilder &OMPBuilder = OMPRuntime->getOMPBuilder();
+
+  // Fill up the arrays with all the captured variables.
+  MappableExprsHandler::MapCombinedInfoTy CombinedInfo;
+
+  // Get mappable expression information.
+  MappableExprsHandler MEHandler(D, CGF);
+  llvm::DenseMap<llvm::Value *, llvm::Value *> LambdaPointers;
+  llvm::DenseSet<CanonicalDeclPtr<const Decl>> MappedVarSet;
+
+  auto RI = CS.getCapturedRecordDecl()->field_begin();
+  auto *CV = CapturedVars.begin();
+  for (CapturedStmt::const_capture_iterator CI = CS.capture_begin(),
+                                            CE = CS.capture_end();
+       CI != CE; ++CI, ++RI, ++CV) {
+    MappableExprsHandler::MapCombinedInfoTy CurInfo;
+    MappableExprsHandler::StructRangeInfoTy PartialStruct;
+
+    // VLA sizes are passed to the outlined region by copy and do not have map
+    // information associated.
+    if (CI->capturesVariableArrayType()) {
+      CurInfo.Exprs.push_back(nullptr);
+      CurInfo.BasePointers.push_back(*CV);
+      CurInfo.DevicePtrDecls.push_back(nullptr);
+      CurInfo.Pointers.push_back(*CV);
+      CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
+          CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
+      // Copy to the device as an argument. No need to retrieve it.
+      CurInfo.Types.push_back(OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
+                              OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
+                              OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+      CurInfo.Mappers.push_back(nullptr);
+    } else {
+      // If we have any information in the map clause, we use it, otherwise we
+      // just do a default mapping.
+      MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
+      if (!CI->capturesThis())
+        MappedVarSet.insert(CI->getCapturedVar());
+      else
+        MappedVarSet.insert(nullptr);
+      if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
+        MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
+      // Generate correct mapping for variables captured by reference in
+      // lambdas.
+      if (CI->capturesVariable())
+        MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
+                                                CurInfo, LambdaPointers);
+    }
+    // We expect to have at least an element of information for this capture.
+    assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
+           "Non-existing map pointer for capture!");
+    assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
+           CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
+           CurInfo.BasePointers.size() == CurInfo.Types.size() &&
+           CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
+           "Inconsistent map information sizes!");
+
+    // If there is an entry in PartialStruct it means we have a struct with
+    // individual members mapped. Emit an extra combined entry.
+    if (PartialStruct.Base.isValid()) {
+      CombinedInfo.append(PartialStruct.PreliminaryMapData);
+      MEHandler.emitCombinedEntry(
+          CombinedInfo, CurInfo.Types, PartialStruct, CI->capturesThis(),
+          nullptr, !PartialStruct.PreliminaryMapData.BasePointers.empty());
+    }
+
+    // We need to append the results of this capture to what we already have.
+    CombinedInfo.append(CurInfo);
+  }
+  // Adjust MEMBER_OF flags for the lambdas captures.
+  MEHandler.adjustMemberOfForLambdaCaptures(
+      LambdaPointers, CombinedInfo.BasePointers, CombinedInfo.Pointers,
+      CombinedInfo.Types);
+  // Map any list items in a map clause that were not captures because they
+  // weren't referenced within the construct.
+  MEHandler.generateAllInfo(CombinedInfo, MappedVarSet);
+
+  CGOpenMPRuntime::TargetDataInfo Info;
+  // Fill up the arrays and create the arguments.
+  emitOffloadingArrays(CGF, CombinedInfo, Info, OMPBuilder);
+  bool EmitDebug = CGF.CGM.getCodeGenOpts().getDebugInfo() !=
+                   llvm::codegenoptions::NoDebugInfo;
+  OMPBuilder.emitOffloadingArraysArgument(CGF.Builder, Info.RTArgs, Info,
+                                          EmitDebug,
+                                          /*ForEndCall=*/false);
+
+  InputInfo.NumberOfTargetItems = Info.NumberOfPtrs;
+  InputInfo.BasePointersArray = Address(Info.RTArgs.BasePointersArray,
+                                        CGF.VoidPtrTy, CGM.getPointerAlign());
+  InputInfo.PointersArray =
+      Address(Info.RTArgs.PointersArray, CGF.VoidPtrTy, CGM.getPointerAlign());
+  InputInfo.SizesArray =
+      Address(Info.RTArgs.SizesArray, CGF.Int64Ty, CGM.getPointerAlign());
+  InputInfo.MappersArray =
+      Address(Info.RTArgs.MappersArray, CGF.VoidPtrTy, CGM.getPointerAlign());
+  MapTypesArray = Info.RTArgs.MapTypesArray;
+  MapNamesArray = Info.RTArgs.MapNamesArray;
+
+  auto &&ThenGen = [&OMPRuntime, OutlinedFn, &D, &CapturedVars,
+                    RequiresOuterTask, &CS, OffloadingMandatory, Device,
+                    OutlinedFnID, &InputInfo, &MapTypesArray, &MapNamesArray,
+                    SizeEmitter](CodeGenFunction &CGF, PrePostActionTy &) {
+    bool IsReverseOffloading = Device.getInt() == OMPC_DEVICE_ancestor;
+
+    if (IsReverseOffloading) {
+      // Reverse offloading is not supported, so just execute on the host.
+      // FIXME: This fallback solution is incorrect since it ignores the
+      // OMP_TARGET_OFFLOAD environment variable. Instead it would be better to
+      // assert here and ensure SEMA emits an error.
+      emitTargetCallFallback(OMPRuntime, OutlinedFn, D, CapturedVars,
+                             RequiresOuterTask, CS, OffloadingMandatory, CGF);
+      return;
+    }
+
+    bool HasNoWait = D.hasClausesOfKind<OMPNowaitClause>();
+    unsigned NumTargetItems = InputInfo.NumberOfTargetItems;
+
+    llvm::Value *BasePointersArray = InputInfo.BasePointersArray.getPointer();
+    llvm::Value *PointersArray = InputInfo.PointersArray.getPointer();
+    llvm::Value *SizesArray = InputInfo.SizesArray.getPointer();
+    llvm::Value *MappersArray = InputInfo.MappersArray.getPointer();
+
+    auto &&EmitTargetCallFallbackCB =
+        [&OMPRuntime, OutlinedFn, &D, &CapturedVars, RequiresOuterTask, &CS,
+         OffloadingMandatory, &CGF](llvm::OpenMPIRBuilder::InsertPointTy IP)
+        -> llvm::OpenMPIRBuilder::InsertPointTy {
+      CGF.Builder.restoreIP(IP);
+      emitTargetCallFallback(OMPRuntime, OutlinedFn, D, CapturedVars,
+                             RequiresOuterTask, CS, OffloadingMandatory, CGF);
+      return CGF.Builder.saveIP();
+    };
+
+    llvm::Value *DeviceID = emitDeviceID(Device, CGF);
+    llvm::Value *NumTeams = OMPRuntime->emitNumTeamsForTargetDirective(CGF, D);
+    llvm::Value *NumThreads =
+        OMPRuntime->emitNumThreadsForTargetDirective(CGF, D);
+    llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
+    llvm::Value *NumIterations =
+        OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
+    llvm::Value *DynCGGroupMem = emitDynCGGroupMem(D, CGF);
+    llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
+        CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
+
+    llvm::OpenMPIRBuilder::TargetDataRTArgs RTArgs(
+        BasePointersArray, PointersArray, SizesArray, MapTypesArray,
+        nullptr /* MapTypesArrayEnd */, MappersArray, MapNamesArray);
+
+    llvm::OpenMPIRBuilder::TargetKernelArgs Args(
+        NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
+        DynCGGroupMem, HasNoWait);
+
+    CGF.Builder.restoreIP(OMPRuntime->getOMPBuilder().emitKernelLaunch(
+        CGF.Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, Args,
+        DeviceID, RTLoc, AllocaIP));
+  };
+
+  if (RequiresOuterTask)
+    CGF.EmitOMPTargetTaskBasedDirective(D, ThenGen, InputInfo);
+  else
+    OMPRuntime->emitInlinedDirective(CGF, D.getDirectiveKind(), ThenGen);
+}
+
+static void
+emitTargetCallElse(CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
+                   const OMPExecutableDirective &D,
+                   llvm::SmallVectorImpl<llvm::Value *> &CapturedVars,
+                   bool RequiresOuterTask, const CapturedStmt &CS,
+                   bool OffloadingMandatory, CodeGenFunction &CGF) {
+
+  // Notify that the host version must be executed.
+  auto &&ElseGen =
+      [&OMPRuntime, OutlinedFn, &D, &CapturedVars, RequiresOuterTask, &CS,
+       OffloadingMandatory](CodeGenFunction &CGF, PrePostActionTy &) {
+        emitTargetCallFallback(OMPRuntime, OutlinedFn, D, CapturedVars,
+                               RequiresOuterTask, CS, OffloadingMandatory, CGF);
+      };
+
+  if (RequiresOuterTask) {
+    CodeGenFunction::OMPTargetDataInfo InputInfo;
+    CGF.EmitOMPTargetTaskBasedDirective(D, ElseGen, InputInfo);
+  } else {
+    OMPRuntime->emitInlinedDirective(CGF, D.getDirectiveKind(), ElseGen);
+  }
+}
+
 void CGOpenMPRuntime::emitTargetCall(
     CodeGenFunction &CGF, const OMPExecutableDirective &D,
     llvm::Function *OutlinedFn, llvm::Value *OutlinedFnID, const Expr *IfCond,
@@ -9602,263 +9849,24 @@
   CodeGenFunction::OMPTargetDataInfo InputInfo;
   llvm::Value *MapTypesArray = nullptr;
   llvm::Value *MapNamesArray = nullptr;
-  // Generate code for the host fallback function.
-  auto &&FallbackGen = [this, OutlinedFn, &D, &CapturedVars, RequiresOuterTask,
-                        &CS, OffloadingMandatory](CodeGenFunction &CGF) {
-    if (OffloadingMandatory) {
-      CGF.Builder.CreateUnreachable();
-    } else {
-      if (RequiresOuterTask) {
-        CapturedVars.clear();
-        CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
-      }
-      emitOutlinedFunctionCall(CGF, D.getBeginLoc(), OutlinedFn, CapturedVars);
-    }
-  };
-  // Fill up the pointer arrays and transfer execution to the device.
-  auto &&ThenGen = [this, Device, OutlinedFnID, &D, &InputInfo, &MapTypesArray,
-                    &MapNamesArray, SizeEmitter,
-                    FallbackGen](CodeGenFunction &CGF, PrePostActionTy &) {
-    if (Device.getInt() == OMPC_DEVICE_ancestor) {
-      // Reverse offloading is not supported, so just execute on the host.
-      FallbackGen(CGF);
-      return;
-    }
-
-    // On top of the arrays that were filled up, the target offloading call
-    // takes as arguments the device id as well as the host pointer. The host
-    // pointer is used by the runtime library to identify the current target
-    // region, so it only has to be unique and not necessarily point to
-    // anything. It could be the pointer to the outlined function that
-    // implements the target region, but we aren't using that so that the
-    // compiler doesn't need to keep that, and could therefore inline the host
-    // function if proven worthwhile during optimization.
 
-    // From this point on, we need to have an ID of the target region defined.
-    assert(OutlinedFnID && "Invalid outlined function ID!");
-    (void)OutlinedFnID;
-
-    // Emit device ID if any.
-    llvm::Value *DeviceID;
-    if (Device.getPointer()) {
-      assert((Device.getInt() == OMPC_DEVICE_unknown ||
-              Device.getInt() == OMPC_DEVICE_device_num) &&
-             "Expected device_num modifier.");
-      llvm::Value *DevVal = CGF.EmitScalarExpr(Device.getPointer());
-      DeviceID =
-          CGF.Builder.CreateIntCast(DevVal, CGF.Int64Ty, /*isSigned=*/true);
-    } else {
-      DeviceID = CGF.Builder.getInt64(OMP_DEVICEID_UNDEF);
-    }
-
-    // Emit the number of elements in the offloading arrays.
-    llvm::Value *PointerNum =
-        CGF.Builder.getInt32(InputInfo.NumberOfTargetItems);
-
-    // Return value of the runtime offloading call.
-    llvm::Value *Return;
-
-    llvm::Value *NumTeams = emitNumTeamsForTargetDirective(CGF, D);
-    llvm::Value *NumThreads = emitNumThreadsForTargetDirective(CGF, D);
-
-    // Source location for the ident struct
-    llvm::Value *RTLoc = emitUpdateLocation(CGF, D.getBeginLoc());
-
-    // Get tripcount for the target loop-based directive.
-    llvm::Value *NumIterations =
-        emitTargetNumIterationsCall(CGF, D, SizeEmitter);
-
-    llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
-    if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
-      CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
-      llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
-          DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
-      DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
-                                               /*isSigned=*/false);
-    }
-
-    llvm::Value *ZeroArray =
-        llvm::Constant::getNullValue(llvm::ArrayType::get(CGF.CGM.Int32Ty, 3));
-
-    bool HasNoWait = D.hasClausesOfKind<OMPNowaitClause>();
-    llvm::Value *Flags = CGF.Builder.getInt64(HasNoWait);
-
-    llvm::Value *NumTeams3D =
-        CGF.Builder.CreateInsertValue(ZeroArray, NumTeams, {0});
-    llvm::Value *NumThreads3D =
-        CGF.Builder.CreateInsertValue(ZeroArray, NumThreads, {0});
-
-    // Arguments for the target kernel.
-    SmallVector<llvm::Value *> KernelArgs{
-        CGF.Builder.getInt32(/* Version */ 2),
-        PointerNum,
-        InputInfo.BasePointersArray.getPointer(),
-        InputInfo.PointersArray.getPointer(),
-        InputInfo.SizesArray.getPointer(),
-        MapTypesArray,
-        MapNamesArray,
-        InputInfo.MappersArray.getPointer(),
-        NumIterations,
-        Flags,
-        NumTeams3D,
-        NumThreads3D,
-        DynCGroupMem,
-    };
-
-    llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
-        CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
-
-    // The target region is an outlined function launched by the runtime
-    // via calls to __tgt_target_kernel().
-    //
-    // Note that on the host and CPU targets, the runtime implementation of
-    // these calls simply call the outlined function without forking threads.
-    // The outlined functions themselves have runtime calls to
-    // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
-    // the compiler in emitTeamsCall() and emitParallelCall().
-    //
-    // In contrast, on the NVPTX target, the implementation of
-    // __tgt_target_teams() launches a GPU kernel with the requested number
-    // of teams and threads so no additional calls to the runtime are required.
-    // Check the error code and execute the host version if required.
-    CGF.Builder.restoreIP(OMPBuilder.emitTargetKernel(
-        CGF.Builder, AllocaIP, Return, RTLoc, DeviceID, NumTeams, NumThreads,
-        OutlinedFnID, KernelArgs));
-
-    llvm::BasicBlock *OffloadFailedBlock =
-        CGF.createBasicBlock("omp_offload.failed");
-    llvm::BasicBlock *OffloadContBlock =
-        CGF.createBasicBlock("omp_offload.cont");
-    llvm::Value *Failed = CGF.Builder.CreateIsNotNull(Return);
-    CGF.Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
-
-    CGF.EmitBlock(OffloadFailedBlock);
-    FallbackGen(CGF);
-
-    CGF.EmitBranch(OffloadContBlock);
-
-    CGF.EmitBlock(OffloadContBlock, /*IsFinished=*/true);
+  auto &&TargetThenGen = [this, OutlinedFn, &D, &CapturedVars,
+                          RequiresOuterTask, &CS, OffloadingMandatory, Device,
+                          OutlinedFnID, &InputInfo, &MapTypesArray,
+                          &MapNamesArray, SizeEmitter](CodeGenFunction &CGF,
+                                                       PrePostActionTy &) {
+    emitTargetCallKernelLaunch(this, OutlinedFn, D, CapturedVars,
+                               RequiresOuterTask, CS, OffloadingMandatory,
+                               Device, OutlinedFnID, InputInfo, MapTypesArray,
+                               MapNamesArray, SizeEmitter, CGF, CGM);
   };
 
-  // Notify that the host version must be executed.
-  auto &&ElseGen = [FallbackGen](CodeGenFunction &CGF, PrePostActionTy &) {
-    FallbackGen(CGF);
-  };
-
-  auto &&TargetThenGen = [this, &ThenGen, &D, &InputInfo, &MapTypesArray,
-                          &MapNamesArray, &CapturedVars, RequiresOuterTask,
-                          &CS](CodeGenFunction &CGF, PrePostActionTy &) {
-    // Fill up the arrays with all the captured variables.
-    MappableExprsHandler::MapCombinedInfoTy CombinedInfo;
-
-    // Get mappable expression information.
-    MappableExprsHandler MEHandler(D, CGF);
-    llvm::DenseMap<llvm::Value *, llvm::Value *> LambdaPointers;
-    llvm::DenseSet<CanonicalDeclPtr<const Decl>> MappedVarSet;
-
-    auto RI = CS.getCapturedRecordDecl()->field_begin();
-    auto *CV = CapturedVars.begin();
-    for (CapturedStmt::const_capture_iterator CI = CS.capture_begin(),
-                                              CE = CS.capture_end();
-         CI != CE; ++CI, ++RI, ++CV) {
-      MappableExprsHandler::MapCombinedInfoTy CurInfo;
-      MappableExprsHandler::StructRangeInfoTy PartialStruct;
-
-      // VLA sizes are passed to the outlined region by copy and do not have map
-      // information associated.
-      if (CI->capturesVariableArrayType()) {
-        CurInfo.Exprs.push_back(nullptr);
-        CurInfo.BasePointers.push_back(*CV);
-        CurInfo.DevicePtrDecls.push_back(nullptr);
-        CurInfo.Pointers.push_back(*CV);
-        CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
-            CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
-        // Copy to the device as an argument. No need to retrieve it.
-        CurInfo.Types.push_back(
-            OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
-            OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
-            OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
-        CurInfo.Mappers.push_back(nullptr);
-      } else {
-        // If we have any information in the map clause, we use it, otherwise we
-        // just do a default mapping.
-        MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
-        if (!CI->capturesThis())
-          MappedVarSet.insert(CI->getCapturedVar());
-        else
-          MappedVarSet.insert(nullptr);
-        if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
-          MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
-        // Generate correct mapping for variables captured by reference in
-        // lambdas.
-        if (CI->capturesVariable())
-          MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
-                                                  CurInfo, LambdaPointers);
-      }
-      // We expect to have at least an element of information for this capture.
-      assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
-             "Non-existing map pointer for capture!");
-      assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
-             CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
-             CurInfo.BasePointers.size() == CurInfo.Types.size() &&
-             CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
-             "Inconsistent map information sizes!");
-
-      // If there is an entry in PartialStruct it means we have a struct with
-      // individual members mapped. Emit an extra combined entry.
-      if (PartialStruct.Base.isValid()) {
-        CombinedInfo.append(PartialStruct.PreliminaryMapData);
-        MEHandler.emitCombinedEntry(
-            CombinedInfo, CurInfo.Types, PartialStruct, CI->capturesThis(),
-            nullptr, !PartialStruct.PreliminaryMapData.BasePointers.empty());
-      }
-
-      // We need to append the results of this capture to what we already have.
-      CombinedInfo.append(CurInfo);
-    }
-    // Adjust MEMBER_OF flags for the lambdas captures.
-    MEHandler.adjustMemberOfForLambdaCaptures(
-        LambdaPointers, CombinedInfo.BasePointers, CombinedInfo.Pointers,
-        CombinedInfo.Types);
-    // Map any list items in a map clause that were not captures because they
-    // weren't referenced within the construct.
-    MEHandler.generateAllInfo(CombinedInfo, MappedVarSet);
-
-    CGOpenMPRuntime::TargetDataInfo Info;
-    // Fill up the arrays and create the arguments.
-    emitOffloadingArrays(CGF, CombinedInfo, Info, OMPBuilder);
-    bool EmitDebug = CGF.CGM.getCodeGenOpts().getDebugInfo() !=
-                     llvm::codegenoptions::NoDebugInfo;
-    OMPBuilder.emitOffloadingArraysArgument(CGF.Builder, Info.RTArgs, Info,
-                                            EmitDebug,
-                                            /*ForEndCall=*/false);
-
-    InputInfo.NumberOfTargetItems = Info.NumberOfPtrs;
-    InputInfo.BasePointersArray = Address(Info.RTArgs.BasePointersArray,
-                                          CGF.VoidPtrTy, CGM.getPointerAlign());
-    InputInfo.PointersArray = Address(Info.RTArgs.PointersArray, CGF.VoidPtrTy,
-                                      CGM.getPointerAlign());
-    InputInfo.SizesArray =
-        Address(Info.RTArgs.SizesArray, CGF.Int64Ty, CGM.getPointerAlign());
-    InputInfo.MappersArray =
-        Address(Info.RTArgs.MappersArray, CGF.VoidPtrTy, CGM.getPointerAlign());
-    MapTypesArray = Info.RTArgs.MapTypesArray;
-    MapNamesArray = Info.RTArgs.MapNamesArray;
-    if (RequiresOuterTask)
-      CGF.EmitOMPTargetTaskBasedDirective(D, ThenGen, InputInfo);
-    else
-      emitInlinedDirective(CGF, D.getDirectiveKind(), ThenGen);
-  };
-
-  auto &&TargetElseGen = [this, &ElseGen, &D, RequiresOuterTask](
-                             CodeGenFunction &CGF, PrePostActionTy &) {
-    if (RequiresOuterTask) {
-      CodeGenFunction::OMPTargetDataInfo InputInfo;
-      CGF.EmitOMPTargetTaskBasedDirective(D, ElseGen, InputInfo);
-    } else {
-      emitInlinedDirective(CGF, D.getDirectiveKind(), ElseGen);
-    }
-  };
+  auto &&TargetElseGen =
+      [this, OutlinedFn, &D, &CapturedVars, RequiresOuterTask, &CS,
+       OffloadingMandatory](CodeGenFunction &CGF, PrePostActionTy &) {
+        emitTargetCallElse(this, OutlinedFn, D, CapturedVars, RequiresOuterTask,
+                           CS, OffloadingMandatory, CGF);
+      };
 
   // If we have a target function ID it means that we need to support
   // offloading, otherwise, just execute on the host. We need to execute on host
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to