Hello, Everyone.

Please review attached patch for LowerSwitch. Actually, it just "merges"
adjacent cases with same destination in one big "case cluster". During
emiting comparison tree such clusters result to range comparisons. This
is pretty cheap but useful transformation.

The same technique is planned for SDISel variant of switch lowering.
-- 
With best regards, Anton Korobeynikov.

Faculty of Mathematics & Mechanics, Saint Petersburg State University.

diff -r 87cd8438fce8 lib/Transforms/Utils/LowerSwitch.cpp
--- a/lib/Transforms/Utils/LowerSwitch.cpp	Mon Feb 26 18:56:07 2007 +0000
+++ b/lib/Transforms/Utils/LowerSwitch.cpp	Wed Feb 28 02:34:32 2007 +0300
@@ -22,6 +22,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Compiler.h"
 #include <algorithm>
+#include <list>
 using namespace llvm;
 
 namespace {
@@ -40,9 +41,19 @@ namespace {
       AU.addPreservedID(LowerInvokePassID);
       AU.addPreservedID(LowerAllocationsID);
     }
-        
-    typedef std::pair<Constant*, BasicBlock*> Case;
-    typedef std::vector<Case>::iterator       CaseItr;
+
+    struct Case {
+      Constant* Low;
+      Constant* High;
+      BasicBlock* BB;
+
+      Case(Constant* _Low = NULL, Constant* _High = NULL,
+           BasicBlock* _BB = NULL):
+        Low(_Low), High(_High), BB(_BB) { }
+    };
+
+    typedef std::vector<Case>           CaseVector;
+    typedef std::vector<Case>::iterator CaseItr;
   private:
     void processSwitchInst(SwitchInst *SI);
 
@@ -50,16 +61,18 @@ namespace {
                               BasicBlock* OrigBlock, BasicBlock* Default);
     BasicBlock* newLeafBlock(Case& Leaf, Value* Val,
                              BasicBlock* OrigBlock, BasicBlock* Default);
+    unsigned Clusterify(CaseVector& Cases, SwitchInst *SI);
   };
 
   /// The comparison function for sorting the switch case values in the vector.
+  /// WARNING: Case ranges should be disjoint!
   struct CaseCmp {
     bool operator () (const LowerSwitch::Case& C1,
                       const LowerSwitch::Case& C2) {
 
-      const ConstantInt* CI1 = cast<const ConstantInt>(C1.first);
-      const ConstantInt* CI2 = cast<const ConstantInt>(C2.first);
-      return CI1->getZExtValue() < CI2->getZExtValue();
+      const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
+      const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
+      return CI1->getSExtValue() < CI2->getSExtValue();
     }
   };
 
@@ -91,19 +104,20 @@ bool LowerSwitch::runOnFunction(Function
 
 // operator<< - Used for debugging purposes.
 //
-std::ostream& operator<<(std::ostream &O,
-                         const std::vector<LowerSwitch::Case> &C) {
+static std::ostream& operator<<(std::ostream &O,
+                                const LowerSwitch::CaseVector &C) {
   O << "[";
 
-  for (std::vector<LowerSwitch::Case>::const_iterator B = C.begin(),
+  for (LowerSwitch::CaseVector::const_iterator B = C.begin(),
          E = C.end(); B != E; ) {
-    O << *B->first;
+    O << *B->Low << " -" << *B->High;
     if (++B != E) O << ", ";
   }
 
   return O << "]";
 }
-OStream& operator<<(OStream &O, const std::vector<LowerSwitch::Case> &C) {
+
+static OStream& operator<<(OStream &O, const LowerSwitch::CaseVector &C) {
   if (O.stream()) *O.stream() << C;
   return O;
 }
@@ -128,7 +142,8 @@ BasicBlock* LowerSwitch::switchConvert(C
 
   Case& Pivot = *(Begin + Mid);
   DOUT << "Pivot ==> "
-       << cast<ConstantInt>(Pivot.first)->getSExtValue() << "\n";
+       << cast<ConstantInt>(Pivot.Low)->getSExtValue() << " -"
+       << cast<ConstantInt>(Pivot.High)->getSExtValue() << "\n";
 
   BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val,
                                       OrigBlock, Default);
@@ -141,7 +156,7 @@ BasicBlock* LowerSwitch::switchConvert(C
   BasicBlock* NewNode = new BasicBlock("NodeBlock");
   F->getBasicBlockList().insert(OrigBlock->getNext(), NewNode);
 
-  ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Val, Pivot.first, "Pivot");
+  ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Val, Pivot.Low, "Pivot");
   NewNode->getInstList().push_back(Comp);
   new BranchInst(LBranch, RBranch, Comp, NewNode);
   return NewNode;
@@ -161,25 +176,93 @@ BasicBlock* LowerSwitch::newLeafBlock(Ca
   BasicBlock* NewLeaf = new BasicBlock("LeafBlock");
   F->getBasicBlockList().insert(OrigBlock->getNext(), NewLeaf);
 
-  // Make the seteq instruction...
-  ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val,
-                                Leaf.first, "SwitchLeaf");
-  NewLeaf->getInstList().push_back(Comp);
+  // Emit comparison
+  ICmpInst* Comp = NULL;
+  if (Leaf.Low == Leaf.High) {
+    // Make the seteq instruction...
+    Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val, Leaf.Low,
+                        "SwitchLeaf", NewLeaf);
+  } else {
+    // Make range comparison
+    if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) {
+      // Val >= Min && Val <= Hi --> Val <= Hi
+      Comp = new ICmpInst(ICmpInst::ICMP_SLE, Val, Leaf.High,
+                          "SwitchLeaf", NewLeaf);
+    } else {
+      // Emit V-Lo <=u Hi-Lo
+      Constant* NegLo = ConstantExpr::getNeg(Leaf.Low);
+      Instruction* Add = BinaryOperator::createAdd(Val, NegLo,
+                                                   Val->getName()+".off",
+                                                   NewLeaf);
+      Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High);
+      Comp = new ICmpInst(ICmpInst::ICMP_ULE, Add, UpperBound,
+                          "SwitchLeaf", NewLeaf);
+    }
+  }
 
   // Make the conditional branch...
-  BasicBlock* Succ = Leaf.second;
+  BasicBlock* Succ = Leaf.BB;
   new BranchInst(Succ, Default, Comp, NewLeaf);
 
   // If there were any PHI nodes in this successor, rewrite one entry
   // from OrigBlock to come from NewLeaf.
   for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
     PHINode* PN = cast<PHINode>(I);
+    // Remove all but one incoming entries from the cluster
+    uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() -
+                     cast<ConstantInt>(Leaf.Low)->getSExtValue();    
+    for (uint64_t j = 0; j < Range; ++j) {
+      PN->removeIncomingValue(OrigBlock);
+    }
+    
     int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
     assert(BlockIdx != -1 && "Switch didn't go to this successor??");
     PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
   }
 
   return NewLeaf;
+}
+
+unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) 
+{
+  std::list<Case> tmpCases;
+  unsigned numCmps = 0;
+
+  // Start with "simple" cases
+  for (unsigned i = 1; i < SI->getNumSuccessors(); ++i)
+    tmpCases.push_back(Case(SI->getSuccessorValue(i),
+                            SI->getSuccessorValue(i),
+                            SI->getSuccessor(i)));
+  tmpCases.sort(CaseCmp());
+
+  // Merge case into clusters
+  if (tmpCases.size()>=2)
+    for (std::list<Case>::iterator I=tmpCases.begin(), J=++(tmpCases.begin()),
+           E=tmpCases.end(); J!=E; ) {
+      int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue();
+      int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue();
+      BasicBlock* nextBB = J->BB;
+      BasicBlock* currentBB = I->BB;
+
+      if ((nextValue-currentValue==1) && (currentBB == nextBB)) {
+        I->High = J->High;
+        tmpCases.erase(J++);
+      } else {
+        I = J++;
+      }
+    }
+
+  Cases.clear();  
+  for (std::list<Case>::iterator I=tmpCases.begin(), E=tmpCases.end();
+       I!=E; ++I, ++numCmps) {
+    if (I->Low != I->High)
+      // A range counts double, since it requires two compares.
+      ++numCmps;
+    
+    Cases.push_back(*I);
+  }
+
+  return numCmps;
 }
 
 // processSwitchInst - Replace the specified switch instruction with a sequence
@@ -215,14 +298,14 @@ void LowerSwitch::processSwitchInst(Swit
     PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
   }
 
-  std::vector<Case> Cases;
-
-  // Expand comparisons for all of the non-default cases...
-  for (unsigned i = 1; i < SI->getNumSuccessors(); ++i)
-    Cases.push_back(Case(SI->getSuccessorValue(i), SI->getSuccessor(i)));
-
-  std::sort(Cases.begin(), Cases.end(), CaseCmp());
+  // Prepare cases vector.
+  CaseVector Cases;
+  unsigned numCmps = Clusterify(Cases, SI);
+
+  DOUT << "Clusterify finished. Total clusters: " << Cases.size()
+       << ". Total compares: " << numCmps << "\n";
   DOUT << "Cases: " << Cases << "\n";
+  
   BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val,
                                           OrigBlock, NewDefault);
 
_______________________________________________
llvm-commits mailing list
llvm-commits@cs.uiuc.edu
http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits

Reply via email to