https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/106154
None >From eda80fe012239f907df3ee3c4d6d94c93d9d4df2 Mon Sep 17 00:00:00 2001 From: Mircea Trofin <mtro...@google.com> Date: Thu, 22 Aug 2024 18:03:56 -0700 Subject: [PATCH] [ctx_prof] Add Inlining support --- llvm/include/llvm/Analysis/CtxProfAnalysis.h | 10 + llvm/include/llvm/IR/IntrinsicInst.h | 4 + .../llvm/ProfileData/PGOCtxProfReader.h | 5 + llvm/include/llvm/Transforms/Utils/Cloning.h | 9 + llvm/lib/Analysis/CtxProfAnalysis.cpp | 1 - llvm/lib/Transforms/IPO/ModuleInliner.cpp | 1 + llvm/lib/Transforms/Utils/InlineFunction.cpp | 162 ++++++++++++++++ .../unittests/Transforms/Utils/CMakeLists.txt | 1 + .../Transforms/Utils/InlineFunctionTest.cpp | 174 ++++++++++++++++++ 9 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 llvm/unittests/Transforms/Utils/InlineFunctionTest.cpp diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h index 10aef6f6067b6f..f630dceb8a644c 100644 --- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h +++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h @@ -62,6 +62,16 @@ class PGOContextualProfile { bool isFunctionKnown(const Function &F) const { return getDefinedFunctionGUID(F) != 0; } + + uint32_t getNrCounters(const Function &F) const { + assert(isFunctionKnown(F)); + return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex; + } + + uint32_t getNrCallsites(const Function &F) const { + assert(isFunctionKnown(F)); + return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex; + } uint32_t allocateNextCounterIndex(const Function &F) { assert(isFunctionKnown(F)); diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h index 71a96e0671c2f1..2ebcee422eddfb 100644 --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -1516,6 +1516,10 @@ class InstrProfInstBase : public IntrinsicInst { return const_cast<Value *>(getArgOperand(0))->stripPointerCasts(); } + void setNameValue(Value *V) { + setArgOperand(0, V); + } + // The hash of the CFG for the instrumented function. ConstantInt *getHash() const { return cast<ConstantInt>(const_cast<Value *>(getArgOperand(1))); diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h index f7f88966f7573f..c64e6e79f96c4c 100644 --- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h +++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h @@ -74,6 +74,11 @@ class PGOCtxProfContext final { Iter->second.emplace(Other.guid(), std::move(Other)); } + void ingestAllContexts(uint32_t CSId, CallTargetMapTy &&Other) { + auto [_, Inserted] = callsites().try_emplace(CSId, std::move(Other)); + assert(Inserted); + } + void resizeCounters(uint32_t Size) { Counters.resize(Size); } bool hasCallsite(uint32_t I) const { diff --git a/llvm/include/llvm/Transforms/Utils/Cloning.h b/llvm/include/llvm/Transforms/Utils/Cloning.h index 6226062dd713f6..0b670bfb9ce806 100644 --- a/llvm/include/llvm/Transforms/Utils/Cloning.h +++ b/llvm/include/llvm/Transforms/Utils/Cloning.h @@ -20,6 +20,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/ValueHandle.h" @@ -270,6 +271,14 @@ InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, bool InsertLifetime = true, Function *ForwardVarArgsTo = nullptr); +/// Same as above, but it will update the contextual profile. +InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, + CtxProfAnalysis::Result &CtxProf, + bool MergeAttributes = false, + AAResults *CalleeAAR = nullptr, + bool InsertLifetime = true, + Function *ForwardVarArgsTo = nullptr); + /// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p /// Blocks. /// diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp index 2cd3f2114397e5..77aefecf3ff18c 100644 --- a/llvm/lib/Analysis/CtxProfAnalysis.cpp +++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp @@ -150,7 +150,6 @@ PGOContextualProfile CtxProfAnalysis::run(Module &M, // If we made it this far, the Result is valid - which we mark by setting // .Profiles. // Trim first the roots that aren't in this module. - DenseSet<GlobalValue::GUID> ProfiledGUIDs; for (auto &[RootGuid, _] : llvm::make_early_inc_range(*MaybeCtx)) if (!Result.FuncInfo.contains(RootGuid)) MaybeCtx->erase(RootGuid); diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp index 5e91ab80d7505f..f0a6c771347f5e 100644 --- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp +++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InlineOrder.h" diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 94e87656a192c7..73b59c52c9c12d 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryProfileInfo.h" @@ -46,6 +47,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstrTypes.h" @@ -71,6 +73,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <deque> #include <iterator> #include <limits> #include <optional> @@ -2116,6 +2119,165 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind, } } +static const std::pair<std::vector<int64_t>, std::vector<int64_t>> +remapIndices(Function &Caller, BasicBlock *StartBB, + CtxProfAnalysis::Result &CtxProf, uint32_t CalleeCounters, + uint32_t CalleeCallsites) { + // We'll allocate a new ID to imported callsite counters and callsites. We're + // using -1 to indicate a counter we delete. Most likely the entry, for + // example, will be deleted - we don't want 2 IDs in the same BB, and the + // entry would have been cloned in the callsite's old BB. + std::vector<int64_t> CalleeCounterMap; + std::vector<int64_t> CalleeCallsiteMap; + CalleeCounterMap.resize(CalleeCounters, -1); + CalleeCallsiteMap.resize(CalleeCallsites, -1); + + auto RewriteInstrIfNeeded = [&](InstrProfIncrementInst &Ins) -> bool { + if (Ins.getNameValue() == &Caller) + return false; + const auto OldID = static_cast<uint32_t>(Ins.getIndex()->getZExtValue()); + if (CalleeCounterMap[OldID] == -1) + CalleeCounterMap[OldID] = CtxProf.allocateNextCounterIndex(Caller); + const auto NewID = static_cast<uint32_t>(CalleeCounterMap[OldID]); + + Ins.setNameValue(&Caller); + Ins.setIndex(NewID); + return true; + }; + + auto RewriteCallsiteInsIfNeeded = [&](InstrProfCallsite &Ins)-> bool { + if (Ins.getNameValue() == &Caller) + return false; + const auto OldID = static_cast<uint32_t>(Ins.getIndex()->getZExtValue()); + if (CalleeCallsiteMap[OldID] == -1) + CalleeCallsiteMap[OldID] = CtxProf.allocateNextCallsiteIndex(Caller); + const auto NewID = static_cast<uint32_t>(CalleeCallsiteMap[OldID]); + + Ins.setNameValue(&Caller); + Ins.setIndex(NewID); + return true; + }; + + std::deque<BasicBlock*> Worklist; + DenseSet<const BasicBlock*> Seen; + // We will traverse the BBs starting from the callsite BB. The callsite BB + // will have at least a BB ID - maybe its own, and in any case the one coming + // from the cloned function's entry BB. The other BBs we'll start seeing from + // there on may or may not have BB IDs. BBs with IDs belonging to our caller + // are definitely not coming from the imported function and form a boundary + // past which we don't need to traverse anymore. BBs may have no + // instrumentation, in which case we'll traverse past them. + // An invariant we'll keep is that a BB will have at most 1 BB ID. For + // example, the callsite BB will delete the callee BB's instrumentation. This + // doesn't result in information loss: the entry BB of the caller will have + // the same count as the callsite's BB. + // At the end of this traversal, all the callee's instrumentation would be + // mapped into the caller's instrumentation index space. Some of the callee's + // counters may be deleted (as mentioned, this should result in no loss of + // information). + Worklist.push_back(StartBB); + while (!Worklist.empty()) { + auto *BB = Worklist.front(); + Worklist.pop_front(); + bool Changed = false; + auto *BBID = CtxProfAnalysis::getBBInstrumentation(*BB); + if (BBID) { + Changed |= RewriteInstrIfNeeded(*BBID); + // this may be the entryblock from the inlined callee, coming into a BB + // that didn't have instrumentation because of MST decisions. Let's make + // sure it's placed accordingly. This is a noop elsewhere. + BBID->moveBefore(&*BB->getFirstInsertionPt()); + } + for (auto &I : llvm::make_early_inc_range(*BB)) { + if (auto *Inc = dyn_cast<InstrProfIncrementInst>(&I)) { + if (Inc != BBID) { + Inc->eraseFromParent(); + Changed = true; + } + } else if (auto *CS = dyn_cast<InstrProfCallsite>(&I)) { + Changed |= RewriteCallsiteInsIfNeeded(*CS); + } + } + if (!BBID || Changed) + for (auto *Succ : successors(BB)) + if (Seen.insert(Succ).second) + Worklist.push_back(Succ); + } + return {std::move(CalleeCounterMap), std::move(CalleeCallsiteMap)}; +} + +llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, + CtxProfAnalysis::Result &CtxProf, + bool MergeAttributes, + AAResults *CalleeAAR, + bool InsertLifetime, + Function *ForwardVarArgsTo) { + auto &Caller = *CB.getCaller(); + auto &Callee = *CB.getCalledFunction(); + auto *StartBB = CB.getParent(); + + const auto CalleeGUID = AssignGUIDPass::getGUID(Callee); + auto *CallsiteIDIns = CtxProfAnalysis::getCallsiteInstrumentation(CB); + const auto CallsiteID = + static_cast<uint32_t>(CallsiteIDIns->getIndex()->getZExtValue()); + + const auto CalleeCounters = CtxProf.getNrCounters(Callee); + const auto CalleeCallsites = CtxProf.getNrCallsites(Callee); + + auto Ret = InlineFunction(CB, IFI, MergeAttributes, CalleeAAR, InsertLifetime, + ForwardVarArgsTo); + if (!Ret.isSuccess()) + return Ret; + + // We don't have that callsite anymore. + CallsiteIDIns->eraseFromParent(); + + // Assinging Maps and then capturing references into it in the lambda because + // captured structured bindings are a C++20 extension. We do also need a + // capture here, though. + const auto Maps = + remapIndices(Caller, StartBB, CtxProf, CalleeCounters, CalleeCallsites); + const auto &[CalleeCounterMap, _] = Maps; + // We'll have to grow the counters vector by this much. The callsites are a + // map, so we don't need to do that. + const uint32_t NewCounters = + llvm::count_if(CalleeCounterMap, [](auto V) { return V != -1; }); + + auto Updater = [&](PGOCtxProfContext &Ctx) { + assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller)); + Ctx.resizeCounters(Ctx.counters().size() + NewCounters); + // If the callsite wasn't exercised in this context, the value of the + // counters coming from it is 0 and so we're done. + auto CSIt = Ctx.callsites().find(CallsiteID); + if (CSIt == Ctx.callsites().end()) + return; + auto CalleeCtxIt = CSIt->second.find(CalleeGUID); + // The callsite was exercised, but not with this callee (so presumably this + // is an indirect callsite). Again we're done. + if (CalleeCtxIt == CSIt->second.end()) + return; + const auto &[CalleeCounterMap, CalleeCallsiteMap] = Maps; + auto &CalleeCtx = CalleeCtxIt->second; + assert(CalleeCtx.guid() == CalleeGUID); + + for (auto I = 0U; I < CalleeCtx.counters().size(); ++I) { + const int64_t NewIndex = CalleeCounterMap[I]; + if (NewIndex >= 0) + Ctx.counters()[NewIndex] = CalleeCtx.counters()[I]; + } + for (auto &[I, OtherSet] : CalleeCtx.callsites()) { + const int64_t NewCSIdx = CalleeCallsiteMap[I]; + if (NewCSIdx >= 0) + Ctx.ingestAllContexts(NewCSIdx, std::move(OtherSet)); + } + auto Deleted = Ctx.callsites().erase(CallsiteID); + assert(Deleted); + (void)Deleted; + }; + CtxProf.update(Updater, &Caller); + return Ret; +} + /// This function inlines the called function into the basic block of the /// caller. This returns false if it is not possible to inline this call. /// The program is still in a well defined state if this occurs though. diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt index 35055baa05ee99..5d86ba80ceb0c6 100644 --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -19,6 +19,7 @@ add_llvm_unittest(UtilsTests CodeMoverUtilsTest.cpp DebugifyTest.cpp FunctionComparatorTest.cpp + InlineFunctionTest.cpp IntegerDivisionTest.cpp LocalTest.cpp LoopRotationUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Utils/InlineFunctionTest.cpp b/llvm/unittests/Transforms/Utils/InlineFunctionTest.cpp new file mode 100644 index 00000000000000..930a7c76e341fd --- /dev/null +++ b/llvm/unittests/Transforms/Utils/InlineFunctionTest.cpp @@ -0,0 +1,174 @@ +//===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Analysis/CtxProfAnalysis.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/ProfileData/PGOCtxProfReader.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/SupportHelpers.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("UtilsTests", errs()); + return Mod; +} + +class InlineFunctionTest : public testing::Test { +protected: + LLVMContext C; + std::unique_ptr<Module> M; + llvm::unittest::TempFile ProfileFile; + ModuleAnalysisManager MAM; + + const char *Profile = R"( + [ + { "Guid": 1000, + "Counters": [10, 2, 8], + "Callsites": [ + [ { "Guid": 1001, + "Counters": [2, 100], + "Callsites": [[{"Guid": 1002, "Counters": [100]}]]} + ], + [ { "Guid": 1001, + "Counters": [8, 500], + "Callsites": [[{"Guid": 1002, "Counters": [500]}]]} + ] + ] + } + ] + )"; + const char *IR = R"IR( +define i32 @entrypoint(i32 %x) !guid !0 { + call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 0) + %t = icmp eq i32 %x, 0 + br i1 %t, label %yes, label %no +yes: + call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 1) + call void @llvm.instrprof.callsite(ptr @entrypoint, i64 0, i32 2, i32 0, ptr @a) + %call1 = call i32 @a(i32 %x) + br label %exit +no: + call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 2) + call void @llvm.instrprof.callsite(ptr @entrypoint, i64 0, i32 2, i32 1, ptr @a) + %call2 = call i32 @a(i32 %x) + br label %exit +exit: + %ret = phi i32 [%call1, %yes], [%call2, %no] + ret i32 %ret +} + +define i32 @a(i32 %x) !guid !1 { +entry: + call void @llvm.instrprof.increment(ptr @a, i64 0, i32 2, i32 0) + br label %loop +loop: + %indvar = phi i32 [%indvar.next, %loop], [0, %entry] + call void @llvm.instrprof.increment(ptr @a, i64 0, i32 2, i32 1) + %b = add i32 %x, %indvar + call void @llvm.instrprof.callsite(ptr @a, i64 0, i32 1, i32 0, ptr @b) + %inc = call i32 @b() + %indvar.next = add i32 %indvar, %inc + %cond = icmp slt i32 %indvar.next, %x + br i1 %cond, label %loop, label %exit +exit: + ret i32 8 +} + +define i32 @b() !guid !2 { + call void @llvm.instrprof.increment(ptr @b, i64 0, i32 1, i32 0) + ret i32 1 +} + +!0 = !{i64 1000} +!1 = !{i64 1001} +!2 = !{i64 1002} +)IR"; + +public: + InlineFunctionTest() : ProfileFile("ctx_profile", "", "", /*Unique*/ true) {} + + void SetUp() override { + M = parseIR(C, IR); + ASSERT_TRUE(!!M); + std::error_code EC; + raw_fd_stream Out(ProfileFile.path(), EC); + ASSERT_FALSE(EC); + // "False" means no error. + ASSERT_FALSE(llvm::createCtxProfFromJSON(Profile, Out)); + MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); }); + MAM.registerPass([&]() { return PassInstrumentationAnalysis(); }); + } +}; + +TEST_F(InlineFunctionTest, InlineWithCtxProf) { + auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M); + EXPECT_TRUE(!!CtxProf); + auto *Caller = M->getFunction("entrypoint"); + CallBase *CB = [&]() -> CallBase * { + for (auto &BB : *Caller) + if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(BB); + Ins && Ins->getIndex()->getZExtValue() == 1) + for (auto &I : BB) + if (auto *CB = dyn_cast<CallBase>(&I); + CB && CB->getCalledFunction() && + !CB->getCalledFunction()->isIntrinsic()) + return CB; + return nullptr; + }(); + ASSERT_NE(CB, nullptr); + ASSERT_NE(CtxProfAnalysis::getCallsiteInstrumentation(*CB), nullptr); + EXPECT_EQ(CtxProfAnalysis::getCallsiteInstrumentation(*CB) + ->getIndex() + ->getZExtValue(), + 0U); + InlineFunctionInfo IFI; + InlineResult IR = InlineFunction(*CB, IFI, CtxProf); + EXPECT_TRUE(IR.isSuccess()); + std::string Str; + raw_string_ostream OS(Str); + CtxProfAnalysisPrinterPass Printer( + OS, CtxProfAnalysisPrinterPass::PrintMode::JSON); + Printer.run(*M, MAM); + + const char *Expected = R"( + [ + { "Guid": 1000, + "Counters": [10, 2, 8, 100], + "Callsites": [ + [], + [ { "Guid": 1001, + "Counters": [8, 500], + "Callsites": [[{"Guid": 1002, "Counters": [500]}]]} + ], + [{ "Guid": 1002, "Counters": [100]}] + ] + } + ] + )"; + + auto ExpectedJSON = json::parse(Expected); + ASSERT_TRUE(!!ExpectedJSON); + auto ProducedJSON = json::parse(Str); + ASSERT_TRUE(!!ProducedJSON); + EXPECT_EQ(*ProducedJSON, *ExpectedJSON); +} \ No newline at end of file _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits