llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Shilei Tian (shiltian) <details> <summary>Changes</summary> --- Full diff: https://github.com/llvm/llvm-project/pull/102717.diff 4 Files Affected: - (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+17-12) - (modified) clang/test/OpenMP/target_teams_codegen.cpp (+6-6) - (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+13-13) - (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+18-13) ``````````diff diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index 8c5e4aa9c037e2..6c0c8646898cc6 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -9588,15 +9588,17 @@ static void genMapInfo(const OMPExecutableDirective &D, CodeGenFunction &CGF, genMapInfo(MEHandler, CGF, CombinedInfo, OMPBuilder, MappedVarSet); } -static void emitNumTeamsForBareTargetDirective( +template <typename ClauseTy> +static void emitClauseForBareTargetDirective( CodeGenFunction &CGF, const OMPExecutableDirective &D, - llvm::SmallVectorImpl<llvm::Value *> &NumTeams) { - const auto *C = D.getSingleClause<OMPNumTeamsClause>(); - assert(!C->varlist_empty() && "ompx_bare requires explicit num_teams"); - CodeGenFunction::RunCleanupsScope NumTeamsScope(CGF); - for (auto *E : C->getNumTeams()) { + llvm::SmallVectorImpl<llvm::Value *> &Valuess) { + const auto *C = D.getSingleClause<ClauseTy>(); + assert(!C->varlist_empty() && + "ompx_bare requires explicit num_teams and thread_limit"); + CodeGenFunction::RunCleanupsScope Scope(CGF); + for (auto *E : C->varlist()) { llvm::Value *V = CGF.EmitScalarExpr(E); - NumTeams.push_back( + Valuess.push_back( CGF.Builder.CreateIntCast(V, CGF.Int32Ty, /*isSigned=*/true)); } } @@ -9672,14 +9674,17 @@ static void emitTargetCallKernelLaunch( bool IsBare = D.hasClausesOfKind<OMPXBareClause>(); SmallVector<llvm::Value *, 3> NumTeams; - if (IsBare) - emitNumTeamsForBareTargetDirective(CGF, D, NumTeams); - else + SmallVector<llvm::Value *, 3> NumThreads; + if (IsBare) { + emitClauseForBareTargetDirective<OMPNumTeamsClause>(CGF, D, NumTeams); + emitClauseForBareTargetDirective<OMPThreadLimitClause>(CGF, D, + NumThreads); + } else { NumTeams.push_back(OMPRuntime->emitNumTeamsForTargetDirective(CGF, D)); + NumThreads.push_back(OMPRuntime->emitNumThreadsForTargetDirective(CGF, D)); + } llvm::Value *DeviceID = emitDeviceID(Device, CGF); - llvm::Value *NumThreads = - OMPRuntime->emitNumThreadsForTargetDirective(CGF, D); llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc()); llvm::Value *NumIterations = OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter); diff --git a/clang/test/OpenMP/target_teams_codegen.cpp b/clang/test/OpenMP/target_teams_codegen.cpp index 9cab8eef148833..13d44e127201bd 100644 --- a/clang/test/OpenMP/target_teams_codegen.cpp +++ b/clang/test/OpenMP/target_teams_codegen.cpp @@ -127,13 +127,13 @@ int foo(int n) { aa += 1; } - #pragma omp target teams ompx_bare num_teams(1, 2) thread_limit(1) + #pragma omp target teams ompx_bare num_teams(1, 2) thread_limit(1, 2) { a += 1; aa += 1; } - #pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(1) + #pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(1, 2, 3) { a += 1; aa += 1; @@ -667,7 +667,7 @@ int bar(int n){ // CHECK1-NEXT: [[TMP144:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 10 // CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP144]], align 4 // CHECK1-NEXT: [[TMP145:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 11 -// CHECK1-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP145]], align 4 +// CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP145]], align 4 // CHECK1-NEXT: [[TMP146:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 12 // CHECK1-NEXT: store i32 0, ptr [[TMP146]], align 4 // CHECK1-NEXT: [[TMP147:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l130.region_id, ptr [[KERNEL_ARGS29]]) @@ -720,7 +720,7 @@ int bar(int n){ // CHECK1-NEXT: [[TMP171:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 10 // CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP171]], align 4 // CHECK1-NEXT: [[TMP172:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 11 -// CHECK1-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP172]], align 4 +// CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP172]], align 4 // CHECK1-NEXT: [[TMP173:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 12 // CHECK1-NEXT: store i32 0, ptr [[TMP173]], align 4 // CHECK1-NEXT: [[TMP174:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l136.region_id, ptr [[KERNEL_ARGS37]]) @@ -2458,7 +2458,7 @@ int bar(int n){ // CHECK3-NEXT: [[TMP142:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 10 // CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP142]], align 4 // CHECK3-NEXT: [[TMP143:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 11 -// CHECK3-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP143]], align 4 +// CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP143]], align 4 // CHECK3-NEXT: [[TMP144:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 12 // CHECK3-NEXT: store i32 0, ptr [[TMP144]], align 4 // CHECK3-NEXT: [[TMP145:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l130.region_id, ptr [[KERNEL_ARGS29]]) @@ -2511,7 +2511,7 @@ int bar(int n){ // CHECK3-NEXT: [[TMP169:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 10 // CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP169]], align 4 // CHECK3-NEXT: [[TMP170:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 11 -// CHECK3-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP170]], align 4 +// CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP170]], align 4 // CHECK3-NEXT: [[TMP171:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 12 // CHECK3-NEXT: store i32 0, ptr [[TMP171]], align 4 // CHECK3-NEXT: [[TMP172:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l136.region_id, ptr [[KERNEL_ARGS37]]) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 9e4e7ebf2a5703..4be0159fb1dd9f 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2195,7 +2195,7 @@ class OpenMPIRBuilder { /// The number of teams. ArrayRef<Value *> NumTeams; /// The number of threads. - Value *NumThreads = nullptr; + ArrayRef<Value *> NumThreads; /// The size of the dynamic shared memory. Value *DynCGGroupMem = nullptr; /// True if the kernel has 'no wait' clause. @@ -2205,7 +2205,8 @@ class OpenMPIRBuilder { TargetKernelArgs() {} TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs, Value *NumIterations, ArrayRef<Value *> NumTeams, - Value *NumThreads, Value *DynCGGroupMem, bool HasNoWait) + ArrayRef<Value *> NumThreads, Value *DynCGGroupMem, + bool HasNoWait) : NumTargetItems(NumTargetItems), RTArgs(RTArgs), NumIterations(NumIterations), NumTeams(NumTeams), NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem), @@ -2852,17 +2853,16 @@ class OpenMPIRBuilder { /// instructions for passed in target arguments where neccessary /// \param Dependencies A vector of DependData objects that carry // dependency information as passed in the depend clause - InsertPointTy createTarget(const LocationDescription &Loc, - bool IsOffloadEntry, - OpenMPIRBuilder::InsertPointTy AllocaIP, - OpenMPIRBuilder::InsertPointTy CodeGenIP, - TargetRegionEntryInfo &EntryInfo, - ArrayRef<int32_t> NumTeams, int32_t NumThreads, - SmallVectorImpl<Value *> &Inputs, - GenMapInfoCallbackTy GenMapInfoCB, - TargetBodyGenCallbackTy BodyGenCB, - TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector<DependData> Dependencies = {}); + InsertPointTy + createTarget(const LocationDescription &Loc, bool IsOffloadEntry, + OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams, + ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs, + GenMapInfoCallbackTy GenMapInfoCB, + TargetBodyGenCallbackTy BodyGenCB, + TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, + SmallVector<DependData> Dependencies = {}); /// Returns __kmpc_for_static_init_* runtime function for the specified /// size \a IVSize and sign \a IVSigned. Will create a distribute call diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index b481520fa6c6f9..f46531cb3bad40 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -505,11 +505,14 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs, Value *NumTeams3D = Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams[0], {0}); + Value *NumThreads3D = + Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads[0], {0}); for (unsigned I = 1; I < std::min(KernelArgs.NumTeams.size(), MaxDim); ++I) NumTeams3D = Builder.CreateInsertValue(NumTeams3D, KernelArgs.NumTeams[I], {I}); - Value *NumThreads3D = - Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0}); + for (unsigned I = 1; I < std::min(KernelArgs.NumThreads.size(), MaxDim); ++I) + NumThreads3D = + Builder.CreateInsertValue(NumThreads3D, KernelArgs.NumTeams[I], {I}); ArgsVector = {Version, PointerNum, @@ -1114,9 +1117,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch( // __tgt_target_teams() launches a GPU kernel with the requested number // of teams and threads so no additional calls to the runtime are required. // Check the error code and execute the host version if required. - Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID, - Args.NumTeams.front(), Args.NumThreads, - OutlinedFnID, ArgsVector)); + Builder.restoreIP(emitTargetKernel( + Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams.front(), + Args.NumThreads.front(), OutlinedFnID, ArgsVector)); BasicBlock *OffloadFailedBlock = BasicBlock::Create(Builder.getContext(), "omp_offload.failed"); @@ -7075,8 +7078,8 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs( static void emitTargetCall( OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn, - Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams, int32_t NumThreads, - SmallVectorImpl<Value *> &Args, + Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams, + ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) { // Generate a function call to the host fallback implementation of the target @@ -7123,13 +7126,15 @@ static void emitTargetCall( /*ForEndCall=*/false); SmallVector<Value *, 3> NumTeamsC; + SmallVector<Value *, 3> NumThreadsC; for (auto V : NumTeams) NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + for (auto V : NumThreads) + NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); unsigned NumTargetItems = Info.NumberOfPtrs; // TODO: Use correct device ID Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); - Value *NumThreadsVal = Builder.getInt32(NumThreads); uint32_t SrcLocStrSize; Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, @@ -7140,8 +7145,8 @@ static void emitTargetCall( Value *DynCGGroupMem = Builder.getInt32(0); OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations, - NumTeamsC, NumThreadsVal, - DynCGGroupMem, HasNoWait); + NumTeamsC, NumThreadsC, DynCGGroupMem, + HasNoWait); // The presence of certain clauses on the target directive require the // explicit generation of the target task. @@ -7159,11 +7164,11 @@ static void emitTargetCall( OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, - ArrayRef<int32_t> NumTeams, int32_t NumThreads, + ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector<DependData> Dependenciess) { + SmallVector<DependData> Dependencies) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -7184,7 +7189,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( // that represents the target region. Do that now. if (!Config.isTargetDevice()) emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams, - NumThreads, Args, GenMapInfoCB, Dependenciess); + NumThreads, Args, GenMapInfoCB, Dependencies); return Builder.saveIP(); } `````````` </details> https://github.com/llvm/llvm-project/pull/102717 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits