llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-llvm-transforms Author: Vitaly Buka (vitalybuka) <details> <summary>Changes</summary> --- Full diff: https://github.com/llvm/llvm-project/pull/101757.diff 3 Files Affected: - (modified) llvm/include/llvm/Transforms/Utils/ModuleUtils.h (+9) - (modified) llvm/lib/Transforms/Utils/ModuleUtils.cpp (+44) - (modified) llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp (+43-3) ``````````diff diff --git a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h index 1ec87505544f8..37d6a3e33315a 100644 --- a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h +++ b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h @@ -30,6 +30,7 @@ class FunctionCallee; class GlobalIFunc; class GlobalValue; class Constant; +class ConstantStruct; class Value; class Type; @@ -44,6 +45,14 @@ void appendToGlobalCtors(Module &M, Function *F, int Priority, void appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *Data = nullptr); +/// Apply 'Fn' to the list of global ctors of module M and replace contructor +/// record with the one returned by `Fn`. If `nullptr` was returned, the +/// corresponding constructor will be removed from the array. For details see +/// https://llvm.org/docs/LangRef.html#the-llvm-global-ctors-global-variable +using GlobalCtorUpdateFn = llvm::function_ref<Constant *(Constant *)>; +void updateGlobalCtors(Module &M, const GlobalCtorUpdateFn &Fn); +void updateGlobalDtors(Module &M, const GlobalCtorUpdateFn &Fn); + /// Sets the KCFI type for the function. Used for compiler-generated functions /// that are indirectly called in instrumented code. void setKCFIType(Module &M, Function &F, StringRef MangledType); diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 122279160cc7e..e443d820a2256 100644 --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -79,6 +79,50 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *D appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data); } +static void updateGlobalArray(StringRef ArrayName, Module &M, + const GlobalCtorUpdateFn &Fn) { + GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName); + if (!GVCtor) + return; + + IRBuilder<> IRB(M.getContext()); + SmallVector<Constant *, 16> CurrentCtors; + bool Changed = false; + StructType *EltTy = + cast<StructType>(GVCtor->getValueType()->getArrayElementType()); + if (Constant *Init = GVCtor->getInitializer()) { + CurrentCtors.reserve(Init->getNumOperands()); + for (Value *OP : Init->operands()) { + Constant *C = cast<Constant>(OP); + Constant *NewC = Fn(C); + Changed |= (!NewC || NewC != C); + if (NewC) + CurrentCtors.push_back(NewC); + } + } + if (!Changed) + return; + + GVCtor->eraseFromParent(); + + // Create a new initializer. + ArrayType *AT = ArrayType::get(EltTy, CurrentCtors.size()); + Constant *NewInit = ConstantArray::get(AT, CurrentCtors); + + // Create the new global variable and replace all uses of + // the old global variable with the new one. + (void)new GlobalVariable(M, NewInit->getType(), false, + GlobalValue::AppendingLinkage, NewInit, ArrayName); +} + +void llvm::updateGlobalCtors(Module &M, const GlobalCtorUpdateFn &Fn) { + updateGlobalArray("llvm.global_ctors", M, Fn); +} + +void llvm::updateGlobalDtors(Module &M, const GlobalCtorUpdateFn &Fn) { + updateGlobalArray("llvm.global_dtors", M, Fn); +} + static void collectUsedGlobals(GlobalVariable *GV, SmallSetVector<Constant *, 16> &Init) { if (!GV || !GV->hasInitializer()) diff --git a/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp b/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp index 0ed7be9620a6f..582448a14ba8a 100644 --- a/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp @@ -70,17 +70,21 @@ TEST(ModuleUtils, AppendToUsedList2) { } using AppendFnType = decltype(&appendToGlobalCtors); -using ParamType = std::tuple<StringRef, AppendFnType>; +using UpdateFnType = decltype(&updateGlobalCtors); +using ParamType = std::tuple<StringRef, AppendFnType, UpdateFnType>; class ModuleUtilsTest : public testing::TestWithParam<ParamType> { public: StringRef arrayName() const { return std::get<0>(GetParam()); } AppendFnType appendFn() const { return std::get<AppendFnType>(GetParam()); } + UpdateFnType updateFn() const { return std::get<UpdateFnType>(GetParam()); } }; INSTANTIATE_TEST_SUITE_P( ModuleUtilsTestCtors, ModuleUtilsTest, - ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors}, - ParamType{"llvm.global_dtors", &appendToGlobalDtors})); + ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors, + &updateGlobalCtors}, + ParamType{"llvm.global_dtors", &appendToGlobalDtors, + &updateGlobalDtors})); TEST_P(ModuleUtilsTest, AppendToMissingArray) { LLVMContext C; @@ -124,3 +128,39 @@ TEST_P(ModuleUtilsTest, AppendToArray) { 11, nullptr); EXPECT_EQ(3, getListSize(*M, arrayName())); } + +TEST_P(ModuleUtilsTest, UpdateArray) { + LLVMContext C; + + std::unique_ptr<Module> M = + parseIR(C, (R"(@)" + arrayName() + + R"( = appending global [2 x { i32, ptr, ptr }] [ + { i32, ptr, ptr } { i32 65535, ptr null, ptr null }, + { i32, ptr, ptr } { i32 0, ptr null, ptr null }] + )") + .str()); + + EXPECT_EQ(2, getListSize(*M, arrayName())); + updateFn()(*M, [](Constant *C) -> Constant * { + ConstantStruct *CS = dyn_cast<ConstantStruct>(C); + if (!CS) + return nullptr; + StructType *EltTy = cast<StructType>(C->getType()); + Constant *CSVals[3] = { + ConstantInt::getSigned(CS->getOperand(0)->getType(), 12), + CS->getOperand(1), + CS->getOperand(2), + }; + return ConstantStruct::get(EltTy, + ArrayRef(CSVals, EltTy->getNumElements())); + }); + EXPECT_EQ(1, getListSize(*M, arrayName())); + ConstantArray *CA = dyn_cast<ConstantArray>( + M->getGlobalVariable(arrayName())->getInitializer()); + ASSERT_NE(nullptr, CA); + ConstantStruct *CS = dyn_cast<ConstantStruct>(CA->getOperand(0)); + ASSERT_NE(nullptr, CS); + ConstantInt *Pri = dyn_cast<ConstantInt>(CS->getOperand(0)); + ASSERT_NE(nullptr, Pri); + EXPECT_EQ(12u, Pri->getLimitedValue()); +} `````````` </details> https://github.com/llvm/llvm-project/pull/101757 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits