https://github.com/manman-ren updated https://github.com/llvm/llvm-project/pull/68235
>From 7e422cecbd5fc28fb0ec699b702d6bccf321f93f Mon Sep 17 00:00:00 2001 From: Manman Ren <m...@fb.com> Date: Mon, 2 Oct 2023 11:16:58 -0700 Subject: [PATCH] Preliminary patch for merging functions that differ in constants --- .../IPO/MergeFunctionsIgnoringConst.h | 34 + .../Transforms/Utils/FunctionComparator.h | 1 + .../Utils/FunctionComparatorIgnoringConst.h | 58 + .../Utils/FunctionHashIgnoringConst.h | 79 + .../Utils/MergeFunctionsIgnoringConst.h | 29 + llvm/lib/Passes/PassBuilder.cpp | 1 + llvm/lib/Passes/PassBuilderPipelines.cpp | 11 + llvm/lib/Passes/PassRegistry.def | 1 + llvm/lib/Transforms/IPO/CMakeLists.txt | 1 + .../IPO/MergeFunctionsIgnoringConst.cpp | 1430 +++++++++++++++++ llvm/lib/Transforms/Utils/CMakeLists.txt | 2 + .../Utils/FunctionComparatorIgnoringConst.cpp | 107 ++ .../Utils/FunctionHashIgnoringConst.cpp | 620 +++++++ .../unittests/Transforms/Utils/CMakeLists.txt | 1 + .../Utils/FunctionHashIgnoringConstTest.cpp | 120 ++ 15 files changed, 2495 insertions(+) create mode 100644 llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h create mode 100644 llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h create mode 100644 llvm/include/llvm/Transforms/Utils/FunctionHashIgnoringConst.h create mode 100644 llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h create mode 100644 llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp create mode 100644 llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp create mode 100644 llvm/lib/Transforms/Utils/FunctionHashIgnoringConst.cpp create mode 100644 llvm/unittests/Transforms/Utils/FunctionHashIgnoringConstTest.cpp diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h b/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h new file mode 100644 index 000000000000000..f9d55cc40873adc --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h @@ -0,0 +1,34 @@ +//===- MergeFunctionsIgnoringConst.h - Merge Functions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass transforms simple global variables that never have their address +// taken. If obviously true, it marks read/write globals as constant, deletes +// variables only stored to, etc. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONSIGNORINGCONST_H +#define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONSIGNORINGCONST_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class Module; + +/// Merge functions that differ by constants. +class MergeFuncIgnoringConstPass + : public PassInfoMixin<MergeFuncIgnoringConstPass> { +public: + MergeFuncIgnoringConstPass() {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_IPO_MERGEFUNCTIONSIGNORINGCONST_H diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h index c28f868039a1f7b..1a314b481c72c61 100644 --- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h +++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h @@ -379,6 +379,7 @@ class FunctionComparator { /// But, we are still not able to compare operands of PHI nodes, since those /// could be operands from further BBs we didn't scan yet. /// So it's impossible to use dominance properties in general. +protected: mutable DenseMap<const Value*, int> sn_mapL, sn_mapR; // The global state we will use diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h b/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h new file mode 100644 index 000000000000000..a61e02fa41db762 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h @@ -0,0 +1,58 @@ +//===- FunctionComparatorIgnoringConst.h - Function Comparator --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the FunctionComparatorIgnoringConst class which is used by +// the MergeFuncIgnoringConst pass for comparing functions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_FUNCTIONCOMPARATORIGNORINGCONST_H +#define LLVM_TRANSFORMS_UTILS_FUNCTIONCOMPARATORIGNORINGCONST_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/ValueMap.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Transforms/Utils/FunctionComparator.h" +#include <set> + +namespace llvm { + +/// FunctionComparatorIgnoringConst - Compares two functions to determine +/// whether or not they will generate machine code with the same behavior. +class FunctionComparatorIgnoringConst : public FunctionComparator { +public: + FunctionComparatorIgnoringConst(const Function *F1, const Function *F2, + GlobalNumberState *GN) + : FunctionComparator(F1, F2, GN) {} + + int cmpOperandsIgnoringConsts(const Instruction *L, const Instruction *R, + unsigned opIdx); + + int cmpBasicBlocksIgnoringConsts( + const BasicBlock *BBL, const BasicBlock *BBR, + const std::set<std::pair<int, int>> *InstOpndIndex = nullptr); + + int compareIgnoringConsts( + const std::set<std::pair<int, int>> *InstOpndIndex = nullptr); + + int compareConstants(const Constant *L, const Constant *R) const { + return cmpConstants(L, R); + } + +private: + // Scratch index for instruction in order during cmpOperandsIgnoringConsts. + int index = 0; +}; + +} // end namespace llvm +#endif // LLVM_TRANSFORMS_UTILS_FUNCTIONCOMPARATORIGNORINGCONST_H diff --git a/llvm/include/llvm/Transforms/Utils/FunctionHashIgnoringConst.h b/llvm/include/llvm/Transforms/Utils/FunctionHashIgnoringConst.h new file mode 100644 index 000000000000000..d696ae8c2381128 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/FunctionHashIgnoringConst.h @@ -0,0 +1,79 @@ +//===- FunctionHashIgnoringConst.h - Function Hash -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the FunctionHashIgnoringConst class which is used by the +// global merge functions that can differ by Constants. This provides stable +// functions hash that ignores Constants. As for Constants that are ignored, +// this also track their locations (instruction, operand) indices. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_FUNCTIONHASHIGNORINGCONST_H +#define LLVM_TRANSFORMS_UTILS_FUNCTIONHASHIGNORINGCONST_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/ValueMap.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h" +#include <map> + +namespace llvm { + +class FunctionHashIgnoringConst : public FunctionComparatorIgnoringConst { +public: + using FunctionHash = uint64_t; + + /// Get function hash by ignoring Constant operands. + /// This is different than FunctionComparator::functionHash which just hashes + /// the opcode. Optionally IdxToConstHash map is passed to return the pair of + /// instruction and operand indices to the const hash. Optionally IdxToInst + /// map is passed to return a map of instruction index to instruction. + static FunctionHash functionHash( + Function &F, std::map<int, Instruction *> *IdxToInst = nullptr, + std::map<std::pair<int, int>, uint64_t> *IdxToConstHash = nullptr); + +private: + FunctionHashIgnoringConst(const Function *F1, GlobalNumberState *GN) + : FunctionComparatorIgnoringConst(F1, F1, GN) {} + + FunctionHash + hashIgnoringConsts(std::map<int, Instruction *> &IdxToInst, + std::map<std::pair<int, int>, uint64_t> &IdxToConstHash); + + FunctionHash hashBasicBlocksIgnoringConsts( + const BasicBlock *BBL, std::map<int, Instruction *> &IdxToInst, + std::map<std::pair<int, int>, uint64_t> &IdxToConstHash); + + FunctionHash hashType(Type *TyL) const; + FunctionHash hashValue(const Value *v) const; + FunctionHash hashOperation(const Instruction *i, + bool &needToCmpOperands) const; + FunctionHash hashAttrs(const AttributeList L) const; + FunctionHash hashSignature() const; + + FunctionHash hashInlineAsm(const InlineAsm *L) const; + FunctionHash hashConstant(const Constant *L) const; + FunctionHash hashAPInt(const APInt &L) const; + FunctionHash hashAPFloat(const APFloat &L) const; + FunctionHash hashGlobalValue(const GlobalValue *L) const; + FunctionHash hashGEP(const GEPOperator *GEPL) const; + FunctionHash hashOperandBundlesSchema(const CallBase &LCS) const; + FunctionHash hashRangeMetadata(const MDNode *L) const; + +private: + // Scratch index for instruction in order during hashIgnoringConsts. + int index = 0; +}; + +} // end namespace llvm +#endif // LLVM_TRANSFORMS_UTILS_FUNCTIONHASHIGNORINGCONST_H diff --git a/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h b/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h new file mode 100644 index 000000000000000..e63afbb6bbf1718 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h @@ -0,0 +1,29 @@ +//===- MergeFunctionsIgnoringConst.h - Merge Functions ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers used in the MergeFunctionsIgnoringConst. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_MERGEFUNCTIONSIGNORINGCONST_H +#define LLVM_TRANSFORMS_UTILS_MERGEFUNCTIONSIGNORINGCONST_H + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" + +using namespace llvm; + +bool isEligibleInstrunctionForConstantSharing(const Instruction *I); + +bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx); + +bool isEligibleFunction(Function *F); + +Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy); +#endif // LLVM_TRANSFORMS_UTILS_MERGEFUNCTIONSIGNORINGCONST_H diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 985ff88139323c6..14a0b62cb9a81c9 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -120,6 +120,7 @@ #include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h" #include "llvm/Transforms/IPO/MergeFunctions.h" +#include "llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h" #include "llvm/Transforms/IPO/ModuleInliner.h" #include "llvm/Transforms/IPO/OpenMPOpt.h" #include "llvm/Transforms/IPO/PartialInlining.h" diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 78e0e6353056343..4a8051405f67025 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -59,6 +59,7 @@ #include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h" #include "llvm/Transforms/IPO/MergeFunctions.h" +#include "llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h" #include "llvm/Transforms/IPO/ModuleInliner.h" #include "llvm/Transforms/IPO/OpenMPOpt.h" #include "llvm/Transforms/IPO/PartialInlining.h" @@ -175,6 +176,10 @@ static cl::opt<bool> EnableMergeFunctions( "enable-merge-functions", cl::init(false), cl::Hidden, cl::desc("Enable function merging as part of the optimization pipeline")); +static cl::opt<bool> EnableMergeFuncIgnoringConst( + "enable-merge-func-ignoring-const", cl::init(false), cl::Hidden, + cl::desc("Enable function merger that ignores constants")); + static cl::opt<bool> EnablePostPGOLoopRotation( "enable-post-pgo-loop-rotation", cl::init(true), cl::Hidden, cl::desc("Run the loop rotation transformation after PGO instrumentation")); @@ -1628,6 +1633,9 @@ ModulePassManager PassBuilder::buildThinLTODefaultPipeline( MPM.addPass(buildModuleOptimizationPipeline( Level, ThinOrFullLTOPhase::ThinLTOPostLink)); + if (EnableMergeFuncIgnoringConst) + MPM.addPass(MergeFuncIgnoringConstPass()); + // Emit annotation remarks. addAnnotationRemarksPass(MPM); @@ -1953,6 +1961,9 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, invokeFullLinkTimeOptimizationLastEPCallbacks(MPM, Level); + if (EnableMergeFuncIgnoringConst) + MPM.addPass(MergeFuncIgnoringConstPass()); + // Emit annotation remarks. addAnnotationRemarksPass(MPM); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index df9f14920f29161..fe6837b3891aead 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -84,6 +84,7 @@ MODULE_PASS("lower-ifunc", LowerIFuncPass()) MODULE_PASS("lowertypetests", LowerTypeTestsPass()) MODULE_PASS("metarenamer", MetaRenamerPass()) MODULE_PASS("mergefunc", MergeFunctionsPass()) +MODULE_PASS("mergefunc-ignoring-const", MergeFuncIgnoringConstPass()) MODULE_PASS("name-anon-globals", NameAnonGlobalPass()) MODULE_PASS("no-op-module", NoOpModulePass()) MODULE_PASS("objc-arc-apelim", ObjCARCAPElimPass()) diff --git a/llvm/lib/Transforms/IPO/CMakeLists.txt b/llvm/lib/Transforms/IPO/CMakeLists.txt index 034f1587ae8df44..4dac04d3369950f 100644 --- a/llvm/lib/Transforms/IPO/CMakeLists.txt +++ b/llvm/lib/Transforms/IPO/CMakeLists.txt @@ -30,6 +30,7 @@ add_llvm_component_library(LLVMipo LowerTypeTests.cpp MemProfContextDisambiguation.cpp MergeFunctions.cpp + MergeFunctionsIgnoringConst.cpp ModuleInliner.cpp OpenMPOpt.cpp PartialInlining.cpp diff --git a/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp b/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp new file mode 100644 index 000000000000000..4c576123a11f365 --- /dev/null +++ b/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp @@ -0,0 +1,1430 @@ +//===--- MergeFunctionsIgnoringConst.cpp - Merge functions ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass looks for similar functions that are mergeable and folds them. +// The implementation is similar to LLVM's MergeFunctions pass. Instead of +// merging identical functions, it merges functions which only differ by a few +// constants in certain instructions. +// This is copied from Swift's implementation. +// TODO: We should generalize this pass and share it with Swift's +// implementation. +// +// This pass should run after LLVM's MergeFunctions pass, because it works best +// if there are no _identical_ functions in the module. +// Note: it would also work for identical functions but could produce more +// code overhead than the LLVM pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h" +// #include "llvm/Transforms/Utils/GlobalMergeFunctions.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h" +// #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/ObjCARCUtil.h" +#include "llvm/CodeGen/StableHashing.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +// #include "llvm/IR/GlobalPtrAuthInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/StructuralHash.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/IR/ValueMap.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "mergefunc-ignoring-const" + +STATISTIC(NumFunctionsMergedIgnoringConst, "Number of functions merged"); +STATISTIC(NumThunksWrittenIgnoringConst, "Number of thunks generated"); + +static cl::opt<bool> + EnableMergeFunc2("enable-merge-func2", cl::init(false), cl::Hidden, + cl::desc("Enable more aggressive function merger")); + +static cl::opt<unsigned> NumFunctionsIgnoringConstForSanityCheck( + "mergefunc-ignoringconst-sanity", + cl::desc("How many functions in module could be used for " + "MergeFunctionsIgnoringConst pass sanity check. " + "'0' disables this check. Works only with '-debug' key."), + cl::init(0), cl::Hidden); + +static cl::opt<unsigned> IgnoringConstMergeThreshold( + "mergefunc-ignoringconst-threshold", + cl::desc("Functions larger than the threshold are considered for merging." + "'0' disables function merging at all."), + cl::init(15), cl::Hidden); + +cl::opt<bool> UseLinkOnceODRLinkageMerging( + "use-linkonceodr-linkage-merging", cl::init(false), cl::Hidden, + cl::desc( + "Use LinkeOnceODR linkage to deduplicate the identical merged function " + "(default = off)")); + +cl::opt<bool> NoInlineForMergedFunction( + "no-inline-merged-function", cl::init(false), cl::Hidden, + cl::desc("set noinline for merged function (default = off)")); + +static cl::opt<bool> + CastArrayType("merge-cast-array-type", cl::init(false), cl::Hidden, + cl::desc("support for casting array type (default = off)")); + +static cl::opt<bool> IgnoreMusttailFunction( + "ignore-musttail-function", cl::init(false), cl::Hidden, + cl::desc( + "ignore functions containing callsites with musttail (default = off)")); + +static cl::opt<bool> AlwaysCallThunk( + "merge-always-call-thunk", cl::init(false), cl::Hidden, + cl::desc( + "do not replace callsites and always emit a thunk (default = off)")); + +static cl::list<std::string> MergeBlockRegexFilters( + "merge-block-regex", cl::Optional, + cl::desc("Block functions from merging if they match the given " + "regular expression"), + cl::ZeroOrMore); + +static cl::list<std::string> MergeAllowRegexFilters( + "merge-allow-regex", cl::Optional, + cl::desc("Allow functions from merging if they match the given " + "regular expression"), + cl::ZeroOrMore); + +bool isEligibleInstrunctionForConstantSharing(const Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Load: + case Instruction::Store: + case Instruction::Call: + return true; + default: { + if (EnableMergeFunc2 && I->getOpcode() == Instruction::Invoke) + return true; + return false; + } + } +} + +/// Returns true if the \opIdx operand of \p CI is the callee operand. +static bool isCalleeOperand(const CallBase *CI, unsigned opIdx) { + return &CI->getCalledOperandUse() == &CI->getOperandUse(opIdx); +} + +static bool canParameterizeCallOperand(const CallBase *CI, unsigned opIdx) { + if (CI->isInlineAsm()) + return false; + Function *Callee = CI->getCalledOperand() + ? dyn_cast_or_null<Function>( + CI->getCalledOperand()->stripPointerCasts()) + : nullptr; + if (Callee) { + if (Callee->isIntrinsic()) + return false; + // objc_msgSend stubs must be called, and can't have their address taken. + if (Callee->getName().startswith("objc_msgSend$")) + return false; + } + if (isCalleeOperand(CI, opIdx) && + CI->getOperandBundle(LLVMContext::OB_ptrauth).has_value()) { + // The operand is the callee and it has already been signed. Ignore this + // because we cannot add another ptrauth bundle to the call instruction. + return false; + } + return true; +} + +bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx) { + assert(OpIdx < I->getNumOperands() && "Invalid operand index"); + + if (!isEligibleInstrunctionForConstantSharing(I)) + return false; + + auto Opnd = I->getOperand(OpIdx); + if (!isa<Constant>(Opnd)) + return false; + + if (const auto *CI = dyn_cast<CallBase>(I)) + return canParameterizeCallOperand(CI, OpIdx); + + return true; +} + +namespace { + +/// MergeFuncIgnoringConst finds functions which only differ by constants in +/// certain instructions, e.g. resulting from specialized functions of layout +/// compatible types. +/// Such functions are merged by replacing the differing constants by a +/// parameter. The original functions are replaced by thunks which call the +/// merged function with the specific argument constants. +/// +class MergeFuncIgnoringConstImpl { // : public ModulePass { +public: + MergeFuncIgnoringConstImpl() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {} + + MergeFuncIgnoringConstImpl(bool ptrAuthEnabled, unsigned ptrAuthKey) + : FnTree(FunctionNodeCmp(&GlobalNumbers)), ptrAuthOptionsSet(true), + ptrAuthEnabled(ptrAuthEnabled), ptrAuthKey(ptrAuthKey) {} + + bool runImpl(Module &M); + +private: + struct FunctionEntry; + + /// Describes the set of functions which are considered as "equivalent" (i.e. + /// only differing by some constants). + struct EquivalenceClass { + /// The single-linked list of all functions which are a member of this + /// equivalence class. + FunctionEntry *First; + + /// A very cheap hash, used to early exit if functions do not match. + llvm::IRHash Hash; + + public: + // Note the hash is recalculated potentially multiple times, but it is + // cheap. + EquivalenceClass(FunctionEntry *First) + : First(First), Hash(StructuralHash(*First->F)) { + assert(!First->Next); + } + }; + + /// The function comparison operator is provided here so that FunctionNodes do + /// not need to become larger with another pointer. + class FunctionNodeCmp { + GlobalNumberState *GlobalNumbers; + + public: + FunctionNodeCmp(GlobalNumberState *GN) : GlobalNumbers(GN) {} + bool operator()(const EquivalenceClass &LHS, + const EquivalenceClass &RHS) const { + // Order first by hashes, then full function comparison. + if (LHS.Hash != RHS.Hash) + return LHS.Hash < RHS.Hash; + FunctionComparatorIgnoringConst FCmp(LHS.First->F, RHS.First->F, + GlobalNumbers); + return FCmp.compareIgnoringConsts() == -1; + } + }; + using FnTreeType = std::set<EquivalenceClass, FunctionNodeCmp>; + + /// + struct FunctionEntry { + FunctionEntry(Function *F, FnTreeType::iterator I) + : F(F), Next(nullptr), numUnhandledCallees(0), TreeIter(I), + isMerged(false) {} + + /// Back-link to the function. + AssertingVH<Function> F; + + /// The next function in its equivalence class. + FunctionEntry *Next; + + /// The number of not-yet merged callees. Used to process the merging in + /// bottom-up call order. + /// This is only valid in the first entry of an equivalence class. The + /// counts of all functions in an equivalence class are accumulated in the + /// first entry. + int numUnhandledCallees; + + /// The iterator of the function's equivalence class in the FnTree. + /// It's FnTree.end() if the function is not in an equivalence class. + FnTreeType::iterator TreeIter; + + /// True if this function is already a thunk, calling the merged function. + bool isMerged; + }; + + /// Describes an operator of a specific instruction. + struct OpLocation { + Instruction *I; + unsigned OpIndex; + }; + + /// Information for a function. Used during merging. + struct FunctionInfo { + + FunctionInfo(Function *F) + : F(F), CurrentInst(nullptr), NumParamsNeeded(0) {} + + void init() { + CurrentInst = &*F->begin()->begin(); + NumParamsNeeded = 0; + } + + /// Advances the current instruction to the next instruction. + void nextInst() { + assert(CurrentInst); + if (CurrentInst->isTerminator()) { + auto BlockIter = std::next(CurrentInst->getParent()->getIterator()); + if (BlockIter == F->end()) { + CurrentInst = nullptr; + return; + } + CurrentInst = &*BlockIter->begin(); + return; + } + CurrentInst = &*std::next(CurrentInst->getIterator()); + } + + /// Returns true if the operand \p OpIdx of the current instruction is the + /// callee of a call, which needs to be signed if passed as a parameter. + bool needsPointerSigning(unsigned OpIdx) const { + if (auto *CI = dyn_cast<CallInst>(CurrentInst)) + return isCalleeOperand(CI, OpIdx); + return false; + } + + Function *F; + + /// The current instruction while iterating over all instructions. + Instruction *CurrentInst; + + /// Roughly the number of parameters needed if this function would be + /// merged with the first function of the equivalence class. + int NumParamsNeeded; + }; + + using FunctionInfos = SmallVector<FunctionInfo, 8>; + + /// Describes a parameter which we create to parameterize the merged function. + struct ParamInfo { + /// The value of the parameter for all the functions in the equivalence + /// class. + SmallVector<Constant *, 8> Values; + + /// All uses of the parameter in the merged function. + SmallVector<OpLocation, 16> Uses; + + /// The discriminator for pointer signing. + /// Only not null if needsPointerSigning is true. + ConstantInt *discriminator = nullptr; + + /// True if the value is a callee function, which needs to be signed if + /// passed as a parameter. + bool needsPointerSigning = false; + + /// Checks if this parameter can be used to describe an operand in all + /// functions of the equivalence class. Returns true if all values match + /// the specific instruction operands in all functions. + bool matches(const FunctionInfos &FInfos, unsigned OpIdx, + bool ptrAuthEnabled) const { + unsigned NumFuncs = FInfos.size(); + assert(Values.size() == NumFuncs); + if (ptrAuthEnabled && + needsPointerSigning != FInfos[0].needsPointerSigning(OpIdx)) { + return false; + } + for (unsigned Idx = 0; Idx < NumFuncs; ++Idx) { + const FunctionInfo &FI = FInfos[Idx]; + Constant *C = cast<Constant>(FI.CurrentInst->getOperand(OpIdx)); + if (Values[Idx] != C) + return false; + } + return true; + } + + /// Computes the discriminator for pointer signing. + void computeDiscriminator(LLVMContext &Context) { + assert(needsPointerSigning); + assert(!discriminator); + + /// Get a hash from the concatenated function names. + /// The hash is deterministic, because the order of values depends on the + /// order of functions in the module, which is itself deterministic. + /// Note that the hash is not part of the ABI, because it's purly used + /// for pointer authentication between a module-private caller-callee + /// pair. + std::string concatenatedCalleeNames; + for (Constant *value : Values) { + if (auto *GO = dyn_cast<GlobalObject>(value)) + concatenatedCalleeNames += GO->getName(); + } + uint64_t rawHash = stable_hash_combine_string(concatenatedCalleeNames); + IntegerType *discrTy = Type::getInt64Ty(Context); + discriminator = ConstantInt::get(discrTy, (rawHash % 0xFFFF) + 1); + } + }; + + using ParamInfos = SmallVector<ParamInfo, 16>; + + Module *module = nullptr; + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; + + GlobalNumberState GlobalNumbers; + + /// A work queue of functions that may have been modified and should be + /// analyzed again. + std::vector<WeakTrackingVH> Deferred; + + /// The set of all distinct functions. Use the insert() and remove() methods + /// to modify it. The map allows efficient lookup and deferring of Functions. + FnTreeType FnTree; + + ValueMap<Function *, FunctionEntry *> FuncEntries; + + // Maps a function-pointer / discriminator pair to a corresponding global in + // the llvm.ptrauth section. + // This map is used as a cache to not create ptrauth globals twice. + DenseMap<std::pair<Constant *, ConstantInt *>, Constant *> ptrAuthGlobals; + + /// If true, ptrAuthEnabled and ptrAuthKey are valid. + bool ptrAuthOptionsSet = false; + + /// True if the architecture has pointer authentication enabled. + bool ptrAuthEnabled = false; + + /// The key for pointer authentication. + unsigned ptrAuthKey = 0; + + FunctionEntry *getEntry(Function *F) const { return FuncEntries.lookup(F); } + + bool isInEquivalenceClass(FunctionEntry *FE) const { + if (FE->TreeIter != FnTree.end()) { + return true; + } + assert(!FE->Next); + assert(FE->numUnhandledCallees == 0); + return false; + } + + /// Checks the rules of order relation introduced among functions set. + /// Returns true, if sanity check has been passed, and false if failed. + bool doSanityCheck(std::vector<WeakTrackingVH> &Worklist); + + /// Updates the numUnhandledCallees of all user functions of the equivalence + /// class containing \p FE by \p Delta. + void updateUnhandledCalleeCount(FunctionEntry *FE, int Delta); + + bool tryMergeEquivalenceClass(FunctionEntry *FirstInClass); + + FunctionInfo removeFuncWithMostParams(FunctionInfos &FInfos); + + bool deriveParams(ParamInfos &Params, FunctionInfos &FInfos, + unsigned maxParams); + + bool numOperandsDiffer(FunctionInfos &FInfos); + + bool constsDiffer(const FunctionInfos &FInfos, unsigned OpIdx); + + bool tryMapToParameter(FunctionInfos &FInfos, unsigned OpIdx, + ParamInfos &Params, unsigned maxParams); + + void replaceCallWithAddedPtrAuth(CallInst *origCall, Value *newCallee, + ConstantInt *discriminator); + + void mergeWithParams(const FunctionInfos &FInfos, ParamInfos &Params); + static void dumpMergeInfo(const FunctionInfos &FInfos, unsigned); + + void removeEquivalenceClassFromTree(FunctionEntry *FE); + + void writeThunk(Function *ToFunc, Function *Thunk, const ParamInfos &Params, + unsigned FuncIdx); + + bool isPtrAuthEnabled() const { + // TODO: fix pointer authentication + // assert(ptrAuthOptionsSet); + return ptrAuthEnabled; + } + + ConstantInt *getPtrAuthKey() { + // TODO: fix pointer authentication + // assert(isPtrAuthEnabled()); + return ConstantInt::get(Type::getInt32Ty(module->getContext()), ptrAuthKey); + } + + /// Returns the value of function \p FuncIdx, and signes it if required. + Constant *getSignedValue(const ParamInfo &PI, unsigned FuncIdx) { + Constant *value = PI.Values[FuncIdx]; + if (!PI.needsPointerSigning) + return value; + + auto lookupKey = std::make_pair(value, PI.discriminator); + Constant *&ptrAuthGlobal = ptrAuthGlobals[lookupKey]; + if (!ptrAuthGlobal) { +#if 0 + ptrAuthGlobal = GlobalPtrAuthInfo::create( + *module, value, getPtrAuthKey(), + ConstantInt::get(PI.discriminator->getType(), 0), PI.discriminator); +#endif + } + return ptrAuthGlobal; + } + + /// Replace all direct calls of Old with calls of New. Will bitcast New if + /// necessary to make types match. + bool replaceDirectCallers(Function *Old, Function *New, + const ParamInfos &Params, unsigned FuncIdx); +}; + +#if 0 +class MergeFuncIgnoringConst : public ModulePass { +public: + static char ID; + /// True if the architecture has pointer authentication enabled. + bool ptrAuthEnabled = false; + + /// The key for pointer authentication. + unsigned ptrAuthKey = 0; + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; + + MergeFuncIgnoringConst() : ModulePass(ID) { + initializeMergeFuncIgnoringConstPass(*llvm::PassRegistry::getPassRegistry()); + } + MergeFuncIgnoringConst(bool ptrAuthEnabled, unsigned ptrAuthKey) + : ModulePass(ID), ptrAuthEnabled(ptrAuthEnabled), ptrAuthKey(ptrAuthKey) { + initializeMergeFuncIgnoringConstPass(*llvm::PassRegistry::getPassRegistry()); + } + bool runOnModule(Module &M) override; +}; +#endif + +} // end anonymous namespace + +#if 0 +char MergeFuncIgnoringConst::ID = 0; +INITIALIZE_PASS_BEGIN(MergeFuncIgnoringConst, "merge-func-ignoring-const", + "merge function pass ignoring const", false, false) +INITIALIZE_PASS_END(MergeFuncIgnoringConst, "merge-func-ignoring-const", + "merge function pass ignoring const", false, false) +#endif +bool MergeFuncIgnoringConstImpl::doSanityCheck( + std::vector<WeakTrackingVH> &Worklist) { + if (const unsigned Max = NumFunctionsIgnoringConstForSanityCheck) { + unsigned TripleNumber = 0; + bool Valid = true; + + dbgs() << "MERGEFUNC-SANITY: Started for first " << Max << " functions.\n"; + + unsigned i = 0; + for (std::vector<WeakTrackingVH>::iterator I = Worklist.begin(), + E = Worklist.end(); + I != E && i < Max; ++I, ++i) { + unsigned j = i; + for (std::vector<WeakTrackingVH>::iterator J = I; J != E && j < Max; + ++J, ++j) { + Function *F1 = cast<Function>(*I); + Function *F2 = cast<Function>(*J); + int Res1 = FunctionComparatorIgnoringConst(F1, F2, &GlobalNumbers) + .compareIgnoringConsts(); + int Res2 = FunctionComparatorIgnoringConst(F2, F1, &GlobalNumbers) + .compareIgnoringConsts(); + + // If F1 <= F2, then F2 >= F1, otherwise report failure. + if (Res1 != -Res2) { + dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber + << "\n"; + LLVM_DEBUG(F1->dump()); + LLVM_DEBUG(F2->dump()); + Valid = false; + } + + if (Res1 == 0) + continue; + + unsigned k = j; + for (std::vector<WeakTrackingVH>::iterator K = J; K != E && k < Max; + ++k, ++K, ++TripleNumber) { + if (K == J) + continue; + + Function *F3 = cast<Function>(*K); + int Res3 = FunctionComparatorIgnoringConst(F1, F3, &GlobalNumbers) + .compareIgnoringConsts(); + int Res4 = FunctionComparatorIgnoringConst(F2, F3, &GlobalNumbers) + .compareIgnoringConsts(); + + bool Transitive = true; + + if (Res1 != 0 && Res1 == Res4) { + // F1 > F2, F2 > F3 => F1 > F3 + Transitive = Res3 == Res1; + } else if (Res3 != 0 && Res3 == -Res4) { + // F1 > F3, F3 > F2 => F1 > F2 + Transitive = Res3 == Res1; + } else if (Res4 != 0 && -Res3 == Res4) { + // F2 > F3, F3 > F1 => F2 > F1 + Transitive = Res4 == -Res1; + } + + if (!Transitive) { + dbgs() << "MERGEFUNC-SANITY: Non-transitive; triple: " + << TripleNumber << "\n"; + dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", " + << Res4 << "\n"; + LLVM_DEBUG(F1->dump()); + LLVM_DEBUG(F2->dump()); + LLVM_DEBUG(F3->dump()); + Valid = false; + } + } + } + } + + dbgs() << "MERGEFUNC-SANITY: " << (Valid ? "Passed." : "Failed.") << "\n"; + return Valid; + } + return true; +} + +/// Returns true if functions containing calls to \p F may be merged together. +static bool mayMergeCallsToFunction(Function &F) { + StringRef Name = F.getName(); + + // Calls to dtrace probes must generate unique patchpoints. + if (Name.startswith("__dtrace")) + return false; + + return true; +} + +/// Returns the benefit, which is approximately the size of the function. +/// Return 0, if the function should not be merged. +static unsigned getBenefit(Function *F) { + unsigned Benefit = 0; + + // We don't want to merge very small functions, because the overhead of + // adding creating thunks and/or adding parameters to the call sites + // outweighs the benefit. + for (BasicBlock &BB : *F) { + for (Instruction &I : BB) { + if (CallBase *CB = dyn_cast<CallBase>(&I)) { + Function *Callee = CB->getCalledFunction(); + if (Callee && !mayMergeCallsToFunction(*Callee)) + return 0; + if (!Callee || !Callee->isIntrinsic()) { + Benefit += 5; + continue; + } + } + Benefit += 1; + } + } + return Benefit; +} + +/// Returns true if function \p F is eligible for merging. +bool isEligibleFunction(Function *F) { + if (F->isDeclaration()) + return false; + + if (F->hasFnAttribute(llvm::Attribute::NoMerge)) + return false; + + if (F->hasAvailableExternallyLinkage()) { + return false; + } + + if (F->getFunctionType()->isVarArg()) { + return false; + } + + // Check against blocklist. + if (!MergeBlockRegexFilters.empty()) { + StringRef FuncName = F->getName(); + for (const auto &tRegex : MergeBlockRegexFilters) + if (Regex(tRegex).match(FuncName)) { + return false; + } + } + // Check against allowlist + if (!MergeAllowRegexFilters.empty()) { + StringRef FuncName = F->getName(); + bool found = false; + for (const auto &tRegex : MergeAllowRegexFilters) + if (Regex(tRegex).match(FuncName)) { + found = true; + break; + } + if (!found) + return false; + } + + if (F->getCallingConv() == CallingConv::SwiftTail) + return false; + + // if function contains callsites with musttail, if we merge + // it, the merged function will have the musttail callsite, but + // the number of parameters can change, thus the parameter count + // of the callsite will mismatch with the function itself. + if (IgnoreMusttailFunction) { + for (const BasicBlock &BB : *F) { + for (const Instruction &I : BB) { + const auto *CB = dyn_cast<CallBase>(&I); + if (CB && CB->isMustTailCall()) + return false; + } + } + } + + unsigned Benefit = getBenefit(F); + if (Benefit < IgnoringConstMergeThreshold) { + return false; + } + + return true; +} + +static bool runInternal(Module &M) { + return MergeFuncIgnoringConstImpl().runImpl(M); +} + +// bool MergeFuncIgnoringConst::runOnModule(Module &M) { return runInternal(M); +// } + +bool MergeFuncIgnoringConstImpl::runImpl(Module &M) { + if (IgnoringConstMergeThreshold == 0) + return false; + + module = &M; + +#if 0 + // TODO: fix pointer authentication + if (!ptrAuthOptionsSet) { + // If invoked from IRGen in the compiler, those options are already set. + // If invoked from swift-llvm-opt, derive the options from the target triple. + Triple triple(M.getTargetTriple()); + ptrAuthEnabled = (triple.getSubArch() == Triple::AArch64SubArch_arm64e); + ptrAuthKey = (unsigned)clang::PointerAuthSchema::ARM8_3Key::ASIA; + ptrAuthOptionsSet = true; + } +#endif + + bool Changed = false; + + // All functions in the module, ordered by hash. Functions with a unique + // hash value are easily eliminated. + std::vector<std::pair<llvm::IRHash, Function *>> HashedFuncs; + + for (Function &Func : M) { + if (isEligibleFunction(&Func)) { + HashedFuncs.push_back({StructuralHash(Func), &Func}); + } + } + + std::stable_sort(HashedFuncs.begin(), HashedFuncs.end(), + [](const std::pair<llvm::IRHash, Function *> &a, + const std::pair<llvm::IRHash, Function *> &b) { + return a.first < b.first; + }); + + std::vector<FunctionEntry> FuncEntryStorage; + FuncEntryStorage.reserve(HashedFuncs.size()); + + auto S = HashedFuncs.begin(); + for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) { + + Function *F = I->second; + FuncEntryStorage.push_back(FunctionEntry(F, FnTree.end())); + FunctionEntry &FE = FuncEntryStorage.back(); + FuncEntries[F] = &FE; + + // If the hash value matches the previous value or the next one, we must + // consider merging it. Otherwise it is dropped and never considered again. + if ((I != S && std::prev(I)->first == I->first) || + (std::next(I) != IE && std::next(I)->first == I->first)) { + Deferred.push_back(WeakTrackingVH(F)); + } + } + + do { + std::vector<WeakTrackingVH> Worklist; + Deferred.swap(Worklist); + + LLVM_DEBUG(dbgs() << "======\nbuild tree: worklist-size=" << Worklist.size() + << '\n'); + LLVM_DEBUG(doSanityCheck(Worklist)); + + SmallVector<FunctionEntry *, 8> FuncsToMerge; + + // Insert all candidates into the Worklist. + for (WeakTrackingVH &I : Worklist) { + if (!I) + continue; + Function *F = cast<Function>(I); + FunctionEntry *FE = getEntry(F); + assert(!isInEquivalenceClass(FE)); + + std::pair<FnTreeType::iterator, bool> Result = FnTree.insert(FE); + + FE->TreeIter = Result.first; + const EquivalenceClass &Eq = *Result.first; + + if (Result.second) { + assert(Eq.First == FE); + LLVM_DEBUG(dbgs() << " new in tree: " << F->getName() << '\n'); + } else { + assert(Eq.First != FE); + LLVM_DEBUG(dbgs() << " add to existing: " << F->getName() << '\n'); + // Add the function to the existing equivalence class. + FE->Next = Eq.First->Next; + Eq.First->Next = FE; + // Schedule for merging if the function's equivalence class reaches the + // size of 2. + if (!FE->Next) + FuncsToMerge.push_back(Eq.First); + } + } + LLVM_DEBUG(dbgs() << "merge functions: tree-size=" << FnTree.size() + << '\n'); + + // Figure out the leaf functions. We want to do the merging in bottom-up + // call order. This ensures that we don't parameterize on callee function + // names if we don't have to (because the callee may be merged). + // Note that "leaf functions" refer to the sub-call-graph of functions which + // are in the FnTree. + for (FunctionEntry *ToMerge : FuncsToMerge) { + assert(isInEquivalenceClass(ToMerge)); + updateUnhandledCalleeCount(ToMerge, 1); + } + + // Check if there are any leaf functions at all. + bool LeafFound = false; + for (FunctionEntry *ToMerge : FuncsToMerge) { + if (ToMerge->numUnhandledCallees == 0) + LeafFound = true; + } + for (FunctionEntry *ToMerge : FuncsToMerge) { + if (isInEquivalenceClass(ToMerge)) { + // Only merge leaf functions (or all functions if all functions are in + // a call cycle). + if (ToMerge->numUnhandledCallees == 0 || !LeafFound) { + updateUnhandledCalleeCount(ToMerge, -1); + Changed |= tryMergeEquivalenceClass(ToMerge); + } else { + // Non-leaf functions (i.e. functions in a call cycle) may become + // leaf functions in the next iteration. + removeEquivalenceClassFromTree(ToMerge); + } + } + } + } while (!Deferred.empty()); + + FnTree.clear(); + GlobalNumbers.clear(); + FuncEntries.clear(); + ptrAuthGlobals.clear(); + + return Changed; +} + +void MergeFuncIgnoringConstImpl::updateUnhandledCalleeCount(FunctionEntry *FE, + int Delta) { + // Iterate over all functions of FE's equivalence class. + do { + for (Use &U : FE->F->uses()) { + if (auto *I = dyn_cast<Instruction>(U.getUser())) { + FunctionEntry *CallerFE = getEntry(I->getFunction()); + if (CallerFE && CallerFE->TreeIter != FnTree.end()) { + // Accumulate the count in the first entry of the equivalence class. + FunctionEntry *Head = CallerFE->TreeIter->First; + Head->numUnhandledCallees += Delta; + } + } + } + FE = FE->Next; + } while (FE); +} + +bool MergeFuncIgnoringConstImpl::tryMergeEquivalenceClass( + FunctionEntry *FirstInClass) { + // Build the FInfos vector from all functions in the equivalence class. + FunctionInfos FInfos; + FunctionEntry *FE = FirstInClass; + do { + FInfos.push_back(FunctionInfo(FE->F)); + FE->isMerged = true; + FE = FE->Next; + } while (FE); + assert(FInfos.size() >= 2); + + // Merged or not: in any case we remove the equivalence class from the FnTree. + removeEquivalenceClassFromTree(FirstInClass); + + // Contains functions which differ too much from the first function (i.e. + // would need too many parameters). + FunctionInfos Removed; + + bool Changed = false; + int Try = 0; + + unsigned Benefit = getBenefit(FirstInClass->F); + + // The bigger the function, the more parameters are allowed. + unsigned maxParams = std::max(4u, Benefit / 100); + + // We need multiple tries if there are some functions in FInfos which differ + // too much from the first function in FInfos. But we limit the number of + // tries to a small number, because this is quadratic. + while (FInfos.size() >= 2 && Try++ < 4) { + ParamInfos Params; + bool Merged = deriveParams(Params, FInfos, maxParams); + if (Merged) { + mergeWithParams(FInfos, Params); + Changed = true; + } else { + // We ran out of parameters. Remove the function from the set which + // differs most from the first function. + Removed.push_back(removeFuncWithMostParams(FInfos)); + } + if (Merged || FInfos.size() < 2) { + // Try again with the functions which were removed from the original set. + FInfos.swap(Removed); + Removed.clear(); + } + } + return Changed; +} + +/// Remove the function from \p FInfos which needs the most parameters. Add the +/// removed function to +MergeFuncIgnoringConstImpl::FunctionInfo +MergeFuncIgnoringConstImpl::removeFuncWithMostParams(FunctionInfos &FInfos) { + FunctionInfos::iterator MaxIter = FInfos.end(); + for (auto Iter = FInfos.begin(), End = FInfos.end(); Iter != End; ++Iter) { + if (MaxIter == FInfos.end() || + Iter->NumParamsNeeded > MaxIter->NumParamsNeeded) { + MaxIter = Iter; + } + } + FunctionInfo Removed = *MaxIter; + FInfos.erase(MaxIter); + return Removed; +} + +/// Finds the set of parameters which are required to merge the functions in +/// \p FInfos. +/// Returns true on success, i.e. the functions in \p FInfos can be merged with +/// the parameters returned in \p Params. +bool MergeFuncIgnoringConstImpl::deriveParams(ParamInfos &Params, + FunctionInfos &FInfos, + unsigned maxParams) { + for (FunctionInfo &FI : FInfos) + FI.init(); + + FunctionInfo &FirstFI = FInfos.front(); + + // Iterate over all instructions synchronously in all functions. + do { + if (isEligibleInstrunctionForConstantSharing(FirstFI.CurrentInst)) { + + // Here we handle a rare corner case which needs to be explained: + // Usually the number of operands match, because otherwise the functions + // in FInfos would not be in the same equivalence class. There is only one + // exception to that: If the current instruction is a call to a function, + // which was merged in the previous iteration (in + // tryMergeEquivalenceClass) then the call could be replaced and has more + // arguments than the original call. + if (numOperandsDiffer(FInfos)) { + assert(isa<CallInst>(FirstFI.CurrentInst) && + "only calls are expected to differ in number of operands"); + return false; + } + + for (unsigned OpIdx = 0, NumOps = FirstFI.CurrentInst->getNumOperands(); + OpIdx != NumOps; ++OpIdx) { + + if (constsDiffer(FInfos, OpIdx)) { + // This instruction has operands which differ in at least some + // functions. So we need to parameterize it. + if (!tryMapToParameter(FInfos, OpIdx, Params, maxParams)) { + // We ran out of parameters. + return false; + } + } + } + } + // Go to the next instruction in all functions. + for (FunctionInfo &FI : FInfos) + FI.nextInst(); + } while (FirstFI.CurrentInst); + + return true; +} + +/// Returns true if the number of operands of the current instruction differs. +bool MergeFuncIgnoringConstImpl::numOperandsDiffer(FunctionInfos &FInfos) { + unsigned numOps = FInfos[0].CurrentInst->getNumOperands(); + for (const FunctionInfo &FI : ArrayRef<FunctionInfo>(FInfos).drop_front(1)) { + if (FI.CurrentInst->getNumOperands() != numOps) + return true; + } + return false; +} + +/// Returns true if the \p OpIdx's constant operand in the current instruction +/// does differ in any of the functions in \p FInfos. +bool MergeFuncIgnoringConstImpl::constsDiffer(const FunctionInfos &FInfos, + unsigned OpIdx) { + Constant *CommonConst = nullptr; + + for (const FunctionInfo &FI : FInfos) { + Value *Op = FI.CurrentInst->getOperand(OpIdx); + if (auto *C = dyn_cast<Constant>(Op)) { + if (!CommonConst) { + CommonConst = C; + } else if (EnableMergeFunc2 && isa<ConstantPointerNull>(CommonConst) && + isa<ConstantPointerNull>(C)) { + // if both are null pointer, and if they are different constants + // due to type, still treat them as the same. + } else if (C != CommonConst) { + return true; + } + } + } + return false; +} + +/// Create a new parameter for differing operands or try to reuse an existing +/// parameter. +/// Returns true if a parameter could be created or found without exceeding the +/// maximum number of parameters. +bool MergeFuncIgnoringConstImpl::tryMapToParameter(FunctionInfos &FInfos, + unsigned OpIdx, + ParamInfos &Params, + unsigned maxParams) { + ParamInfo *Matching = nullptr; + // Try to find an existing parameter which exactly matches the differing + // operands of the current instruction. + for (ParamInfo &PI : Params) { + if (PI.matches(FInfos, OpIdx, isPtrAuthEnabled())) { + Matching = &PI; + break; + } + } + if (!Matching) { + // We need a new parameter. + // Check if we are within the limit. + if (Params.size() >= maxParams) + return false; + + Params.resize(Params.size() + 1); + Matching = &Params.back(); + // Store the constant values into the new parameter. + Constant *FirstC = cast<Constant>(FInfos[0].CurrentInst->getOperand(OpIdx)); + for (FunctionInfo &FI : FInfos) { + Constant *C = cast<Constant>(FI.CurrentInst->getOperand(OpIdx)); + Matching->Values.push_back(C); + if (C != FirstC) + FI.NumParamsNeeded += 1; + } + if (isPtrAuthEnabled()) + Matching->needsPointerSigning = FInfos[0].needsPointerSigning(OpIdx); + } + /// Remember where the parameter is needed when we build our merged function. + Matching->Uses.push_back({FInfos[0].CurrentInst, OpIdx}); + return true; +} + +/// Copy \p origCall with a \p newCalle and add a ptrauth bundle with \p +/// discriminator. +void MergeFuncIgnoringConstImpl::replaceCallWithAddedPtrAuth( + CallInst *origCall, Value *newCallee, ConstantInt *discriminator) { + SmallVector<llvm::OperandBundleDef, 4> bundles; + origCall->getOperandBundlesAsDefs(bundles); + ConstantInt *key = getPtrAuthKey(); + llvm::Value *bundleArgs[] = {key, discriminator}; + bundles.emplace_back("ptrauth", bundleArgs); + + SmallVector<llvm::Value *, 4> copiedArgs; + for (Value *op : origCall->args()) { + copiedArgs.push_back(op); + } + + auto *newCall = + CallInst::Create(origCall->getFunctionType(), newCallee, copiedArgs, + bundles, origCall->getName(), origCall); + newCall->setAttributes(origCall->getAttributes()); + newCall->setTailCallKind(origCall->getTailCallKind()); + newCall->setCallingConv(origCall->getCallingConv()); + origCall->replaceAllUsesWith(newCall); + origCall->eraseFromParent(); +} + +void MergeFuncIgnoringConstImpl::dumpMergeInfo(const FunctionInfos &FInfos, + unsigned paramSize) { + std::set<llvm::IRHash> oHashes; + std::vector<std::string> funcLocs; + Function *OrigFunc = nullptr; + for (const auto &FInfo : FInfos) { + OrigFunc = FInfo.F; + + llvm::IRHash origHash = StructuralHash(*OrigFunc); + oHashes.insert(origHash); + + // Print debug location. + std::string Result; + raw_string_ostream DbgLocOS(Result); + if (DISubprogram *DIS = OrigFunc->getSubprogram()) { + DebugLoc FuncDbgLoc = + DILocation::get(DIS->getContext(), DIS->getScopeLine(), 0, DIS); + FuncDbgLoc.print(DbgLocOS); + DbgLocOS.flush(); + } + std::string singleLine = + "# functionLoc " + + std::to_string(GlobalValue::getGUID(OrigFunc->getName())) + " " + + Result + " " + std::string(OrigFunc->getName()) + "\n"; + funcLocs.push_back(singleLine); + } +} + +/// Merge all functions in \p FInfos by creating thunks which call the single +/// merged function with additional parameters. +void MergeFuncIgnoringConstImpl::mergeWithParams(const FunctionInfos &FInfos, + ParamInfos &Params) { + // We reuse the body of the first function for the new merged function. + Function *FirstF = FInfos.front().F; + + // Build the type for the merged function. This will be the type of the + // original function (FirstF) but with the additional parameter which are + // needed to parameterize the merged function. + FunctionType *OrigTy = FirstF->getFunctionType(); + SmallVector<Type *, 8> ParamTypes(OrigTy->param_begin(), OrigTy->param_end()); + + for (const ParamInfo &PI : Params) { + ParamTypes.push_back(PI.Values[0]->getType()); + } + + FunctionType *funcType = + FunctionType::get(OrigTy->getReturnType(), ParamTypes, false); + + // Create the new function. + Function *NewFunction = Function::Create(funcType, FirstF->getLinkage(), + FirstF->getName() + ".Tm"); + if (auto *SP = FirstF->getSubprogram()) + NewFunction->setSubprogram(SP); + NewFunction->copyAttributesFrom(FirstF); + // NOTE: this function is not externally available, do ensure that we reset + // the DLL storage + NewFunction->setDLLStorageClass(GlobalValue::DefaultStorageClass); + if (UseLinkOnceODRLinkageMerging) + NewFunction->setLinkage(GlobalValue::LinkOnceODRLinkage); + else + NewFunction->setLinkage(GlobalValue::InternalLinkage); + if (NoInlineForMergedFunction) + NewFunction->addFnAttr(Attribute::NoInline); + + // Insert the new function after the last function in the equivalence class. + FirstF->getParent()->getFunctionList().insert( + std::next(FInfos[1].F->getIterator()), NewFunction); + + LLVM_DEBUG(dbgs() << " Merge into " << NewFunction->getName() << '\n'); + + // Move the body of FirstF into the NewFunction. + NewFunction->splice(NewFunction->begin(), FirstF); + + auto NewArgIter = NewFunction->arg_begin(); + for (Argument &OrigArg : FirstF->args()) { + Argument &NewArg = *NewArgIter++; + OrigArg.replaceAllUsesWith(&NewArg); + } + unsigned numOrigArgs = FirstF->arg_size(); + + SmallPtrSet<Function *, 8> SelfReferencingFunctions; + + // Replace all differing operands with a parameter. + for (unsigned paramIdx = 0; paramIdx < Params.size(); ++paramIdx) { + const ParamInfo &PI = Params[paramIdx]; + Argument *NewArg = NewFunction->getArg(numOrigArgs + paramIdx); + + if (!PI.needsPointerSigning) { + for (const OpLocation &OL : PI.Uses) { + OL.I->setOperand(OL.OpIndex, NewArg); + } + } + // Collect all functions which are referenced by any parameter. + for (Value *V : PI.Values) { + if (auto *F = dyn_cast<Function>(V)) + SelfReferencingFunctions.insert(F); + } + } + + // Replace all differing operands, which need pointer signing, with a + // parameter. + // We need to do that after all other parameters, because here we replace + // call instructions, which must be live in case it has another constant to + // be replaced. + for (unsigned paramIdx = 0; paramIdx < Params.size(); ++paramIdx) { + ParamInfo &PI = Params[paramIdx]; + if (PI.needsPointerSigning) { + PI.computeDiscriminator(NewFunction->getContext()); + for (const OpLocation &OL : PI.Uses) { + auto *origCall = cast<CallInst>(OL.I); + Argument *newCallee = NewFunction->getArg(numOrigArgs + paramIdx); + replaceCallWithAddedPtrAuth(origCall, newCallee, PI.discriminator); + } + } + } + + for (unsigned FIdx = 0, NumFuncs = FInfos.size(); FIdx < NumFuncs; ++FIdx) { + Function *OrigFunc = FInfos[FIdx].F; + // Don't try to replace all callers of functions which are used as + // parameters because we must not delete such functions. + if (SelfReferencingFunctions.count(OrigFunc) == 0 && + replaceDirectCallers(OrigFunc, NewFunction, Params, FIdx)) { + // We could replace all uses (and the function is not externally visible), + // so we can delete the original function. + auto Iter = FuncEntries.find(OrigFunc); + assert(Iter != FuncEntries.end()); + assert(!isInEquivalenceClass(&*Iter->second)); + Iter->second->F = nullptr; + FuncEntries.erase(Iter); + LLVM_DEBUG(dbgs() << " Erase " << OrigFunc->getName() << '\n'); + OrigFunc->eraseFromParent(); + } else { + // Otherwise we need a thunk which calls the merged function. + writeThunk(NewFunction, OrigFunc, Params, FIdx); + } + ++NumFunctionsMergedIgnoringConst; + } +} + +/// Remove all functions of \p FE's equivalence class from FnTree. Add them to +/// Deferred so that we'll look at them in the next round. +void MergeFuncIgnoringConstImpl::removeEquivalenceClassFromTree( + FunctionEntry *FE) { + if (!isInEquivalenceClass(FE)) + return; + + FnTreeType::iterator Iter = FE->TreeIter; + FunctionEntry *Unlink = Iter->First; + Unlink->numUnhandledCallees = 0; + while (Unlink) { + LLVM_DEBUG(dbgs() << " remove from tree: " << Unlink->F->getName() + << '\n'); + if (!Unlink->isMerged) + Deferred.emplace_back(Unlink->F); + Unlink->TreeIter = FnTree.end(); + assert(Unlink->numUnhandledCallees == 0); + FunctionEntry *NextEntry = Unlink->Next; + Unlink->Next = nullptr; + Unlink = NextEntry; + } + FnTree.erase(Iter); +} + +// Helper for writeThunk, +// Selects proper bitcast operation, +// but a bit simpler then CastInst::getCastOpcode. +Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { + Type *SrcTy = V->getType(); + if (SrcTy->isStructTy()) { + assert(DestTy->isStructTy()); + assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements()); + Value *Result = UndefValue::get(DestTy); + for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) { + Value *Element = + createCast(Builder, Builder.CreateExtractValue(V, makeArrayRef(I)), + DestTy->getStructElementType(I)); + + Result = Builder.CreateInsertValue(Result, Element, makeArrayRef(I)); + } + return Result; + } + assert(!DestTy->isStructTy()); + if (CastArrayType) { + if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) { + auto *DestAT = dyn_cast<ArrayType>(DestTy); + assert(DestAT); + assert(SrcAT->getNumElements() == DestAT->getNumElements()); + Value *Result = UndefValue::get(DestTy); + for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) { + Value *Element = + createCast(Builder, Builder.CreateExtractValue(V, makeArrayRef(I)), + DestAT->getElementType()); + + Result = Builder.CreateInsertValue(Result, Element, makeArrayRef(I)); + } + return Result; + } + assert(!DestTy->isArrayTy()); + } + if (SrcTy->isIntegerTy() && DestTy->isPointerTy()) + return Builder.CreateIntToPtr(V, DestTy); + else if (SrcTy->isPointerTy() && DestTy->isIntegerTy()) + return Builder.CreatePtrToInt(V, DestTy); + else + return Builder.CreateBitCast(V, DestTy); +} + +/// Replace \p Thunk with a simple tail call to \p ToFunc. Also add parameters +/// to the call to \p ToFunc, which are defined by the FuncIdx's value in +/// \p Params. +void MergeFuncIgnoringConstImpl::writeThunk(Function *ToFunc, Function *Thunk, + const ParamInfos &Params, + unsigned FuncIdx) { + // Delete the existing content of Thunk. + Thunk->dropAllReferences(); + + BasicBlock *BB = BasicBlock::Create(Thunk->getContext(), "", Thunk); + IRBuilder<> Builder(BB); + + SmallVector<Value *, 16> Args; + unsigned ParamIdx = 0; + FunctionType *ToFuncTy = ToFunc->getFunctionType(); + + // Add arguments which are passed through Thunk. + for (Argument &AI : Thunk->args()) { + Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx))); + ++ParamIdx; + } + // Add new arguments defined by Params. + for (const ParamInfo &PI : Params) { + assert(ParamIdx < ToFuncTy->getNumParams()); + Constant *param = getSignedValue(PI, FuncIdx); + Args.push_back( + createCast(Builder, param, ToFuncTy->getParamType(ParamIdx))); + ++ParamIdx; + } + + CallInst *CI = Builder.CreateCall(ToFunc, Args); + bool isSwiftTailCall = ToFunc->getCallingConv() == CallingConv::SwiftTail && + Thunk->getCallingConv() == CallingConv::SwiftTail; + CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail + : llvm::CallInst::TCK_Tail); + CI->setCallingConv(ToFunc->getCallingConv()); + CI->setAttributes(ToFunc->getAttributes()); + if (Thunk->getReturnType()->isVoidTy()) { + Builder.CreateRetVoid(); + } else { + Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType())); + } + + LLVM_DEBUG(dbgs() << " writeThunk: " << Thunk->getName() << '\n'); + ++NumThunksWrittenIgnoringConst; +} + +/// Replace direct callers of Old with New. Also add parameters to the call to +/// \p New, which are defined by the FuncIdx's value in \p Params. +bool MergeFuncIgnoringConstImpl::replaceDirectCallers(Function *Old, + Function *New, + const ParamInfos &Params, + unsigned FuncIdx) { + bool AllReplaced = true; + + SmallVector<CallInst *, 8> Callers; + + for (Use &U : Old->uses()) { + auto *I = dyn_cast<Instruction>(U.getUser()); + if (!I) { + AllReplaced = false; + continue; + } + FunctionEntry *FE = getEntry(I->getFunction()); + if (FE) + removeEquivalenceClassFromTree(FE); + + auto *CI = dyn_cast<CallInst>(I); + if (!CI || CI->getCalledOperand() != Old) { + AllReplaced = false; + continue; + } + Callers.push_back(CI); + } + if (!AllReplaced) + return false; + + // When AlwaysCallThunk is true, return false so a thunk will be emitted, also + // do not replace callsites. + if (AlwaysCallThunk) + return false; + + for (CallInst *CI : Callers) { + auto &Context = New->getContext(); + auto NewPAL = New->getAttributes(); + + SmallVector<Type *, 8> OldParamTypes; + SmallVector<Value *, 16> NewArgs; + SmallVector<AttributeSet, 8> NewArgAttrs; + IRBuilder<> Builder(CI); + + FunctionType *NewFuncTy = New->getFunctionType(); + (void)NewFuncTy; + unsigned ParamIdx = 0; + + // Add the existing parameters. + for (Value *OldArg : CI->args()) { + NewArgAttrs.push_back(NewPAL.getParamAttrs(ParamIdx)); + NewArgs.push_back(OldArg); + OldParamTypes.push_back(OldArg->getType()); + ++ParamIdx; + } + // Add the new parameters. + for (const ParamInfo &PI : Params) { + assert(ParamIdx < NewFuncTy->getNumParams()); + Constant *ArgValue = getSignedValue(PI, FuncIdx); + assert(ArgValue != Old && "should not try to replace all callers of self " + "referencing functions"); + NewArgs.push_back(ArgValue); + OldParamTypes.push_back(ArgValue->getType()); + ++ParamIdx; + } + + auto *FType = FunctionType::get(Old->getFunctionType()->getReturnType(), + OldParamTypes, false); + auto *FPtrType = PointerType::get( + FType, cast<PointerType>(New->getType())->getAddressSpace()); + + Value *Callee = ConstantExpr::getBitCast(New, FPtrType); + CallInst *NewCI; + if (objcarc::hasAttachedCallOpBundle(CI)) { + Value *BundleArgs[] = {*objcarc::getAttachedARCFunction(CI)}; + OperandBundleDef OB("clang.arc.attachedcall", BundleArgs); + NewCI = Builder.CreateCall(FType, Callee, NewArgs, {OB}); + } else { + NewCI = Builder.CreateCall(FType, Callee, NewArgs); + } + NewCI->setCallingConv(CI->getCallingConv()); + // Don't transfer attributes from the function to the callee. Function + // attributes typically aren't relevant to the calling convention or ABI. + NewCI->setAttributes(AttributeList::get(Context, /*FnAttrs=*/AttributeSet(), + NewPAL.getRetAttrs(), NewArgAttrs)); + if (IgnoreMusttailFunction && CI->isMustTailCall()) { + // replace a callsite with musttail. + llvm::errs() << "callsite has musttail in newF " << New->getName() + << "\n"; + } + NewCI->copyMetadata(*CI); + CI->replaceAllUsesWith(NewCI); + CI->eraseFromParent(); + } + assert(Old->use_empty() && "should have replaced all uses of old function"); + return Old->hasLocalLinkage(); +} + +PreservedAnalyses MergeFuncIgnoringConstPass::run(Module &M, + ModuleAnalysisManager &MAM) { + if (runInternal(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt index e971c638327bf05..80946cb06547551 100644 --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -27,6 +27,8 @@ add_llvm_component_library(LLVMTransformUtils FixIrreducible.cpp FlattenCFG.cpp FunctionComparator.cpp + FunctionComparatorIgnoringConst.cpp + FunctionHashIgnoringConst.cpp FunctionImportUtils.cpp GlobalStatus.cpp GuardUtils.cpp diff --git a/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp b/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp new file mode 100644 index 000000000000000..3b3567111f43034 --- /dev/null +++ b/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp @@ -0,0 +1,107 @@ +//===--- FunctionComparatorIgnoringConst.cpp - Function Comparator --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h" + +using namespace llvm; + +int FunctionComparatorIgnoringConst::cmpOperandsIgnoringConsts( + const Instruction *L, const Instruction *R, unsigned opIdx) { + Value *OpL = L->getOperand(opIdx); + Value *OpR = R->getOperand(opIdx); + + int Res = cmpValues(OpL, OpR); + if (Res == 0) + return Res; + + if (!isa<Constant>(OpL) || !isa<Constant>(OpR)) + return Res; + + if (!isEligibleOperandForConstantSharing(L, opIdx) || + !isEligibleOperandForConstantSharing(R, opIdx)) + return Res; + + if (cmpTypes(OpL->getType(), OpR->getType())) + return Res; + + return 0; +} + +// Test whether two basic blocks have equivalent behavior. +int FunctionComparatorIgnoringConst::cmpBasicBlocksIgnoringConsts( + const BasicBlock *BBL, const BasicBlock *BBR, + const std::set<std::pair<int, int>> *InstOpndIndex) { + BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); + BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end(); + + do { + bool needToCmpOperands = true; + if (int Res = cmpOperations(&*InstL, &*InstR, needToCmpOperands)) + return Res; + if (needToCmpOperands) { + assert(InstL->getNumOperands() == InstR->getNumOperands()); + + for (unsigned i = 0, e = InstL->getNumOperands(); i != e; ++i) { + // When a set for (instruction, operand) index pairs is given, we only + // ignore constants located at such indices. Otherwise, we precisely + // compare the operands. + if (InstOpndIndex && !InstOpndIndex->count(std::make_pair(index, i))) { + Value *OpL = InstL->getOperand(i); + Value *OpR = InstR->getOperand(i); + if (int Res = cmpValues(OpL, OpR)) + return Res; + } + if (int Res = cmpOperandsIgnoringConsts(&*InstL, &*InstR, i)) + return Res; + // cmpValues should ensure this is true. + assert(cmpTypes(InstL->getOperand(i)->getType(), + InstR->getOperand(i)->getType()) == 0); + } + } + ++index; + ++InstL, ++InstR; + } while (InstL != InstLE && InstR != InstRE); + + if (InstL != InstLE && InstR == InstRE) + return 1; + if (InstL == InstLE && InstR != InstRE) + return -1; + return 0; +} + +// Test whether the two functions have equivalent behavior. +int FunctionComparatorIgnoringConst::compareIgnoringConsts( + const std::set<std::pair<int, int>> *InstOpndIndex) { + beginCompare(); + index = 0; + + if (int Res = compareSignature()) + return Res; + + Function::const_iterator LIter = FnL->begin(), LEnd = FnL->end(); + Function::const_iterator RIter = FnR->begin(), REnd = FnR->end(); + + do { + const BasicBlock *BBL = &*LIter; + const BasicBlock *BBR = &*RIter; + + if (int Res = cmpValues(BBL, BBR)) + return Res; + + if (int Res = cmpBasicBlocksIgnoringConsts(BBL, BBR, InstOpndIndex)) + return Res; + + ++LIter, ++RIter; + } while (LIter != LEnd && RIter != REnd); + + return 0; +} diff --git a/llvm/lib/Transforms/Utils/FunctionHashIgnoringConst.cpp b/llvm/lib/Transforms/Utils/FunctionHashIgnoringConst.cpp new file mode 100644 index 000000000000000..b24d3ffdba93388 --- /dev/null +++ b/llvm/lib/Transforms/Utils/FunctionHashIgnoringConst.cpp @@ -0,0 +1,620 @@ +//===--- FunctionHashIgnoringConst.cpp - Function Hash --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/FunctionHashIgnoringConst.h" +#include "llvm/CodeGen/MachineStableHash.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h" + +using namespace llvm; + +#define DEBUG_TYPE "functionhash-ignoringconst" + +namespace { + +// Accumulate the hash of a sequence of 64-bit integers. This is similar to a +// hash of a sequence of 64bit ints, but the entire input does not need to be +// available at once. This interface is necessary for functionHash because it +// needs to accumulate the hash as the structure of the function is traversed +// without saving these values to an intermediate buffer. This form of hashing +// is not often needed, as usually the object to hash is just read from a +// buffer. +class HashAccumulator64 { + uint64_t Hash; + +public: + // Initialize to random constant, so the state isn't zero. + HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } + + void add(uint64_t V) { Hash = hashing::detail::hash_16_bytes(Hash, V); } + + // No finishing is required, because the entire hash value is used. + uint64_t getHash() { return Hash; } +}; + +} // end anonymous namespace + +FunctionHashIgnoringConst::FunctionHash FunctionHashIgnoringConst::functionHash( + Function &F, std::map<int, Instruction *> *IdxToInst, + std::map<std::pair<int, int>, uint64_t> *IdxToConstHash) { + GlobalNumberState tGlobalNumbers; + FunctionHashIgnoringConst pFC(&F, &tGlobalNumbers); + + std::map<std::pair<int, int>, uint64_t> LIdxToConstHash; + std::map<int, Instruction *> LIdxToIns; + if (!IdxToConstHash) + IdxToConstHash = &LIdxToConstHash; + if (!IdxToInst) + IdxToInst = &LIdxToIns; + + auto Hash = pFC.hashIgnoringConsts(*IdxToInst, *IdxToConstHash); + // FIXME: Check if all Constants with the same hash are identical. + // Conservatively, we return 0 hash if that's not the case. + // Should we just assert it? + if (Hash && IdxToConstHash && IdxToInst) { + std::map<uint64_t, Constant *> ConstHashToConst; + for (auto &P : *IdxToConstHash) { + auto InstIndex = P.first.first; + auto OpndIndex = P.first.second; + auto ConstHash = P.second; + auto *Inst = (*IdxToInst)[InstIndex]; + auto *Const = cast<Constant>(Inst->getOperand(OpndIndex)); + auto It = ConstHashToConst.find(ConstHash); + if (It != ConstHashToConst.end()) { + auto *PrevConst = It->second; + if (pFC.cmpConstants(PrevConst, Const)) + return 0; + } else { + ConstHashToConst.insert({ConstHash, Const}); + } + } + } + return Hash; +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashIgnoringConsts( + std::map<int, Instruction *> &IdxToInst, + std::map<std::pair<int, int>, uint64_t> &IdxToConstHash) { + beginCompare(); + index = 0; + + auto h = hashSignature(); + Function::const_iterator LIter = FnL->begin(), LEnd = FnL->end(); + + do { + const BasicBlock *BBL = &*LIter; + auto h1 = hashValue(BBL); + auto h2 = hashBasicBlocksIgnoringConsts(BBL, IdxToInst, IdxToConstHash); + // Ignore 0 hash value conservatively. + if (h2 == 0) + return 0; + + h = hash_combine(h, h1, h2); + ++LIter; + } while (LIter != LEnd); + return h; +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashBasicBlocksIgnoringConsts( + const BasicBlock *BBL, std::map<int, Instruction *> &IdxToInst, + std::map<std::pair<int, int>, uint64_t> &IdxToConstHash) { + + BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); + bool isFirst = true; + FunctionHashIgnoringConst::FunctionHash retH; + do { + bool needToCmpOperands = true; + auto h = hashOperation(&*InstL, needToCmpOperands); + if (needToCmpOperands) { + for (int i = 0, e = InstL->getNumOperands(); i != e; ++i) { + auto hashOpnd = 0; + Value *OpL = InstL->getOperand(i); + if (isa<Constant>(OpL)) { + hashOpnd = hashConstant(cast<Constant>(OpL)); + if (isEligibleOperandForConstantSharing(&*InstL, i)) { + // Preserve the original const hash in the map while + // ignoring it while hashing its type only. + IdxToConstHash[std::make_pair(index, i)] = hashOpnd; + hashOpnd = hashType(OpL->getType()); + } + } else { + hashOpnd = hashValue(OpL); + } + h = hash_combine(h, hashOpnd); + } + } + IdxToInst[index] = (Instruction *)&*InstL; + ++index; + if (isFirst) { + retH = h; + isFirst = false; + } else + retH = hash_combine(retH, h); + ++InstL; + } while (InstL != InstLE); + return retH; +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashAPInt(const APInt &L) const { + HashAccumulator64 H; + H.add(hash_value(L)); + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashAPFloat(const APFloat &L) const { + const fltSemantics &SL = L.getSemantics(); + HashAccumulator64 H; + H.add(APFloat::semanticsPrecision(SL)); + H.add(APFloat::semanticsMaxExponent(SL)); + H.add(APFloat::semanticsMinExponent(SL)); + H.add(APFloat::semanticsSizeInBits(SL)); + auto h = hashAPInt(L.bitcastToAPInt()); + H.add(h); + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashAttrs(const AttributeList L) const { + HashAccumulator64 H; + H.add(L.getNumAttrSets()); + + for (unsigned i : L.indexes()) { + AttributeSet LAS = L.getAttributes(i); + AttributeSet::iterator LI = LAS.begin(), LE = LAS.end(); + for (; LI != LE; ++LI) { + Attribute LA = *LI; + H.add(LA.isTypeAttribute()); + if (LA.isTypeAttribute()) { + H.add(LA.getKindAsEnum()); + Type *TyL = LA.getValueAsType(); + if (TyL) { + auto h = hashType(TyL); + H.add(h); + } + continue; + } + // Check AttributeImpl::operator< + if (!LA.isStringAttribute()) { + H.add(LA.getKindAsEnum()); + if (LA.isIntAttribute()) + H.add(LA.getValueAsInt()); + continue; + } + H.add(stable_hash_combine_string(LA.getKindAsString())); + H.add(stable_hash_combine_string(LA.getValueAsString())); + } + } + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashRangeMetadata(const MDNode *L) const { + HashAccumulator64 H; + if (!L) { + H.add('N'); + return H.getHash(); + } + H.add(L->getNumOperands()); + for (size_t I = 0; I < L->getNumOperands(); ++I) { + ConstantInt *LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); + H.add(hashAPInt(LLow->getValue())); + } + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashOperandBundlesSchema(const CallBase &LCS) const { + HashAccumulator64 H; + H.add(LCS.getNumOperandBundles()); + for (unsigned I = 0, E = LCS.getNumOperandBundles(); I != E; ++I) { + auto OBL = LCS.getOperandBundleAt(I); + H.add(stable_hash_combine_string(OBL.getTagName())); + H.add(OBL.Inputs.size()); + } + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashConstant(const Constant *L) const { + HashAccumulator64 H; + + Type *TyL = L->getType(); + // if (!TyL->isfirstClassType()) + // For first class type, check if can be bitcasted + unsigned TyLWidth = 0; + if (auto *VecTyL = dyn_cast<VectorType>(TyL)) { + H.add(stable_hash_combine_string("V")); + TyLWidth = VecTyL->getPrimitiveSizeInBits().getFixedSize(); + H.add(TyLWidth); + } else if (auto *PTyL = dyn_cast<PointerType>(TyL)) { + H.add(stable_hash_combine_string("P")); + H.add(PTyL->getAddressSpace()); + } else { + auto h = hashType(TyL); + H.add(h); + } + if (L->isNullValue()) { + H.add(stable_hash_combine_string("N")); + return H.getHash(); + } + auto GlobalValueL = const_cast<GlobalValue *>(dyn_cast<GlobalValue>(L)); + if (GlobalValueL) { + auto h = hashGlobalValue(GlobalValueL); + // Ignore 0 hash value conservatively. + if (h == 0) + return 0; + H.add(h); + return H.getHash(); + } + if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) { + H.add(stable_hash_combine_string(SeqL->getRawDataValues())); + return H.getHash(); + } + + // This is right no hashing getValueID() like comparing it in cmpConstants + // since the id itself can vary. + switch (L->getValueID()) { + case Value::UndefValueVal: + case Value::PoisonValueVal: + case Value::ConstantTokenNoneVal: { + return H.getHash(); + } + case Value::ConstantIntVal: { + const APInt &LInt = cast<ConstantInt>(L)->getValue(); + H.add(hashAPInt(LInt)); + return H.getHash(); + } + case Value::ConstantFPVal: { + const APFloat &LAPF = cast<ConstantFP>(L)->getValueAPF(); + H.add(hashAPFloat(LAPF)); + return H.getHash(); + } + case Value::ConstantArrayVal: { + const ConstantArray *LA = cast<ConstantArray>(L); + uint64_t NumElementsL = cast<ArrayType>(TyL)->getNumElements(); + H.add(NumElementsL); + for (uint64_t i = 0; i < NumElementsL; ++i) { + auto h = hashConstant(cast<Constant>(LA->getOperand(i))); + H.add(h); + } + return H.getHash(); + } + case Value::ConstantStructVal: { + const ConstantStruct *LS = cast<ConstantStruct>(L); + unsigned NumElementsL = cast<StructType>(TyL)->getNumElements(); + H.add(NumElementsL); + for (unsigned i = 0; i != NumElementsL; ++i) { + auto h = hashConstant(cast<Constant>(LS->getOperand(i))); + H.add(h); + } + return H.getHash(); + } + case Value::ConstantVectorVal: { + const ConstantVector *LV = cast<ConstantVector>(L); + unsigned NumElementsL = cast<FixedVectorType>(TyL)->getNumElements(); + H.add(NumElementsL); + for (uint64_t i = 0; i < NumElementsL; ++i) { + auto h = hashConstant(cast<Constant>(LV->getOperand(i))); + H.add(h); + } + return H.getHash(); + } + case Value::ConstantExprVal: { + const ConstantExpr *LE = cast<ConstantExpr>(L); + unsigned NumOperandsL = LE->getNumOperands(); + H.add(NumOperandsL); + for (unsigned i = 0; i < NumOperandsL; ++i) { + auto h = hashConstant(cast<Constant>(LE->getOperand(i))); + H.add(h); + } + return H.getHash(); + } + case Value::BlockAddressVal: { + const BlockAddress *LBA = cast<BlockAddress>(L); + auto h = hashGlobalValue(LBA->getFunction()); + // TODO: handle BBs in the same function. can we reference a block + // in another TU? + H.add(h); + return H.getHash(); + } + default: // Unknown constant, abort. + LLVM_DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); + llvm_unreachable("Constant ValueID not recognized."); + } + + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashGlobalValue(const GlobalValue *GV) const { + HashAccumulator64 H; + + if (!GV->hasName()) + return 0; + + // For the local global, has module identifier to make it unique. + if (GV->hasLocalLinkage()) { + H.add(stable_hash_combine_string(GV->getParent()->getModuleIdentifier())); + } + + // Use GUID to hash more consistently. + H.add(GV->getGUID()); + + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashType(Type *TyL) const { + HashAccumulator64 H; + PointerType *PTyL = dyn_cast<PointerType>(TyL); + const DataLayout &DL = FnL->getParent()->getDataLayout(); + if (PTyL && PTyL->getAddressSpace() == 0) + TyL = DL.getIntPtrType(TyL); + H.add(TyL->getTypeID()); + switch (TyL->getTypeID()) { + default: + break; + case Type::IntegerTyID: + H.add(cast<IntegerType>(TyL)->getBitWidth()); + break; + case Type::PointerTyID: + H.add(PTyL->getAddressSpace()); + break; + case Type::StructTyID: { + StructType *STyL = cast<StructType>(TyL); + H.add(STyL->getNumElements()); + H.add(STyL->isPacked()); + for (unsigned i = 0, e = STyL->getNumElements(); i != e; ++i) { + auto h = hashType(STyL->getElementType(i)); + H.add(h); + } + break; + } + case Type::FunctionTyID: { + FunctionType *FTyL = cast<FunctionType>(TyL); + H.add(FTyL->getNumParams()); + H.add(FTyL->isVarArg()); + auto h = hashType(FTyL->getReturnType()); + H.add(h); + + for (unsigned i = 0, e = FTyL->getNumParams(); i != e; ++i) { + auto h = hashType(FTyL->getParamType(i)); + H.add(h); + } + break; + } + case Type::ArrayTyID: { + auto *STyL = cast<ArrayType>(TyL); + H.add(STyL->getNumElements()); + auto h = hashType(STyL->getElementType()); + H.add(h); + break; + } + case Type::FixedVectorTyID: + case Type::ScalableVectorTyID: { + auto *STyL = cast<VectorType>(TyL); + H.add(STyL->getElementCount().isScalable()); + H.add(STyL->getElementCount().getKnownMinValue()); + auto h = hashType(STyL->getElementType()); + H.add(h); + break; + } + } + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashOperation(const Instruction *L, + bool &needToCmpOperands) const { + needToCmpOperands = true; + HashAccumulator64 H; + auto h = hashValue(L); + H.add(h); + H.add(L->getOpcode()); + if (const GetElementPtrInst *GEPL = dyn_cast<GetElementPtrInst>(L)) { + needToCmpOperands = false; + auto h = hashValue(GEPL->getPointerOperand()); + H.add(h); + h = hashGEP(cast<GEPOperator>(GEPL)); + H.add(h); + return H.getHash(); + } + H.add(L->getNumOperands()); + H.add(hashType(L->getType())); + H.add(L->getRawSubclassOptionalData()); + + for (unsigned i = 0, e = L->getNumOperands(); i != e; ++i) { + H.add(hashType(L->getOperand(i)->getType())); + } + + if (const AllocaInst *AI = dyn_cast<AllocaInst>(L)) { + H.add(hashType(AI->getAllocatedType())); + H.add(AI->getAlign().value()); + return H.getHash(); + } + if (const LoadInst *LI = dyn_cast<LoadInst>(L)) { + H.add(LI->isVolatile()); + H.add(LI->getAlign().value()); + H.add((int)(LI->getOrdering())); + H.add(LI->getSyncScopeID()); + H.add(hashRangeMetadata(LI->getMetadata(LLVMContext::MD_range))); + return H.getHash(); + } + if (const StoreInst *SI = dyn_cast<StoreInst>(L)) { + H.add(SI->isVolatile()); + H.add(SI->getAlign().value()); + H.add((int)(SI->getOrdering())); + H.add(SI->getSyncScopeID()); + return H.getHash(); + } + if (const CmpInst *CI = dyn_cast<CmpInst>(L)) { + H.add(CI->getPredicate()); + return H.getHash(); + } + if (auto *CBL = dyn_cast<CallBase>(L)) { + H.add(CBL->getCallingConv()); + H.add(hashAttrs(CBL->getAttributes())); + H.add(hashOperandBundlesSchema(*CBL)); + if (const CallInst *CI = dyn_cast<CallInst>(L)) + H.add(CI->getTailCallKind()); + H.add(hashRangeMetadata(L->getMetadata(LLVMContext::MD_range))); + return H.getHash(); + } + if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) { + ArrayRef<unsigned> LIndices = IVI->getIndices(); + H.add(LIndices.size()); + for (size_t i = 0, e = LIndices.size(); i != e; ++i) { + H.add(LIndices[i]); + } + return H.getHash(); + } + if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(L)) { + ArrayRef<unsigned> LIndices = EVI->getIndices(); + H.add(LIndices.size()); + for (size_t i = 0, e = LIndices.size(); i != e; ++i) { + H.add(LIndices[i]); + } + return H.getHash(); + } + if (const FenceInst *FI = dyn_cast<FenceInst>(L)) { + H.add((int)(FI->getOrdering())); + H.add(FI->getSyncScopeID()); + return H.getHash(); + } + if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(L)) { + H.add(CXI->isVolatile()); + H.add(CXI->isWeak()); + H.add((int)(CXI->getSuccessOrdering())); + H.add((int)(CXI->getFailureOrdering())); + H.add(CXI->getSyncScopeID()); + return H.getHash(); + } + if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(L)) { + H.add(RMWI->getOperation()); + H.add(RMWI->isVolatile()); + H.add((int)(RMWI->getOrdering())); + H.add(RMWI->getSyncScopeID()); + return H.getHash(); + } + if (const ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(L)) { + ArrayRef<int> LMask = SVI->getShuffleMask(); + H.add(LMask.size()); + for (size_t i = 0, e = LMask.size(); i != e; ++i) { + H.add(LMask[i]); + } + return H.getHash(); + } + if (const PHINode *PNL = dyn_cast<PHINode>(L)) { + for (unsigned i = 0, e = PNL->getNumIncomingValues(); i != e; ++i) { + H.add(hashValue(PNL->getIncomingBlock(i))); + } + return H.getHash(); + } + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashGEP(const GEPOperator *GEPL) const { + unsigned int ASL = GEPL->getPointerAddressSpace(); + HashAccumulator64 H; + H.add(ASL); + + const DataLayout &DL = FnL->getParent()->getDataLayout(); + unsigned BitWidth = DL.getPointerSizeInBits(ASL); + APInt OffsetL(BitWidth, 0); + if (GEPL->accumulateConstantOffset(DL, OffsetL)) { + H.add(hashAPInt(OffsetL)); + // Return early similar to how we implement cmpGEPs. + // TODO: Should we let it got thru the followings? + return H.getHash(); + } + + H.add(hashType(GEPL->getSourceElementType())); + H.add(GEPL->getNumOperands()); + for (unsigned i = 0, e = GEPL->getNumOperands(); i != e; ++i) { + H.add(hashValue(GEPL->getOperand(i))); + } + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashInlineAsm(const InlineAsm *L) const { + HashAccumulator64 H; + auto h = hashType(L->getFunctionType()); + H.add(h); + H.add(stable_hash_combine_string(L->getAsmString())); + H.add(stable_hash_combine_string(L->getConstraintString())); + H.add(L->hasSideEffects()); + H.add(L->isAlignStack()); + H.add(L->getDialect()); + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashValue(const Value *L) const { + HashAccumulator64 H; + if (L == FnL) { + H.add(stable_hash_combine_string("S")); + return H.getHash(); + } + + const Constant *ConstL = dyn_cast<Constant>(L); + if (ConstL) { + H.add(stable_hash_combine_string("C")); + auto h = hashConstant(ConstL); + H.add(h); + return H.getHash(); + } + + const InlineAsm *InlineAsmL = dyn_cast<InlineAsm>(L); + if (InlineAsmL) + return hashInlineAsm(InlineAsmL); + + // TODO: hash the index at the first insertion to the map? what if we don't + // have a map??? + auto LeftSN = sn_mapL.insert(std::make_pair(L, sn_mapL.size())); + H.add(LeftSN.first->second); + return H.getHash(); +} + +FunctionHashIgnoringConst::FunctionHash +FunctionHashIgnoringConst::hashSignature() const { + HashAccumulator64 H; + auto h = hashAttrs(FnL->getAttributes()); + H.add(h); + H.add(FnL->hasGC()); + if (FnL->hasGC()) { + uint64_t rawHash = stable_hash_combine_string(FnL->getGC()); + H.add(rawHash); + } + H.add(FnL->hasSection()); + if (FnL->hasSection()) { + uint64_t rawHash = stable_hash_combine_string(FnL->getSection()); + H.add(rawHash); + } + H.add(FnL->isVarArg()); + H.add(FnL->getCallingConv()); + h = hashType(FnL->getFunctionType()); + H.add(h); + for (Function::const_arg_iterator ArgI = FnL->arg_begin(), + ArgE = FnL->arg_end(); + ArgI != ArgE; ++ArgI) { + h = hashValue(&*ArgI); + H.add(h); + } + return H.getHash(); +} diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt index d1714a7532d5ff4..a40ec222a0bfb56 100644 --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -19,6 +19,7 @@ add_llvm_unittest(UtilsTests CodeMoverUtilsTest.cpp DebugifyTest.cpp FunctionComparatorTest.cpp + FunctionHashIgnoringConstTest.cpp IntegerDivisionTest.cpp LocalTest.cpp LoopRotationUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Utils/FunctionHashIgnoringConstTest.cpp b/llvm/unittests/Transforms/Utils/FunctionHashIgnoringConstTest.cpp new file mode 100644 index 000000000000000..64e7b96170979e6 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/FunctionHashIgnoringConstTest.cpp @@ -0,0 +1,120 @@ +//===---- FunctionHashIgnoringConstTest.cpp - Unit tests -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/FunctionHashIgnoringConst.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("ModuleUtilsTest", errs()); + return Mod; +} + +// FuncionHashIgnoringConst ignores some Constant using +// isEligibleInstrunctionForConstantSharing(). +// The following tests check those globals operated by load/store/call. +TEST(FunctionHashIgnoringConst, HashPass) { + LLVMContext C; + const char *ModStr = "@g1 = external global i32\n" + "@g2 = external global i32\n" + "declare void @f1()\n" + "declare void @f2()\n" + "define i32 @foo() {\n" + " %1 = load i32, i32* @g1\n" + " store i32 %1, i32* @g1\n" + " call void @f1()" + " ret i32 %1\n" + "}\n" + "define i32 @goo() {\n" + " %1 = load i32, i32* @g2\n" + " store i32 %1, i32* @g2\n" + " call void @f2()" + " ret i32 %1\n" + "}\n"; + std::unique_ptr<Module> M = parseIR(C, ModStr); + auto *Foo = M->getFunction("foo"); + auto *Goo = M->getFunction("goo"); + + auto HashFoo = FunctionHashIgnoringConst::functionHash(*Foo); + auto HashGoo = FunctionHashIgnoringConst::functionHash(*Goo); + + // We expect function hashes are matched by ignoring Constant. + EXPECT_EQ(HashFoo, HashGoo); + + // IR comparsion is equal when ignoring Constants. + GlobalNumberState tGlobalNumbers; + FunctionComparatorIgnoringHash FCmp(Foo, Goo, &tGlobalNumbers); + EXPECT_EQ(0, FCmp.compareIgnoringConsts()); + + // Get the locations and hashes that are different in Constants. + using FunctionHash = uint64_t; + using IdxHashMapTy = std::map<std::pair<int, int>, FunctionHash>; + + IdxHashMapTy IdxToConstHashFoo, IdxToConstHashGoo; + FunctionHashIgnoringConst::functionHash(*Foo, nullptr, &IdxToConstHashFoo); + FunctionHashIgnoringConst::functionHash(*Goo, nullptr, &IdxToConstHashGoo); + EXPECT_EQ(3, IdxToConstHashFoo.size()); + EXPECT_EQ(3, IdxToConstHashGoo.size()); + + // 0th instruction, 0th operand. + // "%1 = load i32, i32* @g1\n" vs. "%1 = load i32, i32* @g2\n" + EXPECT_EQ(1, IdxToConstHashFoo.count({0, 0})); + EXPECT_EQ(1, IdxToConstHashGoo.count({0, 0})); + EXPECT_NE((IdxToConstHashFoo[{0, 0}]), (IdxToConstHashGoo[{0, 0}])); + + // 1st instruction, 1st operand + // "store i32 %1, i32* @g1\n" vs. "store i32 %1, i32* @g1\n + + EXPECT_EQ(1, IdxToConstHashFoo.count({1, 1})); + EXPECT_EQ(1, IdxToConstHashGoo.count({1, 1})); + EXPECT_NE((IdxToConstHashFoo[{1, 1}]), (IdxToConstHashGoo[{1, 1}])); + + // 2nd instruction, 0th operand + // "call void @f1()" vs "call void @f2()" + EXPECT_EQ(1, IdxToConstHashFoo.count({2, 0})); + EXPECT_EQ(1, IdxToConstHashGoo.count({2, 0})); + EXPECT_NE((IdxToConstHashFoo[{2, 0}]), (IdxToConstHashGoo[{2, 0}])); + + // Expect the hash of g1 are the same in 0th and 1st instruciton in Foo. + EXPECT_EQ((IdxToConstHashFoo[{0, 0}]), (IdxToConstHashFoo[{1, 1}])); + // Expect the hash of g2 are the same in 0th and 1st instruciton in Goo. + EXPECT_EQ((IdxToConstHashGoo[{0, 0}]), (IdxToConstHashGoo[{1, 1}])); + // Expect the hash of g1 is different than that of f1. + EXPECT_NE((IdxToConstHashFoo[{0, 0}]), (IdxToConstHashFoo[{2, 0}])); +} + +// This is the case where Constant differs but via `add` operation +// which isEligibleInstrunctionForConstantSharing does not support. +TEST(FunctionHashIgnoringConst, HashFail) { + LLVMContext C; + const char *ModStr = "define i32 @foo(i32 %a) {\n" + " %1 = add i32 %a, 1\n" + " ret i32 %1\n" + "}\n" + "define i32 @goo(i32 %a) {\n" + " %1 = add i32 %a, 2\n" + " ret i32 %1\n" + "}\n"; + std::unique_ptr<Module> M = parseIR(C, ModStr); + auto *Foo = M->getFunction("foo"); + auto *Goo = M->getFunction("goo"); + + auto HashFoo = FunctionHashIgnoringConst::functionHash(*Foo); + auto HashGoo = FunctionHashIgnoringConst::functionHash(*Goo); + + EXPECT_NE(HashFoo, HashGoo); +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits