Hello, Everyone. Please review attached patch for LowerSwitch. Actually, it just "merges" adjacent cases with same destination in one big "case cluster". During emiting comparison tree such clusters result to range comparisons. This is pretty cheap but useful transformation.
The same technique is planned for SDISel variant of switch lowering. -- With best regards, Anton Korobeynikov. Faculty of Mathematics & Mechanics, Saint Petersburg State University.
diff -r 87cd8438fce8 lib/Transforms/Utils/LowerSwitch.cpp --- a/lib/Transforms/Utils/LowerSwitch.cpp Mon Feb 26 18:56:07 2007 +0000 +++ b/lib/Transforms/Utils/LowerSwitch.cpp Wed Feb 28 02:34:32 2007 +0300 @@ -22,6 +22,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/Compiler.h" #include <algorithm> +#include <list> using namespace llvm; namespace { @@ -40,9 +41,19 @@ namespace { AU.addPreservedID(LowerInvokePassID); AU.addPreservedID(LowerAllocationsID); } - - typedef std::pair<Constant*, BasicBlock*> Case; - typedef std::vector<Case>::iterator CaseItr; + + struct Case { + Constant* Low; + Constant* High; + BasicBlock* BB; + + Case(Constant* _Low = NULL, Constant* _High = NULL, + BasicBlock* _BB = NULL): + Low(_Low), High(_High), BB(_BB) { } + }; + + typedef std::vector<Case> CaseVector; + typedef std::vector<Case>::iterator CaseItr; private: void processSwitchInst(SwitchInst *SI); @@ -50,16 +61,18 @@ namespace { BasicBlock* OrigBlock, BasicBlock* Default); BasicBlock* newLeafBlock(Case& Leaf, Value* Val, BasicBlock* OrigBlock, BasicBlock* Default); + unsigned Clusterify(CaseVector& Cases, SwitchInst *SI); }; /// The comparison function for sorting the switch case values in the vector. + /// WARNING: Case ranges should be disjoint! struct CaseCmp { bool operator () (const LowerSwitch::Case& C1, const LowerSwitch::Case& C2) { - const ConstantInt* CI1 = cast<const ConstantInt>(C1.first); - const ConstantInt* CI2 = cast<const ConstantInt>(C2.first); - return CI1->getZExtValue() < CI2->getZExtValue(); + const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); + const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); + return CI1->getSExtValue() < CI2->getSExtValue(); } }; @@ -91,19 +104,20 @@ bool LowerSwitch::runOnFunction(Function // operator<< - Used for debugging purposes. // -std::ostream& operator<<(std::ostream &O, - const std::vector<LowerSwitch::Case> &C) { +static std::ostream& operator<<(std::ostream &O, + const LowerSwitch::CaseVector &C) { O << "["; - for (std::vector<LowerSwitch::Case>::const_iterator B = C.begin(), + for (LowerSwitch::CaseVector::const_iterator B = C.begin(), E = C.end(); B != E; ) { - O << *B->first; + O << *B->Low << " -" << *B->High; if (++B != E) O << ", "; } return O << "]"; } -OStream& operator<<(OStream &O, const std::vector<LowerSwitch::Case> &C) { + +static OStream& operator<<(OStream &O, const LowerSwitch::CaseVector &C) { if (O.stream()) *O.stream() << C; return O; } @@ -128,7 +142,8 @@ BasicBlock* LowerSwitch::switchConvert(C Case& Pivot = *(Begin + Mid); DOUT << "Pivot ==> " - << cast<ConstantInt>(Pivot.first)->getSExtValue() << "\n"; + << cast<ConstantInt>(Pivot.Low)->getSExtValue() << " -" + << cast<ConstantInt>(Pivot.High)->getSExtValue() << "\n"; BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val, OrigBlock, Default); @@ -141,7 +156,7 @@ BasicBlock* LowerSwitch::switchConvert(C BasicBlock* NewNode = new BasicBlock("NodeBlock"); F->getBasicBlockList().insert(OrigBlock->getNext(), NewNode); - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Val, Pivot.first, "Pivot"); + ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Val, Pivot.Low, "Pivot"); NewNode->getInstList().push_back(Comp); new BranchInst(LBranch, RBranch, Comp, NewNode); return NewNode; @@ -161,25 +176,93 @@ BasicBlock* LowerSwitch::newLeafBlock(Ca BasicBlock* NewLeaf = new BasicBlock("LeafBlock"); F->getBasicBlockList().insert(OrigBlock->getNext(), NewLeaf); - // Make the seteq instruction... - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val, - Leaf.first, "SwitchLeaf"); - NewLeaf->getInstList().push_back(Comp); + // Emit comparison + ICmpInst* Comp = NULL; + if (Leaf.Low == Leaf.High) { + // Make the seteq instruction... + Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val, Leaf.Low, + "SwitchLeaf", NewLeaf); + } else { + // Make range comparison + if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { + // Val >= Min && Val <= Hi --> Val <= Hi + Comp = new ICmpInst(ICmpInst::ICMP_SLE, Val, Leaf.High, + "SwitchLeaf", NewLeaf); + } else { + // Emit V-Lo <=u Hi-Lo + Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); + Instruction* Add = BinaryOperator::createAdd(Val, NegLo, + Val->getName()+".off", + NewLeaf); + Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); + Comp = new ICmpInst(ICmpInst::ICMP_ULE, Add, UpperBound, + "SwitchLeaf", NewLeaf); + } + } // Make the conditional branch... - BasicBlock* Succ = Leaf.second; + BasicBlock* Succ = Leaf.BB; new BranchInst(Succ, Default, Comp, NewLeaf); // If there were any PHI nodes in this successor, rewrite one entry // from OrigBlock to come from NewLeaf. for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { PHINode* PN = cast<PHINode>(I); + // Remove all but one incoming entries from the cluster + uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - + cast<ConstantInt>(Leaf.Low)->getSExtValue(); + for (uint64_t j = 0; j < Range; ++j) { + PN->removeIncomingValue(OrigBlock); + } + int BlockIdx = PN->getBasicBlockIndex(OrigBlock); assert(BlockIdx != -1 && "Switch didn't go to this successor??"); PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); } return NewLeaf; +} + +unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) +{ + std::list<Case> tmpCases; + unsigned numCmps = 0; + + // Start with "simple" cases + for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) + tmpCases.push_back(Case(SI->getSuccessorValue(i), + SI->getSuccessorValue(i), + SI->getSuccessor(i))); + tmpCases.sort(CaseCmp()); + + // Merge case into clusters + if (tmpCases.size()>=2) + for (std::list<Case>::iterator I=tmpCases.begin(), J=++(tmpCases.begin()), + E=tmpCases.end(); J!=E; ) { + int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); + int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); + BasicBlock* nextBB = J->BB; + BasicBlock* currentBB = I->BB; + + if ((nextValue-currentValue==1) && (currentBB == nextBB)) { + I->High = J->High; + tmpCases.erase(J++); + } else { + I = J++; + } + } + + Cases.clear(); + for (std::list<Case>::iterator I=tmpCases.begin(), E=tmpCases.end(); + I!=E; ++I, ++numCmps) { + if (I->Low != I->High) + // A range counts double, since it requires two compares. + ++numCmps; + + Cases.push_back(*I); + } + + return numCmps; } // processSwitchInst - Replace the specified switch instruction with a sequence @@ -215,14 +298,14 @@ void LowerSwitch::processSwitchInst(Swit PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); } - std::vector<Case> Cases; - - // Expand comparisons for all of the non-default cases... - for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) - Cases.push_back(Case(SI->getSuccessorValue(i), SI->getSuccessor(i))); - - std::sort(Cases.begin(), Cases.end(), CaseCmp()); + // Prepare cases vector. + CaseVector Cases; + unsigned numCmps = Clusterify(Cases, SI); + + DOUT << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total compares: " << numCmps << "\n"; DOUT << "Cases: " << Cases << "\n"; + BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val, OrigBlock, NewDefault);
_______________________________________________ llvm-commits mailing list llvm-commits@cs.uiuc.edu http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits