https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/67723
>From 6aabc3c10ea2d587120b74966b7ce96f9b8167af Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Thu, 28 Sep 2023 13:35:07 -0500 Subject: [PATCH 1/4] [OpenMPIRBuilder] Remove wrapper function in `createTask` This patch removes the wrapper function in `OpenMPIRBuilder::createTask`. The outlined function is directly of the form that is expected by the runtime library calls. This also fixes the global thread ID argument, which should be used whenever `kmpc_global_thread_num()` is called inside the outlined function. --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 106 ++++++++---------- .../Frontend/OpenMPIRBuilderTest.cpp | 56 +++++---- mlir/test/Target/LLVMIR/openmp-llvm.mlir | 51 +++------ 3 files changed, 99 insertions(+), 114 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 9c70d384e55db2b..54012b488c6b671 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" @@ -1496,6 +1497,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition, SmallVector<DependData> Dependencies) { + // We create a temporary i32 value that will represent the global tid after + // outlining. + SmallVector<Instruction *, 4> ToBeDeleted; + Builder.restoreIP(AllocaIP); + AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr"); + LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"); + ToBeDeleted.append({TID, TIDAddr}); + if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1523,41 +1532,27 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, BasicBlock *TaskAllocaBB = splitBB(Builder, /*CreateBranch=*/true, "task.alloca"); + // Fake use of TID + Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin()); + BinaryOperator *AddInst = + dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10))); + ToBeDeleted.push_back(AddInst); + OutlineInfo OI; OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, - Dependencies](Function &OutlinedFn) { - // The input IR here looks like the following- - // ``` - // func @current_fn() { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` - // - // This is changed to the following- - // - // ``` - // func @current_fn() { - // runtime_call(..., wrapper_fn, ...) - // } - // func @wrapper_fn(..., %args) { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` - - // The stale call instruction will be replaced with a new call instruction - // for runtime call with a wrapper function. + OI.ExcludeArgsFromAggregate = {TID}; + OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies, + TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) { + // Replace the Stale CI by appropriate RTL function call. assert(OutlinedFn.getNumUses() == 1 && "there must be a single user for the outlined function"); CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back()); // HasShareds is true if any variables are captured in the outlined region, // false otherwise. - bool HasShareds = StaleCI->arg_size() > 0; + bool HasShareds = StaleCI->arg_size() > 1; Builder.SetInsertPoint(StaleCI); // Gather the arguments for emitting the runtime call for @@ -1595,7 +1590,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, Value *SharedsSize = Builder.getInt64(0); if (HasShareds) { AllocaInst *ArgStructAlloca = - dyn_cast<AllocaInst>(StaleCI->getArgOperand(0)); + dyn_cast<AllocaInst>(StaleCI->getArgOperand(1)); assert(ArgStructAlloca && "Unable to find the alloca instruction corresponding to arguments " "for extracted function"); @@ -1606,31 +1601,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, SharedsSize = Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); } - - // Argument - task_entry (the wrapper function) - // If the outlined function has some captured variables (i.e. HasShareds is - // true), then the wrapper function will have an additional argument (the - // struct containing captured variables). Otherwise, no such argument will - // be present. - SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()}; - if (HasShareds) - WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType()); - FunctionCallee WrapperFuncVal = M.getOrInsertFunction( - (Twine(OutlinedFn.getName()) + ".wrapper").str(), - FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false)); - Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee()); - // Emit the @__kmpc_omp_task_alloc runtime call // The runtime call returns a pointer to an area where the task captured // variables must be copied before the task is run (TaskData) CallInst *TaskData = Builder.CreateCall( TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags, /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize, - /*task_func=*/WrapperFunc}); + /*task_func=*/&OutlinedFn}); // Copy the arguments for outlined function if (HasShareds) { - Value *Shareds = StaleCI->getArgOperand(0); + Value *Shareds = StaleCI->getArgOperand(1); Align Alignment = TaskData->getPointerAlignment(M.getDataLayout()); Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData); Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment, @@ -1697,10 +1678,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, if (IfCondition) { // `SplitBlockAndInsertIfThenElse` requires the block to have a // terminator. - BasicBlock *NewBasicBlock = - splitBB(Builder, /*CreateBranch=*/true, "if.end"); + splitBB(Builder, /*CreateBranch=*/true, "if.end"); Instruction *IfTerminator = - NewBasicBlock->getSinglePredecessor()->getTerminator(); + Builder.GetInsertPoint()->getParent()->getTerminator(); Instruction *ThenTI = IfTerminator, *ElseTI = nullptr; Builder.SetInsertPoint(IfTerminator); SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI, @@ -1711,10 +1691,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, Function *TaskCompleteFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0); Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData}); + CallInst *CI = nullptr; if (HasShareds) - Builder.CreateCall(WrapperFunc, {ThreadID, TaskData}); + CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData}); else - Builder.CreateCall(WrapperFunc, {ThreadID}); + CI = Builder.CreateCall(&OutlinedFn, {ThreadID}); + CI->setDebugLoc(StaleCI->getDebugLoc()); Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData}); Builder.SetInsertPoint(ThenTI); } @@ -1736,18 +1718,28 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, StaleCI->eraseFromParent(); - // Emit the body for wrapper function - BasicBlock *WrapperEntryBB = - BasicBlock::Create(M.getContext(), "", WrapperFunc); - Builder.SetInsertPoint(WrapperEntryBB); + Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin()); if (HasShareds) { - llvm::Value *Shareds = - Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1)); - Builder.CreateCall(&OutlinedFn, {Shareds}); - } else { - Builder.CreateCall(&OutlinedFn); + LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1)); + OutlinedFn.getArg(1)->replaceUsesWithIf( + Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; }); + } + + // Replace kmpc_global_thread_num() calls with the global thread id + // argument. + OutlinedFn.getArg(0)->setName("global.tid"); + FunctionCallee TIDRTLFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num); + for (Instruction &Inst : instructions(OutlinedFn)) { + CallInst *CI = dyn_cast<CallInst>(&Inst); + if (!CI) + continue; + if (CI->getCalledFunction() == TIDRTLFn.getCallee()) + CI->replaceAllUsesWith(OutlinedFn.getArg(0)); } - Builder.CreateRet(Builder.getInt32(0)); + + for (Instruction *I : ToBeDeleted) + I->eraseFromParent(); }; addOutlineInfo(std::move(OI)); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index fd524f6067ee0ea..643b34270c01693 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5486,25 +5486,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) { 24); // 64-bit pointer + 128-bit integer // Verify Wrapper function - Function *WrapperFunc = + Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); - ASSERT_NE(WrapperFunc, nullptr); + ASSERT_NE(OutlinedFn, nullptr); - LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin()); + LoadInst *SharedsLoad = dyn_cast<LoadInst>(OutlinedFn->begin()->begin()); ASSERT_NE(SharedsLoad, nullptr); - EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1)); - - EXPECT_FALSE(WrapperFunc->isDeclaration()); - CallInst *OutlinedFnCall = - dyn_cast<CallInst>(++WrapperFunc->begin()->begin()); - ASSERT_NE(OutlinedFnCall, nullptr); - EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty()); - EXPECT_EQ(OutlinedFnCall->getArgOperand(0), - WrapperFunc->getArg(1)->uses().begin()->getUser()); + EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1)); + + EXPECT_FALSE(OutlinedFn->isDeclaration()); + EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty()); + + // Verify that the data argument is used only once, and that too in the load + // instruction that is then used for accessing shared data. + Value *DataPtr = OutlinedFn->getArg(1); + EXPECT_EQ(DataPtr->getNumUses(), 1); + EXPECT_TRUE(isa<LoadInst>(DataPtr->uses().begin()->getUser())); + Value *Data = DataPtr->uses().begin()->getUser(); + EXPECT_TRUE(all_of(Data->uses(), [](Use &U) { + return isa<GetElementPtrInst>(U.getUser()); + })); // Verify the presence of `trunc` and `icmp` instructions in Outlined function - Function *OutlinedFn = OutlinedFnCall->getCalledFunction(); - ASSERT_NE(OutlinedFn, nullptr); EXPECT_TRUE(any_of(instructions(OutlinedFn), [](Instruction &inst) { return isa<TruncInst>(&inst); })); EXPECT_TRUE(any_of(instructions(OutlinedFn), @@ -5547,6 +5550,14 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) { Builder.CreateRetVoid(); EXPECT_FALSE(verifyModule(*M, &errs())); + + // Check that the outlined function has only one argument. + CallInst *TaskAllocCall = dyn_cast<CallInst>( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5)); + ASSERT_NE(OutlinedFn, nullptr); + ASSERT_EQ(OutlinedFn->arg_size(), 1); } TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) { @@ -5658,8 +5669,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) { F->setName("func"); IRBuilder<> Builder(BB); auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; - IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); Builder.SetInsertPoint(BodyBB); Value *Final = Builder.CreateICmp( CmpInst::Predicate::ICMP_EQ, F->getArg(0), @@ -5711,8 +5722,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) { F->setName("func"); IRBuilder<> Builder(BB); auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; - IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); Builder.SetInsertPoint(BodyBB); Value *IfCondition = Builder.CreateICmp( CmpInst::Predicate::ICMP_EQ, F->getArg(0), @@ -5758,15 +5769,16 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) { ->user_back()); ASSERT_NE(TaskBeginIfCall, nullptr); ASSERT_NE(TaskCompleteCall, nullptr); - Function *WrapperFunc = + Function *OulinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); - ASSERT_NE(WrapperFunc, nullptr); - CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back()); - ASSERT_NE(WrapperFuncCall, nullptr); + ASSERT_NE(OulinedFn, nullptr); + CallInst *OulinedFnCall = dyn_cast<CallInst>(OulinedFn->user_back()); + ASSERT_NE(OulinedFnCall, nullptr); EXPECT_EQ(TaskBeginIfCall->getParent(), IfConditionBranchInst->getSuccessor(1)); - EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall); - EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall); + + EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall); + EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall); } TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 28b0113a19d61b8..2cd561cb021075f 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2209,7 +2209,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, - // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) + // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) omp.task { %n = llvm.mlir.constant(1 : i64) : i64 @@ -2222,7 +2222,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) { llvm.return } -// CHECK: define internal void @[[outlined_fn:.+]]() +// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]]) // CHECK: task.alloca{{.*}}: // CHECK: br label %[[task_body:[^, ]+]] // CHECK: [[task_body]]: @@ -2236,12 +2236,6 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) { // CHECK: [[exit_stub]]: // CHECK: ret void - -// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) { -// CHECK: call void @[[outlined_fn]]() -// CHECK: ret i32 0 -// CHECK: } - // ----- // CHECK-LABEL: define void @omp_task_with_deps @@ -2259,7 +2253,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, - // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) + // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}}) omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) { %n = llvm.mlir.constant(1 : i64) : i64 @@ -2272,7 +2266,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) { llvm.return } -// CHECK: define internal void @[[outlined_fn:.+]]() +// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]]) // CHECK: task.alloca{{.*}}: // CHECK: br label %[[task_body:[^, ]+]] // CHECK: [[task_body]]: @@ -2286,11 +2280,6 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) { // CHECK: [[exit_stub]]: // CHECK: ret void -// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) { -// CHECK: call void @[[outlined_fn]]() -// CHECK: ret i32 0 -// CHECK: } - // ----- // CHECK-LABEL: define void @omp_task @@ -2304,7 +2293,7 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16, - // CHECK-SAME: ptr @[[wrapper_fn:.+]]) + // CHECK-SAME: ptr @[[outlined_fn:.+]]) // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]] // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false) // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) @@ -2321,8 +2310,9 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} { } } -// CHECK: define internal void @[[outlined_fn:.+]](ptr %[[task_data:.+]]) +// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]], ptr %[[task_data:.+]]) // CHECK: task.alloca{{.*}}: +// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]] // CHECK: br label %[[task_body:[^, ]+]] // CHECK: [[task_body]]: // CHECK: br label %[[task_region:[^, ]+]] @@ -2333,13 +2323,6 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} { // CHECK: [[exit_stub]]: // CHECK: ret void - -// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) { -// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8 -// CHECK: call void @[[outlined_fn]](ptr %[[shareds]]) -// CHECK: ret i32 0 -// CHECK: } - // ----- llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) { @@ -2355,14 +2338,12 @@ llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) { } // CHECK-LABEL: @par_task_ -// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @par_task_..omp_par.wrapper) +// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]]) -// CHECK-LABEL: define internal void @par_task_..omp_par +// CHECK: define internal void @[[task_outlined_fn]] // CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8 -// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @par_task_..omp_par..omp_par, ptr %[[ARG_ALLOC]]) -// CHECK: define internal void @par_task_..omp_par..omp_par -// CHECK: define i32 @par_task_..omp_par.wrapper -// CHECK: call void @par_task_..omp_par +// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]]) +// CHECK: define internal void @[[parallel_outlined_fn]] // ----- llvm.func @foo() -> () @@ -2432,7 +2413,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) { // CHECK: br label %[[codeRepl:[^,]+]] // CHECK: [[codeRepl]]: // CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper) +// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @[[outlined_task_fn:.+]]) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]]) // CHECK: br label %[[task_exit:[^,]+]] // CHECK: [[task_exit]]: @@ -2445,7 +2426,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) { // CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2 // CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8 // CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper) +// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @[[outlined_task_fn:.+]]) // CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]] // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]]) @@ -2617,7 +2598,7 @@ llvm.func @omp_task_final(%boolexpr: i1) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) // CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0 // CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1 -// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @[[task_outlined_fn:.+]]) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[task_exit:[^,]+]] // CHECK: [[task_exit]]: @@ -2648,14 +2629,14 @@ llvm.func @omp_task_if(%boolexpr: i1) { // CHECK: br label %[[codeRepl:[^,]+]] // CHECK: [[codeRepl]]: // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @[[task_outlined_fn:.+]]) // CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]] // CHECK: [[true_label]]: // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[if_else_exit:[^,]+]] // CHECK: [[false_label:[^,]+]]: ; preds = %codeRepl // CHECK: call void @__kmpc_omp_task_begin_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) -// CHECK: %{{.+}} = call i32 @omp_task_if..omp_par.wrapper(i32 %[[omp_global_thread_num]]) +// CHECK: call void @[[task_outlined_fn]](i32 %[[omp_global_thread_num]]) // CHECK: call void @__kmpc_omp_task_complete_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[if_else_exit]] // CHECK: [[if_else_exit]]: >From a1a9438b5e00170030b419a7736053422745cbc6 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Mon, 2 Oct 2023 09:22:30 -0500 Subject: [PATCH 2/4] Remove outlining for teams too. --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 178 +++++++++--------- .../Frontend/OpenMPIRBuilderTest.cpp | 22 +-- 2 files changed, 95 insertions(+), 105 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 54012b488c6b671..a5a73bcc10c48e3 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -341,6 +341,44 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch, return splitBB(Builder, CreateBranch, Old->getName() + Suffix); } +// This function creates a fake integer value and a fake use for the integer +// value. It returns the fake value created. This is useful in modeling the +// extra arguments to the outlined functions. +Value *createFakeIntVal(IRBuilder<> &Builder, + OpenMPIRBuilder::InsertPointTy OuterAllocaIP, + std::stack<Instruction *> &ToBeDeleted, + OpenMPIRBuilder::InsertPointTy InnerAllocaIP, + const Twine &Name = "", bool AsPtr = true) { + Builder.restoreIP(OuterAllocaIP); + Instruction *FakeVal; + AllocaInst *FakeValAddr = + Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr"); + ToBeDeleted.push(FakeValAddr); + + if (AsPtr) + FakeVal = FakeValAddr; + else { + FakeVal = + Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val"); + ToBeDeleted.push(FakeVal); + } + + // We only need TIDAddr and ZeroAddr for modeling purposes to get the + // associated arguments in the outlined function, so we delete them later. + + // Fake use of TID + Builder.restoreIP(InnerAllocaIP); + Instruction *UseFakeVal; + if (AsPtr) + UseFakeVal = + Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use"); + else + UseFakeVal = + cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10))); + ToBeDeleted.push(UseFakeVal); + return FakeVal; +} + //===----------------------------------------------------------------------===// // OpenMPIRBuilderConfig //===----------------------------------------------------------------------===// @@ -1497,13 +1535,6 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition, SmallVector<DependData> Dependencies) { - // We create a temporary i32 value that will represent the global tid after - // outlining. - SmallVector<Instruction *, 4> ToBeDeleted; - Builder.restoreIP(AllocaIP); - AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr"); - LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"); - ToBeDeleted.append({TID, TIDAddr}); if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1532,19 +1563,24 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, BasicBlock *TaskAllocaBB = splitBB(Builder, /*CreateBranch=*/true, "task.alloca"); - // Fake use of TID - Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin()); - BinaryOperator *AddInst = - dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10))); - ToBeDeleted.push_back(AddInst); + InsertPointTy TaskAllocaIP = + InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); + InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); + BodyGenCB(TaskAllocaIP, TaskBodyIP); + Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin()); OutlineInfo OI; OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.ExcludeArgsFromAggregate = {TID}; + + // Add the thread ID argument. + std::stack<Instruction *> ToBeDeleted; + OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false)); + OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies, - TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) { + TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable { // Replace the Stale CI by appropriate RTL function call. assert(OutlinedFn.getNumUses() == 1 && "there must be a single user for the outlined function"); @@ -1670,7 +1706,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, // br label %exit // else: // call @__kmpc_omp_task_begin_if0(...) - // call @wrapper_fn(...) + // call @outlined_fn(...) // call @__kmpc_omp_task_complete_if0(...) // br label %exit // exit: @@ -1725,31 +1761,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; }); } - // Replace kmpc_global_thread_num() calls with the global thread id - // argument. - OutlinedFn.getArg(0)->setName("global.tid"); - FunctionCallee TIDRTLFn = - getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num); - for (Instruction &Inst : instructions(OutlinedFn)) { - CallInst *CI = dyn_cast<CallInst>(&Inst); - if (!CI) - continue; - if (CI->getCalledFunction() == TIDRTLFn.getCallee()) - CI->replaceAllUsesWith(OutlinedFn.getArg(0)); + while (!ToBeDeleted.empty()) { + ToBeDeleted.top()->eraseFromParent(); + ToBeDeleted.pop(); } - - for (Instruction *I : ToBeDeleted) - I->eraseFromParent(); }; addOutlineInfo(std::move(OI)); - InsertPointTy TaskAllocaIP = - InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); - InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); - BodyGenCB(TaskAllocaIP, TaskBodyIP); - Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin()); - return Builder.saveIP(); } @@ -5740,6 +5759,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry"); Builder.SetInsertPoint(BodyBB, BodyBB->begin()); } + InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin()); // The current basic block is split into four basic blocks. After outlining, // they will be mapped as follows: @@ -5763,84 +5783,62 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, BasicBlock *AllocaBB = splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); + // Generate the body of teams. + InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); + InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); + BodyGenCB(AllocaIP, CodeGenIP); + OutlineInfo OI; OI.EntryBB = AllocaBB; OI.ExitBB = ExitBB; OI.OuterAllocaBB = &OuterAllocaBB; - OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) { - // The input IR here looks like the following- - // ``` - // func @current_fn() { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` - // - // This is changed to the following- - // - // ``` - // func @current_fn() { - // runtime_call(..., wrapper_fn, ...) - // } - // func @wrapper_fn(..., %args) { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` + // Insert fake values for global tid and bound tid. + std::stack<Instruction *> ToBeDeleted; + OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true)); + OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true)); + + OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable { // The stale call instruction will be replaced with a new call instruction - // for runtime call with a wrapper function. + // for runtime call with the outlined function. assert(OutlinedFn.getNumUses() == 1 && "there must be a single user for the outlined function"); CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back()); + ToBeDeleted.push(StaleCI); + + assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) && + "Outlined function must have two or three arguments only"); - // Create the wrapper function. - SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()}; - for (auto &Arg : OutlinedFn.args()) - WrapperArgTys.push_back(Arg.getType()); - FunctionCallee WrapperFuncVal = M.getOrInsertFunction( - (Twine(OutlinedFn.getName()) + ".teams").str(), - FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false)); - Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee()); - WrapperFunc->getArg(0)->setName("global_tid"); - WrapperFunc->getArg(1)->setName("bound_tid"); - if (WrapperFunc->arg_size() > 2) - WrapperFunc->getArg(2)->setName("data"); - - // Emit the body of the wrapper function - just a call to outlined function - // and return statement. - BasicBlock *WrapperEntryBB = - BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc); - Builder.SetInsertPoint(WrapperEntryBB); - SmallVector<Value *> Args; - for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) - Args.push_back(WrapperFunc->getArg(ArgIndex)); - Builder.CreateCall(&OutlinedFn, Args); - Builder.CreateRetVoid(); - - OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline); + bool HasShared = OutlinedFn.arg_size() == 3; + + OutlinedFn.getArg(0)->setName("global.tid.ptr"); + OutlinedFn.getArg(1)->setName("bound.tid.ptr"); + if (HasShared) + OutlinedFn.getArg(2)->setName("data"); // Call to the runtime function for teams in the current function. assert(StaleCI && "Error while outlining - no CallInst user found for the " "outlined function."); Builder.SetInsertPoint(StaleCI); - Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc}; - for (Use &Arg : StaleCI->args()) - Args.push_back(Arg); + SmallVector<Value *> Args = {Ident, Builder.getInt32(StaleCI->arg_size()), + &OutlinedFn}; + if (HasShared) + Args.push_back(StaleCI->getArgOperand(2)); Builder.CreateCall(getOrCreateRuntimeFunctionPtr( omp::RuntimeFunction::OMPRTL___kmpc_fork_teams), Args); - StaleCI->eraseFromParent(); + + while (!ToBeDeleted.empty()) { + ToBeDeleted.top()->eraseFromParent(); + ToBeDeleted.pop(); + } }; addOutlineInfo(std::move(OI)); - // Generate the body of teams. - InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); - InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); - BodyGenCB(AllocaIP, CodeGenIP); - Builder.SetInsertPoint(ExitBB, ExitBB->begin()); return Builder.saveIP(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 643b34270c01693..c4b0389c89c7c60 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4057,25 +4057,17 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) { ASSERT_NE(SrcSrc, nullptr); // Verify the outlined function signature. - Function *WrapperFn = + Function *OutlinedFn = dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts()); - ASSERT_NE(WrapperFn, nullptr); - EXPECT_FALSE(WrapperFn->isDeclaration()); - EXPECT_TRUE(WrapperFn->arg_size() >= 3); - EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid - EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid - EXPECT_EQ(WrapperFn->getArg(2)->getType(), + ASSERT_NE(OutlinedFn, nullptr); + EXPECT_FALSE(OutlinedFn->isDeclaration()); + EXPECT_TRUE(OutlinedFn->arg_size() >= 3); + EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid + EXPECT_EQ(OutlinedFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid + EXPECT_EQ(OutlinedFn->getArg(2)->getType(), Builder.getPtrTy()); // captured args // Check for TruncInst and ICmpInst in the outlined function. - inst_range Instructions = instructions(WrapperFn); - auto OutlinedFnInst = find_if( - Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); }); - ASSERT_NE(OutlinedFnInst, Instructions.end()); - CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst); - ASSERT_NE(OutlinedFnCI, nullptr); - Function *OutlinedFn = OutlinedFnCI->getCalledFunction(); - EXPECT_TRUE(any_of(instructions(OutlinedFn), [](Instruction &inst) { return isa<TruncInst>(&inst); })); EXPECT_TRUE(any_of(instructions(OutlinedFn), >From 4b71558a1936983e1eeebfee98de6b4d8f1062cc Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Mon, 2 Oct 2023 09:26:57 -0500 Subject: [PATCH 3/4] Remove unintentional include for InstIterator.h --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index a5a73bcc10c48e3..f62d244a2dc4c68 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -35,7 +35,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" >From 7c95d29b677c6107f81b0c26c139a34475a6fe81 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Mon, 2 Oct 2023 09:50:26 -0500 Subject: [PATCH 4/4] Fix insertpoint after createTask --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index f62d244a2dc4c68..5ed2a345a14dd04 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1566,7 +1566,6 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); BodyGenCB(TaskAllocaIP, TaskBodyIP); - Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin()); OutlineInfo OI; OI.EntryBB = TaskAllocaBB; @@ -1767,6 +1766,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, }; addOutlineInfo(std::move(OI)); + Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin()); return Builder.saveIP(); } _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits