Author: Jan Sjodin
Date: 2023-06-30T10:40:21-04:00
New Revision: ac65cc1215132b4725ed64b726671867a028e275


LOG: [OpenMP][OpenMPIRBuilder] Migrate kernel launch code and host fallback 
code generation from Clang to the OpenMPIRBuilder

This patch refactors the code generation that emits the offloading kernel
launch and moves the core portion to the OpenMPIRBuilder so that it can be used
from flang in the future.

Reviewed By: jdoerfert

Differential Revision:




diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp 
index 619f5629b6de78..dfc8f71ef43583 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -9573,6 +9573,253 @@ llvm::Value 
   return llvm::ConstantInt::get(CGF.Int64Ty, 0);
+static void
+emitTargetCallFallback(CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
+                       const OMPExecutableDirective &D,
+                       llvm::SmallVectorImpl<llvm::Value *> &CapturedVars,
+                       bool RequiresOuterTask, const CapturedStmt &CS,
+                       bool OffloadingMandatory, CodeGenFunction &CGF) {
+  if (OffloadingMandatory) {
+    CGF.Builder.CreateUnreachable();
+  } else {
+    if (RequiresOuterTask) {
+      CapturedVars.clear();
+      CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
+    }
+    OMPRuntime->emitOutlinedFunctionCall(CGF, D.getBeginLoc(), OutlinedFn,
+                                         CapturedVars);
+  }
+static llvm::Value *emitDeviceID(
+    llvm::PointerIntPair<const Expr *, 2, OpenMPDeviceClauseModifier> Device,
+    CodeGenFunction &CGF) {
+  // Emit device ID if any.
+  llvm::Value *DeviceID;
+  if (Device.getPointer()) {
+    assert((Device.getInt() == OMPC_DEVICE_unknown ||
+            Device.getInt() == OMPC_DEVICE_device_num) &&
+           "Expected device_num modifier.");
+    llvm::Value *DevVal = CGF.EmitScalarExpr(Device.getPointer());
+    DeviceID =
+        CGF.Builder.CreateIntCast(DevVal, CGF.Int64Ty, /*isSigned=*/true);
+  } else {
+    DeviceID = CGF.Builder.getInt64(OMP_DEVICEID_UNDEF);
+  }
+  return DeviceID;
+llvm::Value *emitDynCGGroupMem(const OMPExecutableDirective &D,
+                               CodeGenFunction &CGF) {
+  llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
+  if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
+    CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
+    llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
+        DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
+    DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
+                                             /*isSigned=*/false);
+  }
+  return DynCGroupMem;
+static void emitTargetCallKernelLaunch(
+    CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
+    const OMPExecutableDirective &D,
+    llvm::SmallVectorImpl<llvm::Value *> &CapturedVars, bool RequiresOuterTask,
+    const CapturedStmt &CS, bool OffloadingMandatory,
+    llvm::PointerIntPair<const Expr *, 2, OpenMPDeviceClauseModifier> Device,
+    llvm::Value *OutlinedFnID, CodeGenFunction::OMPTargetDataInfo &InputInfo,
+    llvm::Value *&MapTypesArray, llvm::Value *&MapNamesArray,
+    llvm::function_ref<llvm::Value *(CodeGenFunction &CGF,
+                                     const OMPLoopDirective &D)>
+        SizeEmitter,
+    CodeGenFunction &CGF, CodeGenModule &CGM) {
+  llvm::OpenMPIRBuilder &OMPBuilder = OMPRuntime->getOMPBuilder();
+  // Fill up the arrays with all the captured variables.
+  MappableExprsHandler::MapCombinedInfoTy CombinedInfo;
+  // Get mappable expression information.
+  MappableExprsHandler MEHandler(D, CGF);
+  llvm::DenseMap<llvm::Value *, llvm::Value *> LambdaPointers;
+  llvm::DenseSet<CanonicalDeclPtr<const Decl>> MappedVarSet;
+  auto RI = CS.getCapturedRecordDecl()->field_begin();
+  auto *CV = CapturedVars.begin();
+  for (CapturedStmt::const_capture_iterator CI = CS.capture_begin(),
+                                            CE = CS.capture_end();
+       CI != CE; ++CI, ++RI, ++CV) {
+    MappableExprsHandler::MapCombinedInfoTy CurInfo;
+    MappableExprsHandler::StructRangeInfoTy PartialStruct;
+    // VLA sizes are passed to the outlined region by copy and do not have map
+    // information associated.
+    if (CI->capturesVariableArrayType()) {
+      CurInfo.Exprs.push_back(nullptr);
+      CurInfo.BasePointers.push_back(*CV);
+      CurInfo.DevicePtrDecls.push_back(nullptr);
+      CurInfo.Pointers.push_back(*CV);
+      CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
+          CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
+      // Copy to the device as an argument. No need to retrieve it.
+      CurInfo.Types.push_back(OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
+                              OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
+                              OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+      CurInfo.Mappers.push_back(nullptr);
+    } else {
+      // If we have any information in the map clause, we use it, otherwise we
+      // just do a default mapping.
+      MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
+      if (!CI->capturesThis())
+        MappedVarSet.insert(CI->getCapturedVar());
+      else
+        MappedVarSet.insert(nullptr);
+      if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
+        MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
+      // Generate correct mapping for variables captured by reference in
+      // lambdas.
+      if (CI->capturesVariable())
+        MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
+                                                CurInfo, LambdaPointers);
+    }
+    // We expect to have at least an element of information for this capture.
+    assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
+           "Non-existing map pointer for capture!");
+    assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
+           CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
+           CurInfo.BasePointers.size() == CurInfo.Types.size() &&
+           CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
+           "Inconsistent map information sizes!");
+    // If there is an entry in PartialStruct it means we have a struct with
+    // individual members mapped. Emit an extra combined entry.
+    if (PartialStruct.Base.isValid()) {
+      CombinedInfo.append(PartialStruct.PreliminaryMapData);
+      MEHandler.emitCombinedEntry(
+          CombinedInfo, CurInfo.Types, PartialStruct, CI->capturesThis(),
+          nullptr, !PartialStruct.PreliminaryMapData.BasePointers.empty());
+    }
+    // We need to append the results of this capture to what we already have.
+    CombinedInfo.append(CurInfo);
+  }
+  // Adjust MEMBER_OF flags for the lambdas captures.
+  MEHandler.adjustMemberOfForLambdaCaptures(
+      LambdaPointers, CombinedInfo.BasePointers, CombinedInfo.Pointers,
+      CombinedInfo.Types);
+  // Map any list items in a map clause that were not captures because they
+  // weren't referenced within the construct.
+  MEHandler.generateAllInfo(CombinedInfo, MappedVarSet);
+  CGOpenMPRuntime::TargetDataInfo Info;
+  // Fill up the arrays and create the arguments.
+  emitOffloadingArrays(CGF, CombinedInfo, Info, OMPBuilder);
+  bool EmitDebug = CGF.CGM.getCodeGenOpts().getDebugInfo() !=
+                   llvm::codegenoptions::NoDebugInfo;
+  OMPBuilder.emitOffloadingArraysArgument(CGF.Builder, Info.RTArgs, Info,
+                                          EmitDebug,
+                                          /*ForEndCall=*/false);
+  InputInfo.NumberOfTargetItems = Info.NumberOfPtrs;
+  InputInfo.BasePointersArray = Address(Info.RTArgs.BasePointersArray,
+                                        CGF.VoidPtrTy, CGM.getPointerAlign());
+  InputInfo.PointersArray =
+      Address(Info.RTArgs.PointersArray, CGF.VoidPtrTy, CGM.getPointerAlign());
+  InputInfo.SizesArray =
+      Address(Info.RTArgs.SizesArray, CGF.Int64Ty, CGM.getPointerAlign());
+  InputInfo.MappersArray =
+      Address(Info.RTArgs.MappersArray, CGF.VoidPtrTy, CGM.getPointerAlign());
+  MapTypesArray = Info.RTArgs.MapTypesArray;
+  MapNamesArray = Info.RTArgs.MapNamesArray;
+  auto &&ThenGen = [&OMPRuntime, OutlinedFn, &D, &CapturedVars,
+                    RequiresOuterTask, &CS, OffloadingMandatory, Device,
+                    OutlinedFnID, &InputInfo, &MapTypesArray, &MapNamesArray,
+                    SizeEmitter](CodeGenFunction &CGF, PrePostActionTy &) {
+    bool IsReverseOffloading = Device.getInt() == OMPC_DEVICE_ancestor;
+    if (IsReverseOffloading) {
+      // Reverse offloading is not supported, so just execute on the host.
+      // FIXME: This fallback solution is incorrect since it ignores the
+      // OMP_TARGET_OFFLOAD environment variable. Instead it would be better to
+      // assert here and ensure SEMA emits an error.
+      emitTargetCallFallback(OMPRuntime, OutlinedFn, D, CapturedVars,
+                             RequiresOuterTask, CS, OffloadingMandatory, CGF);
+      return;
+    }
+    bool HasNoWait = D.hasClausesOfKind<OMPNowaitClause>();
+    unsigned NumTargetItems = InputInfo.NumberOfTargetItems;
+    llvm::Value *BasePointersArray = InputInfo.BasePointersArray.getPointer();
+    llvm::Value *PointersArray = InputInfo.PointersArray.getPointer();
+    llvm::Value *SizesArray = InputInfo.SizesArray.getPointer();
+    llvm::Value *MappersArray = InputInfo.MappersArray.getPointer();
+    auto &&EmitTargetCallFallbackCB =
+        [&OMPRuntime, OutlinedFn, &D, &CapturedVars, RequiresOuterTask, &CS,
+         OffloadingMandatory, &CGF](llvm::OpenMPIRBuilder::InsertPointTy IP)
+        -> llvm::OpenMPIRBuilder::InsertPointTy {
+      CGF.Builder.restoreIP(IP);
+      emitTargetCallFallback(OMPRuntime, OutlinedFn, D, CapturedVars,
+                             RequiresOuterTask, CS, OffloadingMandatory, CGF);
+      return CGF.Builder.saveIP();
+    };
+    llvm::Value *DeviceID = emitDeviceID(Device, CGF);
+    llvm::Value *NumTeams = OMPRuntime->emitNumTeamsForTargetDirective(CGF, D);
+    llvm::Value *NumThreads =
+        OMPRuntime->emitNumThreadsForTargetDirective(CGF, D);
+    llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
+    llvm::Value *NumIterations =
+        OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
+    llvm::Value *DynCGGroupMem = emitDynCGGroupMem(D, CGF);
+    llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
+        CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
+    llvm::OpenMPIRBuilder::TargetDataRTArgs RTArgs(
+        BasePointersArray, PointersArray, SizesArray, MapTypesArray,
+        nullptr /* MapTypesArrayEnd */, MappersArray, MapNamesArray);
+    llvm::OpenMPIRBuilder::TargetKernelArgs Args(
+        NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
+        DynCGGroupMem, HasNoWait);
+    CGF.Builder.restoreIP(OMPRuntime->getOMPBuilder().emitKernelLaunch(
+        CGF.Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, Args,
+        DeviceID, RTLoc, AllocaIP));
+  };
+  if (RequiresOuterTask)
+    CGF.EmitOMPTargetTaskBasedDirective(D, ThenGen, InputInfo);
+  else
+    OMPRuntime->emitInlinedDirective(CGF, D.getDirectiveKind(), ThenGen);
+static void
+emitTargetCallElse(CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
+                   const OMPExecutableDirective &D,
+                   llvm::SmallVectorImpl<llvm::Value *> &CapturedVars,
+                   bool RequiresOuterTask, const CapturedStmt &CS,
+                   bool OffloadingMandatory, CodeGenFunction &CGF) {
+  // Notify that the host version must be executed.
+  auto &&ElseGen =
+      [&OMPRuntime, OutlinedFn, &D, &CapturedVars, RequiresOuterTask, &CS,
+       OffloadingMandatory](CodeGenFunction &CGF, PrePostActionTy &) {
+        emitTargetCallFallback(OMPRuntime, OutlinedFn, D, CapturedVars,
+                               RequiresOuterTask, CS, OffloadingMandatory, 
+      };
+  if (RequiresOuterTask) {
+    CodeGenFunction::OMPTargetDataInfo InputInfo;
+    CGF.EmitOMPTargetTaskBasedDirective(D, ElseGen, InputInfo);
+  } else {
+    OMPRuntime->emitInlinedDirective(CGF, D.getDirectiveKind(), ElseGen);
+  }
 void CGOpenMPRuntime::emitTargetCall(
     CodeGenFunction &CGF, const OMPExecutableDirective &D,
     llvm::Function *OutlinedFn, llvm::Value *OutlinedFnID, const Expr *IfCond,
@@ -9602,263 +9849,24 @@ void CGOpenMPRuntime::emitTargetCall(
   CodeGenFunction::OMPTargetDataInfo InputInfo;
   llvm::Value *MapTypesArray = nullptr;
   llvm::Value *MapNamesArray = nullptr;
-  // Generate code for the host fallback function.
-  auto &&FallbackGen = [this, OutlinedFn, &D, &CapturedVars, RequiresOuterTask,
-                        &CS, OffloadingMandatory](CodeGenFunction &CGF) {
-    if (OffloadingMandatory) {
-      CGF.Builder.CreateUnreachable();
-    } else {
-      if (RequiresOuterTask) {
-        CapturedVars.clear();
-        CGF.GenerateOpenMPCapturedVars(CS, CapturedVars);
-      }
-      emitOutlinedFunctionCall(CGF, D.getBeginLoc(), OutlinedFn, CapturedVars);
-    }
-  };
-  // Fill up the pointer arrays and transfer execution to the device.
-  auto &&ThenGen = [this, Device, OutlinedFnID, &D, &InputInfo, &MapTypesArray,
-                    &MapNamesArray, SizeEmitter,
-                    FallbackGen](CodeGenFunction &CGF, PrePostActionTy &) {
-    if (Device.getInt() == OMPC_DEVICE_ancestor) {
-      // Reverse offloading is not supported, so just execute on the host.
-      FallbackGen(CGF);
-      return;
-    }
-    // On top of the arrays that were filled up, the target offloading call
-    // takes as arguments the device id as well as the host pointer. The host
-    // pointer is used by the runtime library to identify the current target
-    // region, so it only has to be unique and not necessarily point to
-    // anything. It could be the pointer to the outlined function that
-    // implements the target region, but we aren't using that so that the
-    // compiler doesn't need to keep that, and could therefore inline the host
-    // function if proven worthwhile during optimization.
-    // From this point on, we need to have an ID of the target region defined.
-    assert(OutlinedFnID && "Invalid outlined function ID!");
-    (void)OutlinedFnID;
-    // Emit device ID if any.
-    llvm::Value *DeviceID;
-    if (Device.getPointer()) {
-      assert((Device.getInt() == OMPC_DEVICE_unknown ||
-              Device.getInt() == OMPC_DEVICE_device_num) &&
-             "Expected device_num modifier.");
-      llvm::Value *DevVal = CGF.EmitScalarExpr(Device.getPointer());
-      DeviceID =
-          CGF.Builder.CreateIntCast(DevVal, CGF.Int64Ty, /*isSigned=*/true);
-    } else {
-      DeviceID = CGF.Builder.getInt64(OMP_DEVICEID_UNDEF);
-    }
-    // Emit the number of elements in the offloading arrays.
-    llvm::Value *PointerNum =
-        CGF.Builder.getInt32(InputInfo.NumberOfTargetItems);
-    // Return value of the runtime offloading call.
-    llvm::Value *Return;
-    llvm::Value *NumTeams = emitNumTeamsForTargetDirective(CGF, D);
-    llvm::Value *NumThreads = emitNumThreadsForTargetDirective(CGF, D);
-    // Source location for the ident struct
-    llvm::Value *RTLoc = emitUpdateLocation(CGF, D.getBeginLoc());
-    // Get tripcount for the target loop-based directive.
-    llvm::Value *NumIterations =
-        emitTargetNumIterationsCall(CGF, D, SizeEmitter);
-    llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
-    if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
-      CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
-      llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
-          DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
-      DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
-                                               /*isSigned=*/false);
-    }
-    llvm::Value *ZeroArray =
-        llvm::Constant::getNullValue(llvm::ArrayType::get(CGF.CGM.Int32Ty, 3));
-    bool HasNoWait = D.hasClausesOfKind<OMPNowaitClause>();
-    llvm::Value *Flags = CGF.Builder.getInt64(HasNoWait);
-    llvm::Value *NumTeams3D =
-        CGF.Builder.CreateInsertValue(ZeroArray, NumTeams, {0});
-    llvm::Value *NumThreads3D =
-        CGF.Builder.CreateInsertValue(ZeroArray, NumThreads, {0});
-    // Arguments for the target kernel.
-    SmallVector<llvm::Value *> KernelArgs{
-        CGF.Builder.getInt32(/* Version */ 2),
-        PointerNum,
-        InputInfo.BasePointersArray.getPointer(),
-        InputInfo.PointersArray.getPointer(),
-        InputInfo.SizesArray.getPointer(),
-        MapTypesArray,
-        MapNamesArray,
-        InputInfo.MappersArray.getPointer(),
-        NumIterations,
-        Flags,
-        NumTeams3D,
-        NumThreads3D,
-        DynCGroupMem,
-    };
-    llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
-        CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
-    // The target region is an outlined function launched by the runtime
-    // via calls to __tgt_target_kernel().
-    //
-    // Note that on the host and CPU targets, the runtime implementation of
-    // these calls simply call the outlined function without forking threads.
-    // The outlined functions themselves have runtime calls to
-    // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
-    // the compiler in emitTeamsCall() and emitParallelCall().
-    //
-    // In contrast, on the NVPTX target, the implementation of
-    // __tgt_target_teams() launches a GPU kernel with the requested number
-    // of teams and threads so no additional calls to the runtime are required.
-    // Check the error code and execute the host version if required.
-    CGF.Builder.restoreIP(OMPBuilder.emitTargetKernel(
-        CGF.Builder, AllocaIP, Return, RTLoc, DeviceID, NumTeams, NumThreads,
-        OutlinedFnID, KernelArgs));
-    llvm::BasicBlock *OffloadFailedBlock =
-        CGF.createBasicBlock("omp_offload.failed");
-    llvm::BasicBlock *OffloadContBlock =
-        CGF.createBasicBlock("omp_offload.cont");
-    llvm::Value *Failed = CGF.Builder.CreateIsNotNull(Return);
-    CGF.Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
-    CGF.EmitBlock(OffloadFailedBlock);
-    FallbackGen(CGF);
-    CGF.EmitBranch(OffloadContBlock);
-    CGF.EmitBlock(OffloadContBlock, /*IsFinished=*/true);
+  auto &&TargetThenGen = [this, OutlinedFn, &D, &CapturedVars,
+                          RequiresOuterTask, &CS, OffloadingMandatory, Device,
+                          OutlinedFnID, &InputInfo, &MapTypesArray,
+                          &MapNamesArray, SizeEmitter](CodeGenFunction &CGF,
+                                                       PrePostActionTy &) {
+    emitTargetCallKernelLaunch(this, OutlinedFn, D, CapturedVars,
+                               RequiresOuterTask, CS, OffloadingMandatory,
+                               Device, OutlinedFnID, InputInfo, MapTypesArray,
+                               MapNamesArray, SizeEmitter, CGF, CGM);
-  // Notify that the host version must be executed.
-  auto &&ElseGen = [FallbackGen](CodeGenFunction &CGF, PrePostActionTy &) {
-    FallbackGen(CGF);
-  };
-  auto &&TargetThenGen = [this, &ThenGen, &D, &InputInfo, &MapTypesArray,
-                          &MapNamesArray, &CapturedVars, RequiresOuterTask,
-                          &CS](CodeGenFunction &CGF, PrePostActionTy &) {
-    // Fill up the arrays with all the captured variables.
-    MappableExprsHandler::MapCombinedInfoTy CombinedInfo;
-    // Get mappable expression information.
-    MappableExprsHandler MEHandler(D, CGF);
-    llvm::DenseMap<llvm::Value *, llvm::Value *> LambdaPointers;
-    llvm::DenseSet<CanonicalDeclPtr<const Decl>> MappedVarSet;
-    auto RI = CS.getCapturedRecordDecl()->field_begin();
-    auto *CV = CapturedVars.begin();
-    for (CapturedStmt::const_capture_iterator CI = CS.capture_begin(),
-                                              CE = CS.capture_end();
-         CI != CE; ++CI, ++RI, ++CV) {
-      MappableExprsHandler::MapCombinedInfoTy CurInfo;
-      MappableExprsHandler::StructRangeInfoTy PartialStruct;
-      // VLA sizes are passed to the outlined region by copy and do not have 
-      // information associated.
-      if (CI->capturesVariableArrayType()) {
-        CurInfo.Exprs.push_back(nullptr);
-        CurInfo.BasePointers.push_back(*CV);
-        CurInfo.DevicePtrDecls.push_back(nullptr);
-        CurInfo.Pointers.push_back(*CV);
-        CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
-            CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
-        // Copy to the device as an argument. No need to retrieve it.
-        CurInfo.Types.push_back(
-            OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
-            OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM |
-            OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
-        CurInfo.Mappers.push_back(nullptr);
-      } else {
-        // If we have any information in the map clause, we use it, otherwise 
-        // just do a default mapping.
-        MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
-        if (!CI->capturesThis())
-          MappedVarSet.insert(CI->getCapturedVar());
-        else
-          MappedVarSet.insert(nullptr);
-        if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
-          MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
-        // Generate correct mapping for variables captured by reference in
-        // lambdas.
-        if (CI->capturesVariable())
-          MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
-                                                  CurInfo, LambdaPointers);
-      }
-      // We expect to have at least an element of information for this capture.
-      assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
-             "Non-existing map pointer for capture!");
-      assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
-             CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
-             CurInfo.BasePointers.size() == CurInfo.Types.size() &&
-             CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
-             "Inconsistent map information sizes!");
-      // If there is an entry in PartialStruct it means we have a struct with
-      // individual members mapped. Emit an extra combined entry.
-      if (PartialStruct.Base.isValid()) {
-        CombinedInfo.append(PartialStruct.PreliminaryMapData);
-        MEHandler.emitCombinedEntry(
-            CombinedInfo, CurInfo.Types, PartialStruct, CI->capturesThis(),
-            nullptr, !PartialStruct.PreliminaryMapData.BasePointers.empty());
-      }
-      // We need to append the results of this capture to what we already have.
-      CombinedInfo.append(CurInfo);
-    }
-    // Adjust MEMBER_OF flags for the lambdas captures.
-    MEHandler.adjustMemberOfForLambdaCaptures(
-        LambdaPointers, CombinedInfo.BasePointers, CombinedInfo.Pointers,
-        CombinedInfo.Types);
-    // Map any list items in a map clause that were not captures because they
-    // weren't referenced within the construct.
-    MEHandler.generateAllInfo(CombinedInfo, MappedVarSet);
-    CGOpenMPRuntime::TargetDataInfo Info;
-    // Fill up the arrays and create the arguments.
-    emitOffloadingArrays(CGF, CombinedInfo, Info, OMPBuilder);
-    bool EmitDebug = CGF.CGM.getCodeGenOpts().getDebugInfo() !=
-                     llvm::codegenoptions::NoDebugInfo;
-    OMPBuilder.emitOffloadingArraysArgument(CGF.Builder, Info.RTArgs, Info,
-                                            EmitDebug,
-                                            /*ForEndCall=*/false);
-    InputInfo.NumberOfTargetItems = Info.NumberOfPtrs;
-    InputInfo.BasePointersArray = Address(Info.RTArgs.BasePointersArray,
-                                          CGF.VoidPtrTy, 
-    InputInfo.PointersArray = Address(Info.RTArgs.PointersArray, CGF.VoidPtrTy,
-                                      CGM.getPointerAlign());
-    InputInfo.SizesArray =
-        Address(Info.RTArgs.SizesArray, CGF.Int64Ty, CGM.getPointerAlign());
-    InputInfo.MappersArray =
-        Address(Info.RTArgs.MappersArray, CGF.VoidPtrTy, 
-    MapTypesArray = Info.RTArgs.MapTypesArray;
-    MapNamesArray = Info.RTArgs.MapNamesArray;
-    if (RequiresOuterTask)
-      CGF.EmitOMPTargetTaskBasedDirective(D, ThenGen, InputInfo);
-    else
-      emitInlinedDirective(CGF, D.getDirectiveKind(), ThenGen);
-  };
-  auto &&TargetElseGen = [this, &ElseGen, &D, RequiresOuterTask](
-                             CodeGenFunction &CGF, PrePostActionTy &) {
-    if (RequiresOuterTask) {
-      CodeGenFunction::OMPTargetDataInfo InputInfo;
-      CGF.EmitOMPTargetTaskBasedDirective(D, ElseGen, InputInfo);
-    } else {
-      emitInlinedDirective(CGF, D.getDirectiveKind(), ElseGen);
-    }
-  };
+  auto &&TargetElseGen =
+      [this, OutlinedFn, &D, &CapturedVars, RequiresOuterTask, &CS,
+       OffloadingMandatory](CodeGenFunction &CGF, PrePostActionTy &) {
+        emitTargetCallElse(this, OutlinedFn, D, CapturedVars, 
+                           CS, OffloadingMandatory, CGF);
+      };
   // If we have a target function ID it means that we need to support
   // offloading, otherwise, just execute on the host. We need to execute on 

diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.h 
index b1164dc3f72904..4370c95e03feac 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -327,42 +327,6 @@ class CGOpenMPRuntime {
                                                 bool IsOffloadEntry,
                                                 const RegionCodeGenTy 
-  /// Emits object of ident_t type with info for source location.
-  /// \param Flags Flags for OpenMP location.
-  /// \param EmitLoc emit source location with debug-info is off.
-  ///
-  llvm::Value *emitUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc,
-                                  unsigned Flags = 0, bool EmitLoc = false);
-  /// Emit the number of teams for a target directive.  Inspect the num_teams
-  /// clause associated with a teams construct combined or closely nested
-  /// with the target directive.
-  ///
-  /// Emit a team of size one for directives such as 'target parallel' that
-  /// have no associated teams construct.
-  ///
-  /// Otherwise, return nullptr.
-  const Expr *getNumTeamsExprForTargetDirective(CodeGenFunction &CGF,
-                                                const OMPExecutableDirective 
-                                                int32_t &DefaultVal);
-  llvm::Value *emitNumTeamsForTargetDirective(CodeGenFunction &CGF,
-                                              const OMPExecutableDirective &D);
-  /// Emit the number of threads for a target directive.  Inspect the
-  /// thread_limit clause associated with a teams construct combined or closely
-  /// nested with the target directive.
-  ///
-  /// Emit the num_threads clause for directives such as 'target parallel' that
-  /// have no associated teams construct.
-  ///
-  /// Otherwise, return nullptr.
-  const Expr *
-  getNumThreadsExprForTargetDirective(CodeGenFunction &CGF,
-                                      const OMPExecutableDirective &D,
-                                      int32_t &DefaultVal);
-  llvm::Value *
-  emitNumThreadsForTargetDirective(CodeGenFunction &CGF,
-                                   const OMPExecutableDirective &D);
   /// Returns pointer to ident_t type.
   llvm::Type *getIdentTyPointerTy();
@@ -652,15 +616,6 @@ class CGOpenMPRuntime {
                             llvm::Function *TaskFunction, QualType SharedsTy,
                             Address Shareds, const OMPTaskDataTy &Data);
-  /// Return the trip count of loops associated with constructs / 'target teams
-  /// distribute' and 'teams distribute parallel for'. \param SizeEmitter Emits
-  /// the int64 value for the number of iterations of the associated loop.
-  llvm::Value *emitTargetNumIterationsCall(
-      CodeGenFunction &CGF, const OMPExecutableDirective &D,
-      llvm::function_ref<llvm::Value *(CodeGenFunction &CGF,
-                                       const OMPLoopDirective &D)>
-          SizeEmitter);
   /// Emit update for lastprivate conditional data.
   void emitLastprivateConditionalUpdate(CodeGenFunction &CGF, LValue IVLVal,
                                         StringRef UniqueDeclName, LValue LVal,
@@ -687,6 +642,51 @@ class CGOpenMPRuntime {
   virtual ~CGOpenMPRuntime() {}
   virtual void clear();
+  /// Emits object of ident_t type with info for source location.
+  /// \param Flags Flags for OpenMP location.
+  /// \param EmitLoc emit source location with debug-info is off.
+  ///
+  llvm::Value *emitUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc,
+                                  unsigned Flags = 0, bool EmitLoc = false);
+  /// Emit the number of teams for a target directive.  Inspect the num_teams
+  /// clause associated with a teams construct combined or closely nested
+  /// with the target directive.
+  ///
+  /// Emit a team of size one for directives such as 'target parallel' that
+  /// have no associated teams construct.
+  ///
+  /// Otherwise, return nullptr.
+  const Expr *getNumTeamsExprForTargetDirective(CodeGenFunction &CGF,
+                                                const OMPExecutableDirective 
+                                                int32_t &DefaultVal);
+  llvm::Value *emitNumTeamsForTargetDirective(CodeGenFunction &CGF,
+                                              const OMPExecutableDirective &D);
+  /// Emit the number of threads for a target directive.  Inspect the
+  /// thread_limit clause associated with a teams construct combined or closely
+  /// nested with the target directive.
+  ///
+  /// Emit the num_threads clause for directives such as 'target parallel' that
+  /// have no associated teams construct.
+  ///
+  /// Otherwise, return nullptr.
+  const Expr *
+  getNumThreadsExprForTargetDirective(CodeGenFunction &CGF,
+                                      const OMPExecutableDirective &D,
+                                      int32_t &DefaultVal);
+  llvm::Value *
+  emitNumThreadsForTargetDirective(CodeGenFunction &CGF,
+                                   const OMPExecutableDirective &D);
+  /// Return the trip count of loops associated with constructs / 'target teams
+  /// distribute' and 'teams distribute parallel for'. \param SizeEmitter Emits
+  /// the int64 value for the number of iterations of the associated loop.
+  llvm::Value *emitTargetNumIterationsCall(
+      CodeGenFunction &CGF, const OMPExecutableDirective &D,
+      llvm::function_ref<llvm::Value *(CodeGenFunction &CGF,
+                                       const OMPLoopDirective &D)>
+          SizeEmitter);
   /// Returns true if the current target is a GPU.
   virtual bool isTargetCodegen() const { return false; }

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h 
index 948cc128c65ca3..32dcdd587f3b31 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -72,6 +72,9 @@ enum class IdentFlag {
 #define OMP_IDENT_FLAG(Enum, ...) constexpr auto Enum = omp::IdentFlag::Enum;
 #include "llvm/Frontend/OpenMP/OMPKinds.def"
+// Version of the kernel argument format used by the omp runtime.
 /// \note This needs to be kept in sync with kmp.h enum sched_type.
 /// Todo: Update kmp.h to include this file, and remove the enums in kmp.h
 enum class OMPScheduleType {

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
index ac16e827e622d8..ed0c923ceaca12 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1533,7 +1533,6 @@ class OpenMPIRBuilder {
   /// Container for the arguments used to pass data to the runtime library.
   struct TargetDataRTArgs {
-    explicit TargetDataRTArgs() {}
     /// The array of base pointer passed to the runtime library.
     Value *BasePointersArray = nullptr;
     /// The array of section pointers passed to the runtime library.
@@ -1553,8 +1552,53 @@ class OpenMPIRBuilder {
     /// The array of original declaration names of mapped pointers sent to the
     /// runtime library for debugging
     Value *MapNamesArray = nullptr;
+    explicit TargetDataRTArgs() {}
+    explicit TargetDataRTArgs(Value *BasePointersArray, Value *PointersArray,
+                              Value *SizesArray, Value *MapTypesArray,
+                              Value *MapTypesArrayEnd, Value *MappersArray,
+                              Value *MapNamesArray)
+        : BasePointersArray(BasePointersArray), PointersArray(PointersArray),
+          SizesArray(SizesArray), MapTypesArray(MapTypesArray),
+          MapTypesArrayEnd(MapTypesArrayEnd), MappersArray(MappersArray),
+          MapNamesArray(MapNamesArray) {}
+  /// Data structure that contains the needed information to construct the
+  /// kernel args vector.
+  struct TargetKernelArgs {
+    /// Number of arguments passed to the runtime library.
+    unsigned NumTargetItems;
+    /// Arguments passed to the runtime library
+    TargetDataRTArgs RTArgs;
+    /// The number of iterations
+    Value *NumIterations;
+    /// The number of teams.
+    Value *NumTeams;
+    /// The number of threads.
+    Value *NumThreads;
+    /// The size of the dynamic shared memory.
+    Value *DynCGGroupMem;
+    /// True if the kernel has 'no wait' clause.
+    bool HasNoWait;
+    /// Constructor for TargetKernelArgs
+    TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
+                     Value *NumIterations, Value *NumTeams, Value *NumThreads,
+                     Value *DynCGGroupMem, bool HasNoWait)
+        : NumTargetItems(NumTargetItems), RTArgs(RTArgs),
+          NumIterations(NumIterations), NumTeams(NumTeams),
+          NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
+          HasNoWait(HasNoWait) {}
+  };
+  /// Create the kernel args vector used by emitTargetKernel. This function
+  /// creates various constant values that are used in the resulting args
+  /// vector.
+  static void getKernelArgsVector(TargetKernelArgs &KernelArgs,
+                                  IRBuilderBase &Builder,
+                                  SmallVector<Value *> &ArgsVector);
   /// Struct that keeps the information that should be kept throughout
   /// a 'target data' region.
   class TargetDataInfo {
@@ -1636,6 +1680,28 @@ class OpenMPIRBuilder {
+  /// Callback function type for functions emitting the host fallback code that
+  /// is executed when the kernel launch fails. It takes an insertion point as
+  /// parameter where the code should be emitted. It returns an insertion point
+  /// that points right after after the emitted code.
+  using EmitFallbackCallbackTy = function_ref<InsertPointTy(InsertPointTy)>;
+  /// Generate a target region entry call and host fallback call.
+  ///
+  /// \param Loc The location at which the request originated and is fulfilled.
+  /// \param OutlinedFn The outlined kernel function.
+  /// \param OutlinedFnID The ooulined function ID.
+  /// \param EmitTargetCallFallbackCB Call back function to generate host
+  ///        fallback code.
+  /// \param Args Data structure holding information about the kernel 
+  /// \param DeviceID Identifier for the device via the 'device' clause.
+  /// \param RTLoc Source location identifier
+  /// \param AllocaIP The insertion point to be used for alloca instructions.
+  InsertPointTy emitKernelLaunch(
+      const LocationDescription &Loc, Function *OutlinedFn, Value 
+      EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
+      Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP);
   /// Emit the arguments to be passed to the runtime library based on the
   /// arrays of base pointers, pointers, sizes, map types, and mappers.  If
   /// ForEndCall, emit map types to be passed for the end of the region instead

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
index 90d831c3bb583b..8c3ff591af1e35 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -331,6 +331,35 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase 
&Builder, bool CreateBranch,
   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
+void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
+                                          IRBuilderBase &Builder,
+                                          SmallVector<Value *> &ArgsVector) {
+  Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
+  Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
+  auto Int32Ty = Type::getInt32Ty(Builder.getContext());
+  Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
+  Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
+  Value *NumTeams3D =
+      Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
+  Value *NumThreads3D =
+      Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
+  ArgsVector = {Version,
+                PointerNum,
+                KernelArgs.RTArgs.BasePointersArray,
+                KernelArgs.RTArgs.PointersArray,
+                KernelArgs.RTArgs.SizesArray,
+                KernelArgs.RTArgs.MapTypesArray,
+                KernelArgs.RTArgs.MapNamesArray,
+                KernelArgs.RTArgs.MappersArray,
+                KernelArgs.NumIterations,
+                Flags,
+                NumTeams3D,
+                NumThreads3D,
+                KernelArgs.DynCGGroupMem};
 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
   LLVMContext &Ctx = Fn.getContext();
   Triple T(M.getTargetTriple());
@@ -880,6 +909,67 @@ OpenMPIRBuilder::InsertPointTy 
   return Builder.saveIP();
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
+    const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
+    EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
+    Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+  Builder.restoreIP(Loc.IP);
+  // On top of the arrays that were filled up, the target offloading call
+  // takes as arguments the device id as well as the host pointer. The host
+  // pointer is used by the runtime library to identify the current target
+  // region, so it only has to be unique and not necessarily point to
+  // anything. It could be the pointer to the outlined function that
+  // implements the target region, but we aren't using that so that the
+  // compiler doesn't need to keep that, and could therefore inline the host
+  // function if proven worthwhile during optimization.
+  // From this point on, we need to have an ID of the target region defined.
+  assert(OutlinedFnID && "Invalid outlined function ID!");
+  (void)OutlinedFnID;
+  // Return value of the runtime offloading call.
+  Value *Return;
+  // Arguments for the target kernel.
+  SmallVector<Value *> ArgsVector;
+  getKernelArgsVector(Args, Builder, ArgsVector);
+  // The target region is an outlined function launched by the runtime
+  // via calls to __tgt_target_kernel().
+  //
+  // Note that on the host and CPU targets, the runtime implementation of
+  // these calls simply call the outlined function without forking threads.
+  // The outlined functions themselves have runtime calls to
+  // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
+  // the compiler in emitTeamsCall() and emitParallelCall().
+  //
+  // In contrast, on the NVPTX target, the implementation of
+  // __tgt_target_teams() launches a GPU kernel with the requested number
+  // of teams and threads so no additional calls to the runtime are required.
+  // Check the error code and execute the host version if required.
+  Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, 
+                                     Args.NumTeams, Args.NumThreads,
+                                     OutlinedFnID, ArgsVector));
+  BasicBlock *OffloadFailedBlock =
+      BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
+  BasicBlock *OffloadContBlock =
+      BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
+  Value *Failed = Builder.CreateIsNotNull(Return);
+  Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
+  auto CurFn = Builder.GetInsertBlock()->getParent();
+  emitBlock(OffloadFailedBlock, CurFn);
+  Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
+  emitBranch(OffloadContBlock);
+  emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
+  return Builder.saveIP();
 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
                                                FinalizeCallbackTy ExitCB) {

cfe-commits mailing list

Reply via email to