https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051
>From 47a64959c79fbdec4e6108ddad8f2b2317ce0b76 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Wed, 27 Nov 2024 11:33:01 +0000 Subject: [PATCH] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, kernel type information is used to influence target device code generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more granularity. --- clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp | 5 +- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 38 ++- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 129 ++++++--- .../Frontend/OpenMPIRBuilderTest.cpp | 263 +++++++++++++++++- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 12 +- 5 files changed, 386 insertions(+), 61 deletions(-) diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp index 654a13d75ec810..1e2e693d91de72 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -20,6 +20,7 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Basic/Cuda.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPGridValues.h" using namespace clang; @@ -745,7 +746,9 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D, CodeGenFunction &CGF, EntryFunctionState &EST, bool IsSPMD) { llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs; - Attrs.IsSPMD = IsSPMD; + Attrs.ExecFlags = + IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD + : llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC; computeMinAndMaxThreadsAndTeams(D, CGF, Attrs); CGBuilderTy &Bld = CGF.Builder; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 8ca3bc08b5ad49..7eceec3d8cf8f5 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1389,9 +1389,6 @@ class OpenMPIRBuilder { /// Supporting functions for Reductions CodeGen. private: - /// Emit the llvm.used metadata. - void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List); - /// Get the id of the current thread on the GPU. Value *getGPUThreadID(); @@ -2013,6 +2010,13 @@ class OpenMPIRBuilder { /// Value. GlobalValue *createGlobalFlag(unsigned Value, StringRef Name); + /// Emit the llvm.used metadata. + void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List); + + /// Emit the kernel execution mode. + GlobalVariable *emitKernelExecutionMode(StringRef KernelName, + omp::OMPTgtExecModeFlags Mode); + /// Generate control flow and cleanup for cancellation. /// /// \param CancelFlag Flag indicating if the cancellation is performed. @@ -2233,13 +2237,34 @@ class OpenMPIRBuilder { /// time. The number of max values will be 1 except for the case where /// ompx_bare is set. struct TargetKernelDefaultAttrs { - bool IsSPMD = false; + omp::OMPTgtExecModeFlags ExecFlags = + omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC; SmallVector<int32_t, 3> MaxTeams = {-1}; int32_t MinTeams = 1; SmallVector<int32_t, 3> MaxThreads = {-1}; int32_t MinThreads = 1; }; + /// Container to pass LLVM IR runtime values or constants related to the + /// number of teams and threads with which the kernel must be launched, as + /// well as the trip count of the loop, if it is an SPMD or Generic-SPMD + /// kernel. These must be defined in the host prior to the call to the kernel + /// launch OpenMP RTL function. + struct TargetKernelRuntimeAttrs { + SmallVector<Value *, 3> MaxTeams = {nullptr}; + Value *MinTeams = nullptr; + SmallVector<Value *, 3> TargetThreadLimit = {nullptr}; + SmallVector<Value *, 3> TeamsThreadLimit = {nullptr}; + + /// 'parallel' construct 'num_threads' clause value, if present and it is an + /// SPMD kernel. + Value *MaxThreads = nullptr; + + /// Total number of iterations of the SPMD or Generic-SPMD kernel or null if + /// it is a generic kernel. + Value *LoopTripCount = nullptr; + }; + /// Data structure that contains the needed information to construct the /// kernel args vector. struct TargetKernelArgs { @@ -2971,7 +2996,9 @@ class OpenMPIRBuilder { /// \param CodeGenIP The insertion point where the call to the outlined /// function should be emitted. /// \param EntryInfo The entry information about the function. - /// \param DefaultAttrs Structure containing the default numbers of threads + /// \param DefaultAttrs Structure containing the default attributes, including + /// numbers of threads and teams to launch the kernel with. + /// \param RuntimeAttrs Structure containing the runtime numbers of threads /// and teams to launch the kernel with. /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. @@ -2987,6 +3014,7 @@ class OpenMPIRBuilder { OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 82c7be79cae2af..77a852fee850ed 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) { return GV; } +void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) { + if (List.empty()) + return; + + // Convert List to what ConstantArray needs. + SmallVector<Constant *, 8> UsedArray; + UsedArray.resize(List.size()); + for (unsigned I = 0, E = List.size(); I != E; ++I) + UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(&*List[I]), Builder.getPtrTy()); + + if (UsedArray.empty()) + return; + ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size()); + + auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, + ConstantArray::get(ATy, UsedArray), Name); + + GV->setSection("llvm.metadata"); +} + +GlobalVariable * +OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName, + OMPTgtExecModeFlags Mode) { + auto *Int8Ty = Builder.getInt8Ty(); + auto *GVMode = new GlobalVariable( + M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage, + ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode")); + GVMode->setVisibility(GlobalVariable::ProtectedVisibility); + return GVMode; +} + Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr, uint32_t SrcLocStrSize, IdentFlag LocFlags, @@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) { return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT); } -void OpenMPIRBuilder::emitUsed(StringRef Name, - std::vector<WeakTrackingVH> &List) { - if (List.empty()) - return; - - // Convert List to what ConstantArray needs. - SmallVector<Constant *, 8> UsedArray; - UsedArray.resize(List.size()); - for (unsigned I = 0, E = List.size(); I != E; ++I) - UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast( - cast<Constant>(&*List[I]), Builder.getPtrTy()); - - if (UsedArray.empty()) - return; - ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size()); - - auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, - ConstantArray::get(ATy, UsedArray), Name); - - GV->setSection("llvm.metadata"); -} - Value *OpenMPIRBuilder::getGPUThreadID() { return Builder.CreateCall( getOrCreateRuntimeFunction(M, @@ -6140,10 +6150,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( uint32_t SrcLocStrSize; Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); - Constant *IsSPMDVal = ConstantInt::getSigned( - Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC); - Constant *UseGenericStateMachineVal = - ConstantInt::getSigned(Int8, !Attrs.IsSPMD); + Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags); + Constant *UseGenericStateMachineVal = ConstantInt::getSigned( + Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD); Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true); Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0); @@ -6778,6 +6787,12 @@ static Expected<Function *> createOutlinedFunction( auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M); + if (OMPBuilder.Config.isTargetDevice()) { + Value *ExecMode = + OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags); + OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode}); + } + // Save insert point. IRBuilder<>::InsertPointGuard IPG(Builder); // If there's a DISubprogram associated with current function, then @@ -7325,6 +7340,7 @@ static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, + const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl<Value *> &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, @@ -7406,11 +7422,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, /*ForEndCall=*/false); SmallVector<Value *, 3> NumTeamsC; + for (auto [DefaultVal, RuntimeVal] : + zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) + NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal)); + + // Calculate number of threads: 0 if no clauses specified, otherwise it is the + // minimum between optional THREAD_LIMIT and NUM_THREADS clauses. + auto InitMaxThreadsClause = [&Builder](Value *Clause) { + if (Clause) + Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), + /*isSigned=*/false); + return Clause; + }; + auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { + if (Clause) + Result = Result + ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), + Result, Clause) + : Clause; + }; + + // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so + // the NUM_THREADS clause is overriden by THREAD_LIMIT. SmallVector<Value *, 3> NumThreadsC; - for (auto V : DefaultAttrs.MaxTeams) - NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); - for (auto V : DefaultAttrs.MaxThreads) - NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1 + ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) + : nullptr; + + for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit, + RuntimeAttrs.TargetThreadLimit)) { + Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); + Value *NumThreads = InitMaxThreadsClause(TargetVal); + + CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); + CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); + + NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + } unsigned NumTargetItems = Info.NumberOfPtrs; // TODO: Use correct device ID @@ -7419,14 +7467,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, llvm::omp::IdentFlag(0), 0); - // TODO: Use correct NumIterations - Value *NumIterations = Builder.getInt64(0); + + Value *TripCount = RuntimeAttrs.LoopTripCount + ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, + Builder.getInt64Ty(), + /*isSigned=*/false) + : Builder.getInt64(0); + // TODO: Use correct DynCGGroupMem Value *DynCGGroupMem = Builder.getInt32(0); - KArgs = OpenMPIRBuilder::TargetKernelArgs( - NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); + KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, + NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); // The presence of certain clauses on the target directive require the // explicit generation of the target task. @@ -7451,6 +7504,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, @@ -7475,8 +7529,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait); + emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, + OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, + HasNowait); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 670841aadafc2d..58a1a52207b119 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6123,7 +6123,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { OMPBuilder.setConfig(Config); F->setName("func"); IRBuilder<> Builder(BB); - auto Int32Ty = Builder.getInt32Ty(); + auto *Int32Ty = Builder.getInt32Ty(); AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr"); AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr"); @@ -6182,13 +6182,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { - /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, - /*MinThreads=*/0}; - OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = - OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, + /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20); + RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30); + RuntimeAttrs.MaxThreads = Builder.getInt32(40); + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), + EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, + SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); OMPBuilder.finalize(); @@ -6208,6 +6212,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { StringRef FunctionName = KernelLaunchFunc->getName(); EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel")); + // Check num_teams and num_threads in call arguments + EXPECT_TRUE(Call->arg_size() >= 4); + Value *NumTeamsArg = Call->getArgOperand(2); + EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg)); + EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue()); + Value *NumThreadsArg = Call->getArgOperand(3); + EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg)); + EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue()); + + // Check num_teams and num_threads kernel arguments (use number 5 starting + // from the end and counting the call to __tgt_target_kernel as the first use) + Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1); + EXPECT_TRUE(KernelArgs->getNumUses() >= 4); + Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3); + EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr)); + Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(NumTeamsStore)); + Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg)); + auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg); + EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements()); + EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0)); + EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1)); + EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2)); + Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2); + EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr)); + Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(NumThreadsStore)); + Value *NumThreadsStoreArg = + cast<StoreInst>(NumThreadsStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg)); + auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg); + EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements()); + EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0)); + EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1)); + EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2)); + // Check the fallback call BasicBlock *FallbackBlock = Branch->getSuccessor(0); Iter = FallbackBlock->rbegin(); @@ -6296,12 +6337,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, /*Line=*/3, /*Count=*/0); + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { - /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, - /*MinThreads=*/0}; + /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, - CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + RuntimeAttrs, CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); @@ -6386,6 +6428,200 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { auto *ExitBlock = EntryBlockBranch->getSuccessor(1); EXPECT_EQ(ExitBlock->getName(), "worker.exit"); EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI())); + + // Check global exec_mode. + GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used"); + EXPECT_NE(Used, nullptr); + Constant *UsedInit = Used->getInitializer(); + EXPECT_NE(UsedInit, nullptr); + EXPECT_TRUE(isa<ConstantArray>(UsedInit)); + auto *UsedInitData = cast<ConstantArray>(UsedInit); + EXPECT_EQ(1U, UsedInitData->getNumOperands()); + Constant *ExecMode = UsedInitData->getOperand(0); + EXPECT_TRUE(isa<GlobalVariable>(ExecMode)); + Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer(); + EXPECT_NE(ExecModeValue, nullptr); + EXPECT_TRUE(isa<ConstantInt>(ExecModeValue)); + EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC, + cast<ConstantInt>(ExecModeValue)->getZExtValue()); +} + +TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false, + /*OpenMPOffloadMandatory=*/false, + /*HasRequiresReverseOffload=*/false, + /*HasRequiresUnifiedAddress=*/false, + /*HasRequiresUnifiedSharedMemory=*/false, + /*HasRequiresDynamicAllocators=*/false); + OMPBuilder.setConfig(Config); + F->setName("func"); + IRBuilder<> Builder(BB); + + auto BodyGenCB = [&](InsertPointTy, + InsertPointTy CodeGenIP) -> InsertPointTy { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&, + OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + SmallVector<Value *> Inputs; + OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = + [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & { + return CombinedInfos; + }; + + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); + OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { + /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), + EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, + SimpleArgAccessorCB); + assert(AfterIP && "unexpected error"); + Builder.restoreIP(*AfterIP); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check the kernel launch sequence + auto Iter = F->getEntryBlock().rbegin(); + EXPECT_TRUE(isa<BranchInst>(&*(Iter))); + BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter)); + EXPECT_TRUE(isa<CmpInst>(&*(++Iter))); + EXPECT_TRUE(isa<CallInst>(&*(++Iter))); + CallInst *Call = dyn_cast<CallInst>(&*(Iter)); + + // Check that the kernel launch function is called + Function *KernelLaunchFunc = Call->getCalledFunction(); + EXPECT_NE(KernelLaunchFunc, nullptr); + StringRef FunctionName = KernelLaunchFunc->getName(); + EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel")); + + // Check the trip count kernel argument (use number 5 starting from the end + // and counting the call to __tgt_target_kernel as the first use) + Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1); + EXPECT_TRUE(KernelArgs->getNumUses() >= 6); + Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5); + EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr)); + Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(TripCountStore)); + Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg)); + EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue()); + + // Check the fallback call + BasicBlock *FallbackBlock = Branch->getSuccessor(0); + Iter = FallbackBlock->rbegin(); + CallInst *FCall = dyn_cast<CallInst>(&*(++Iter)); + // 'F' has a dummy DISubprogram which causes OutlinedFunc to also + // have a DISubprogram. In this case, the call to OutlinedFunc needs + // to have a debug loc, otherwise verifier will complain. + FCall->setDebugLoc(DL); + EXPECT_NE(FCall, nullptr); + + // Check that the outlined function exists with the expected prefix + Function *OutlinedFunc = FCall->getCalledFunction(); + EXPECT_NE(OutlinedFunc, nullptr); + StringRef FunctionName2 = OutlinedFunc->getName(); + EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading")); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + +TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.setConfig( + OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false, + /*OpenMPOffloadMandatory=*/false, + /*HasRequiresReverseOffload=*/false, + /*HasRequiresUnifiedAddress=*/false, + /*HasRequiresUnifiedSharedMemory=*/false, + /*HasRequiresDynamicAllocators=*/false)); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + Function *OutlinedFn = nullptr; + SmallVector<Value *> CapturedArgs; + + auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&, + OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = + [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & { + return CombinedInfos; + }; + + auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) + -> OpenMPIRBuilder::InsertPointTy { + Builder.restoreIP(CodeGenIP); + OutlinedFn = CodeGenIP.getBlock()->getParent(); + return Builder.saveIP(); + }; + + IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, + /*Line=*/3, /*Count=*/0); + + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { + /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, + RuntimeAttrs, CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + assert(AfterIP && "unexpected error"); + Builder.restoreIP(*AfterIP); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + // Check outlined function + EXPECT_FALSE(verifyModule(*M, &errs())); + EXPECT_NE(OutlinedFn, nullptr); + EXPECT_NE(F, OutlinedFn); + + EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); + // Account for the "implicit" first argument. + EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); + EXPECT_EQ(OutlinedFn->arg_size(), 1U); + + // Check global exec_mode. + GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used"); + EXPECT_NE(Used, nullptr); + Constant *UsedInit = Used->getInitializer(); + EXPECT_NE(UsedInit, nullptr); + EXPECT_TRUE(isa<ConstantArray>(UsedInit)); + auto *UsedInitData = cast<ConstantArray>(UsedInit); + EXPECT_EQ(1U, UsedInitData->getNumOperands()); + Constant *ExecMode = UsedInitData->getOperand(0); + EXPECT_TRUE(isa<GlobalVariable>(ExecMode)); + Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer(); + EXPECT_NE(ExecModeValue, nullptr); + EXPECT_TRUE(isa<ConstantInt>(ExecModeValue)); + EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD, + cast<ConstantInt>(ExecModeValue)->getZExtValue()); } TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { @@ -6454,12 +6690,13 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, /*Line=*/3, /*Count=*/0); + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { - /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, - /*MinThreads=*/0}; + /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, - CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + RuntimeAttrs, CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index b2faefc6199485..d71d3a9d7421f9 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4085,10 +4085,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, allocaIP, codeGenIP); }; - // TODO: Populate default attributes based on the construct and clauses. + // TODO: Populate default and runtime attributes based on the construct and + // clauses. + llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs; llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = { - /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, - /*MinThreads=*/0}; + /*ExecFlags=*/llvm::omp::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, + /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; llvm::SmallVector<llvm::Value *, 4> kernelInput; for (size_t i = 0; i < mapVars.size(); ++i) { @@ -4113,8 +4115,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds, - targetOp.getNowait()); + defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB, + argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits