Hello, Everyone.

Please find new optimization pass for LLVM attached. Actually, it
implements llvm-to-llvm transformation of switches, lowering their
"complexity" (in terms of comparisons/jumps needed). It does techniques
listed in PR926:
1. MRST for sparse switches (they are converted in the series of
switches suitable for jump table emission)
2. Shift/And technique for small switches with few destinations.

Tested with:

1. llvm-test
2. 1) llvm-gcc bootstraped
   2) Mozilla build with bootstrapped llvm-gcc

-- 
With best regards, Anton Korobeynikov.

Faculty of Mathematics & Mechanics, Saint Petersburg State University.

diff -r 87cd8438fce8 include/llvm/LinkAllPasses.h
--- a/include/llvm/LinkAllPasses.h	Mon Feb 26 18:56:07 2007 +0000
+++ b/include/llvm/LinkAllPasses.h	Sat Feb 24 23:31:48 2007 +0300
@@ -109,6 +109,7 @@ namespace {
       (void) llvm::createIndMemRemPass();
       (void) llvm::createInstCountPass();
       (void) llvm::createPredicateSimplifierPass();
+      (void) llvm::createSwitchStrengthReducePass();
 
       (void)new llvm::IntervalPartition();
       (void)new llvm::ImmediateDominators();
diff -r 87cd8438fce8 include/llvm/Transforms/Scalar.h
--- a/include/llvm/Transforms/Scalar.h	Mon Feb 26 18:56:07 2007 +0000
+++ b/include/llvm/Transforms/Scalar.h	Sat Feb 24 23:39:49 2007 +0300
@@ -133,6 +133,13 @@ FunctionPass *createLoopUnswitchPass();
 // LoopUnroll - This pass is a simple loop unrolling pass.
 //
 FunctionPass *createLoopUnrollPass();
+
+//===----------------------------------------------------------------------===//
+//
+// SwitchStrengthReduce - This pass performs a strength reduction on switch
+// instructions, replacing them by code which has lower amount of jumps.
+FunctionPass *createSwitchStrengthReducePass();
+extern const PassInfo *SwitchStrengthReduceID;
 
 //===----------------------------------------------------------------------===//
 //
diff -r 87cd8438fce8 lib/CodeGen/LLVMTargetMachine.cpp
--- a/lib/CodeGen/LLVMTargetMachine.cpp	Mon Feb 26 18:56:07 2007 +0000
+++ b/lib/CodeGen/LLVMTargetMachine.cpp	Sun Feb 25 16:28:35 2007 +0300
@@ -26,8 +26,11 @@ LLVMTargetMachine::addPassesToEmitFile(F
                                        bool Fast) {
   // Standard LLVM-Level Passes.
   
-  // Run loop strength reduction before anything else.
-  if (!Fast) PM.add(createLoopStrengthReducePass(getTargetLowering()));
+  // Run loop & switch strength reduction before anything else.
+  if (!Fast) {
+    PM.add(createLoopStrengthReducePass(getTargetLowering()));
+    PM.add(createSwitchStrengthReducePass());
+  }
   
   // FIXME: Implement efficient support for garbage collection intrinsics.
   PM.add(createLowerGCPass());
//===- SwitchStrengthReduce.cpp - Reduce switch strength ------------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file was developed by Anton Korobeynikov and is distributed under
// the University of Illinois Open Source License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This pass performs a strength reduction on switch instructions,
// replacing them by code, which has lower amount of jumps.
//
// This pass implements 2 techniques:
// 1. Shift/And switch lowering for "dense" switches with few destinations
// 2. MRST algorithm for sparse switches described in this paper:
//    "Efficient Multiway Radix Search Trees"
//    Ulfar Erlingsson, Mukkai Krishnamoorthy and T.V. Raman.
//    Information Processing Letters 60(1996), pp. 115-120.
//    http://www.cs.cornell.edu/home/ulfar/mrst.ps.gz
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
#include "llvm/Constants.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Function.h"
#include "llvm/Instructions.h"
#include "llvm/Pass.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Target/TargetData.h"
#include <algorithm>
#include <list>
#include <iostream>
using namespace llvm;

#define MAX_CASE_BIT_TESTS 3

namespace {    
  /// SwitchStrengthReduce Pass
  class VISIBILITY_HIDDEN SwitchStrengthReduce : public FunctionPass {
  public:
    const TargetData *TD;
    const Type *UIntPtrTy;

    virtual bool runOnFunction(Function &F);
    
    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
      // This is a cluster of orthogonal Transforms	
      AU.addPreserved<UnifyFunctionExitNodes>();
      AU.addPreservedID(PromoteMemoryToRegisterID);
      AU.addPreservedID(LowerSelectID);
      AU.addPreservedID(LowerInvokePassID);
      AU.addPreservedID(LowerAllocationsID);

      AU.addRequired<TargetData>();
    }

    struct Case {
      Constant* Low;
      Constant* High;
      BasicBlock* BB;

      Case(Constant* _Low = NULL, Constant* _High = NULL,
           BasicBlock* _BB = NULL):
        Low(_Low), High(_High), BB(_BB) { }
    };

    struct CaseBits {
      uint64_t Mask;
      BasicBlock* BB;
      unsigned Bits;

      CaseBits(uint64_t _Mask, BasicBlock* _BB, unsigned _Bits):
        Mask(_Mask), BB(_BB), Bits(_Bits) { }
    };

    typedef std::vector<Case>::iterator               CaseItr;
    typedef std::vector<Case>                         CaseVector;
    typedef std::vector<CaseBits>                     CaseBitsVector;
  private:
    unsigned totalCases;
    unsigned totalSteps;
    unsigned totalSpace;

    uint64_t card(CaseVector& C, uint64_t W, unsigned L, unsigned lowbit);
    bool isCritical(CaseVector& C, uint64_t W, unsigned L, unsigned lowbit);
    bool isEntryUsed(CaseVector& C, uint64_t entry, uint64_t Wmax, unsigned lowbit);
    void getEntries(CaseVector& Cout, CaseVector& C, uint64_t entry, uint64_t Wmax,
                    unsigned lowbit);

    uint64_t MRSTFindWindow(CaseVector& C, unsigned no_b,
                            unsigned& len, unsigned& lowbit);
    BasicBlock* MRSTTransformSwitch(CaseVector& C, unsigned steps, Value* Val,
                                    BasicBlock* OrigBlock, BasicBlock* Default);

    BasicBlock* newBunchBlock(CaseVector& C, Value* Val, BasicBlock* OrigBlock,
                              BasicBlock* Default);

    unsigned Clusterify(CaseVector& Cases, SwitchInst *SI);
    void processSwitchInst(SwitchInst *SI);
    void AddMRST(SwitchInst *SI);
    bool EmitBitTests(SwitchInst *SI);
  };

  /// The comparison function for sorting the switch case values in the vector.
  /// WARNING: Case ranges should be disjoint!
  struct CaseCmp {
    bool operator () (const SwitchStrengthReduce::Case& C1,
                      const SwitchStrengthReduce::Case& C2) {

      const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
      const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
      return CI1->getSExtValue() < CI2->getSExtValue();
    }
  };

  struct CaseBitsCmp {
    bool operator () (const SwitchStrengthReduce::CaseBits& C1,
                      const SwitchStrengthReduce::CaseBits& C2) {

      return C1.Bits > C2.Bits;
    }
  };

  RegisterPass<SwitchStrengthReduce>
  X("switch-reduce", "Reduce switch strength");
}

// Publically exposed interface to pass...
const PassInfo *llvm::SwitchStrengthReduceID = X.getPassInfo();

// operator<< - Used for debugging purposes.
//
static std::ostream& operator<<(std::ostream &O,
                                const SwitchStrengthReduce::CaseVector &C) {
  O << "[";

  for (SwitchStrengthReduce::CaseVector::const_iterator B = C.begin(),
         E = C.end(); B != E; ) {
    O << *B->Low << " -" << *B->High;
    if (++B != E) O << ", ";
  }

  return O << "]";
}

static OStream& operator<<(OStream &O, const SwitchStrengthReduce::CaseVector &C) {
  if (O.stream()) *O.stream() << C;
  return O;
}

// createSwitchStrengthReducePass - Interface to this file...
FunctionPass *llvm::createSwitchStrengthReducePass() {
  return new SwitchStrengthReduce();
}

// card - Compute the cardinality of "windowed" set
uint64_t SwitchStrengthReduce::card(CaseVector& C, uint64_t W, 
                                    unsigned L, unsigned lowbit) 
{
  if (W == 0)
    return 0;
  
  unsigned int tblsize = (1u << L);
  char* vals = new char[tblsize];
  memset(vals, 0, tblsize*sizeof(char));

  uint64_t cnt = 0;
  for (CaseItr I = C.begin(), E = C.end(); I != E; ++I) {
    uint64_t Low = cast<ConstantInt>(I->Low)->getSExtValue();
    uint64_t High = cast<ConstantInt>(I->High)->getSExtValue();
    assert((Low == High) && "Should be simple CaseVector!");

    uint64_t index = ((Low & W) >> lowbit);
    if (vals[index] == 0)
      ++cnt;
    vals[index] = 1;
  }

  delete[] vals;

  return cnt;
}

// isCritical - Checks, whether given window is critical
bool SwitchStrengthReduce::isCritical(CaseVector& C, uint64_t W, 
                                      unsigned L, unsigned lowbit)
{
  uint64_t tblsize = (1u << L);

  return (card(C,W,L,lowbit)*2 > tblsize);
}

// MRSTFindWindow - Find the maximum (critical) window
uint64_t SwitchStrengthReduce::MRSTFindWindow(CaseVector& C, unsigned no_b,
                                              unsigned& len, unsigned& lowbit)
{
  uint64_t Wcur = 0, Wmax = 0;
  uint64_t Ccur = 0, Cmax = 0;
  uint64_t Wleft = 0;
  unsigned L = 0, Lmax = 0;
  unsigned lowmax = 0;

  uint64_t b = (0x1ull << no_b);
  for (; b > 0; b >>= 1) {
    if (Wcur == 0) Wleft = b;
    Wcur |= b; 
    L += 1; 

    if (isCritical(C, Wcur, L, no_b)) {
      Wmax = Wcur;
      Lmax = L;
      lowmax = no_b;
      Cmax = card(C, Wmax, Lmax, lowmax);
    } else {
      assert(Wcur & Wleft); 
      Wcur ^= Wleft;
      Wleft >>= 1;
      L -= 1;

      Ccur = card(C, Wcur, L, no_b);
      if (Ccur > Cmax) {
        Wmax = Wcur;
        Lmax = L;
        lowmax = no_b;
        Cmax = Ccur;
      }
    }
    --no_b;
  }
  
  len = Lmax;
  lowbit = lowmax;
  return Wmax;
}

bool SwitchStrengthReduce::isEntryUsed(CaseVector& C, uint64_t entry,
                                       uint64_t Wmax, unsigned lowbit)
{
  uint64_t mask = Wmax >> lowbit;
  for (CaseItr I = C.begin(), E = C.end(); I != E; ++I) {
    uint64_t Low = cast<ConstantInt>(I->Low)->getSExtValue();
    uint64_t High = cast<ConstantInt>(I->High)->getSExtValue();
    assert((Low == High) && "Should be simple CaseVector!");

    if ((((Low) >> lowbit) & mask) == entry)
      return true;
  }
  return false;
}

void SwitchStrengthReduce::getEntries(CaseVector& Cout, CaseVector& C,
                                      uint64_t entry, uint64_t Wmax,
                                      unsigned lowbit)
{
  uint64_t mask = Wmax >> lowbit;
  
  for (CaseItr I = C.begin(), E = C.end(); I != E; ++I) {
    uint64_t Low = cast<ConstantInt>(I->Low)->getSExtValue();
    uint64_t High = cast<ConstantInt>(I->High)->getSExtValue();
    assert((Low == High) && "Should be simple CaseVector!");

    if ((((cast<ConstantInt>(I->Low)->getZExtValue()) >> lowbit) & mask) == entry)
      Cout.push_back(*I);
  }
}

#define MRST_CTREE 4

BasicBlock*
SwitchStrengthReduce::MRSTTransformSwitch(CaseVector& C, unsigned steps,
                                          Value* Val, BasicBlock* OrigBlock,
                                          BasicBlock* Default)
{

  // Step 1: Handle switches with <= MRST_CTREE cases with a comparison tree.
  DOUT << "Entering MRSTTransformSwitch\n";
  DOUT << "Cases: " << C << "\n";
  if (C.size() <= MRST_CTREE) {
    BasicBlock* NewNode;
    
    NewNode = newBunchBlock(C, Val, OrigBlock, Default);
    totalSteps += steps + 1 + ((1+C.size())*C.size())/2;
    totalCases += C.size();
    DOUT << "Exit\n"; 
    return NewNode;
  }

  // Step 2: Find the maximum window
  unsigned len, shiftbits;
  unsigned no_b = Val->getType()->getPrimitiveSizeInBits() - 1;
  uint64_t Wmax = MRSTFindWindow(C, no_b, len, shiftbits);
  uint64_t tblsize = (1u << len);

  DOUT << "Width: " << no_b << "\n";
  DOUT << "Wmax: " << Wmax << "\n";
  DOUT << "bits " << len+shiftbits-1 << " to " << shiftbits << "\n";
  
  // Step 3: Generate the table
  totalSpace += tblsize;
  DOUT << "Entries: [ ";
  for (uint64_t j = 0; j<tblsize; ++j ) {
    if (isEntryUsed(C, j, Wmax, shiftbits))
      DOUT << j << ", ";
    else
      DOUT << "dflt, ";
  }
  DOUT << "]\n";
  
  Function* F = OrigBlock->getParent();
  BasicBlock* NewNode = new BasicBlock("NodeBlock");
  F->getBasicBlockList().insert(OrigBlock->getNext(), NewNode);

  BinaryOperator* AndOp =
    BinaryOperator::create(Instruction::And,
                           Val,
                           ConstantInt::get(Val->getType(), Wmax));

  BinaryOperator* ShiftOp = NULL;
  if (shiftbits) // Don't create redundant shift
    ShiftOp= BinaryOperator::create(Instruction::LShr,
                                    AndOp,
                                    ConstantInt::get(Val->getType(), shiftbits));
  
  SwitchInst* SI = new SwitchInst((ShiftOp ? ShiftOp : AndOp), Default, tblsize);
  NewNode->getInstList().push_back(AndOp);
  if (ShiftOp)
    NewNode->getInstList().push_back(ShiftOp);
  NewNode->getInstList().push_back(SI);

  // Step 4: Recurse
  for (uint64_t j = 0; j < tblsize; ++j) {
    BasicBlock* NodeBlock = Default;
    if (isEntryUsed(C, j, Wmax, shiftbits)) {
      CaseVector Cnew;
      getEntries(Cnew, C, j, Wmax, shiftbits);
      NodeBlock = MRSTTransformSwitch(Cnew, steps+1, Val, OrigBlock, Default);
    }
    
    SI->addCase(ConstantInt::get(Val->getType(), j), NodeBlock);
  }

  return NewNode;
}

bool SwitchStrengthReduce::runOnFunction(Function &F) {
  TD = &getAnalysis<TargetData>();
  UIntPtrTy = TD->getIntPtrType();

  bool Changed = false;

  for (Function::iterator I = F.begin(), E = F.end(); I != E; ) {
    BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks

    if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) {
      Changed = true;
      processSwitchInst(SI);
    }
  }

  return Changed;
}

BasicBlock* SwitchStrengthReduce::newBunchBlock(CaseVector& C, Value* Val,
                                                BasicBlock* OrigBlock,
                                                BasicBlock* Default)
{
  Function* F = OrigBlock->getParent();
  BasicBlock* CurrentLeaf = new BasicBlock("LeafBlock");
  BasicBlock* TopLeaf = CurrentLeaf;
    
  for (CaseItr I = C.begin(), E = C.end(); I != E; ++I) {
    BasicBlock* NextLeaf;
    if (I+1 != E)
      NextLeaf = new BasicBlock("LeafBlock");
    else
      NextLeaf = Default;
        
    F->getBasicBlockList().insert(OrigBlock->getNext(), CurrentLeaf);

    // Make the seteq instruction...
    ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val,
                                  I->Low, "SwitchLeaf");
    CurrentLeaf->getInstList().push_back(Comp);

    // Make the conditional branch...
    BasicBlock* Succ = I->BB;
    new BranchInst(Succ, NextLeaf, Comp, CurrentLeaf);

    // If there were any PHI nodes in this successor, rewrite one entry
    // from OrigBlock to come from CurrentLeaf.
    for (BasicBlock::iterator J = Succ->begin(); isa<PHINode>(J); ++J) {      PHINode* PN = cast<PHINode>(J);
      int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
      assert(BlockIdx != -1 && "Switch didn't go to this successor??");
      PN->setIncomingBlock((unsigned)BlockIdx, CurrentLeaf);
    }

    CurrentLeaf = NextLeaf;
  }
  
  return TopLeaf;
}

unsigned SwitchStrengthReduce::Clusterify(CaseVector& Cases, SwitchInst *SI) 
{
  std::list<Case> tmpCases;
  unsigned numCmps = 0;

  // Start with "simple" cases
  for (unsigned i = 1; i < SI->getNumSuccessors(); ++i)
    tmpCases.push_back(Case(SI->getSuccessorValue(i),
                            SI->getSuccessorValue(i),
                            SI->getSuccessor(i)));
  tmpCases.sort(CaseCmp());

  // Merge case into clusters
  if (tmpCases.size()>=2)
    for (std::list<Case>::iterator I=tmpCases.begin(), J=++(tmpCases.begin()),
           E=tmpCases.end(); J!=E; ) {
      int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue();
      int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue();
      BasicBlock* nextBB = J->BB;
      BasicBlock* currentBB = I->BB;

      if ((nextValue-currentValue==1) && (currentBB == nextBB)) {
        I->High = J->High;
        tmpCases.erase(J++);
      } else {
        I = J++;
      }
    }

  Cases.clear();  
  for (std::list<Case>::iterator I=tmpCases.begin(), E=tmpCases.end();
       I!=E; ++I, ++numCmps) {
    if (I->Low != I->High)
      // A range counts double, since it requires two compares.
      ++numCmps;
    
    Cases.push_back(*I);
  }

  return numCmps;
}

void SwitchStrengthReduce::AddMRST(SwitchInst *SI) 
{
  BasicBlock *CurBlock = SI->getParent();
  Function *F = CurBlock->getParent();
  Value *Val = SI->getOperand(0);  // The value we are switching on...
  BasicBlock* Default = SI->getDefaultDest();

  // Prepare cases vector. We need "simple" cases for MRST
  CaseVector Cases;
  for (unsigned i = 1; i < SI->getNumSuccessors(); ++i)
    Cases.push_back(Case(SI->getSuccessorValue(i),
                         SI->getSuccessorValue(i),
                         SI->getSuccessor(i)));
  std::sort(Cases.begin(), Cases.end(), CaseCmp());
 
  // Should we do MRST?
  int64_t First =cast<ConstantInt>(Cases.front().Low)->getSExtValue();
  int64_t Last  = cast<ConstantInt>(Cases.back().High)->getSExtValue();
  double Density = (double)Cases.size() / (double)((Last - First) + 1ULL);
  DOUT << "First: " << First << " Last: " << Last << "\n";
  DOUT << "Total cases: " << Cases.size() << "\n";
  DOUT << "Switch density: " << Density << "\n";

  if (Density<0.3125 && Cases.size()>4) {
    // It's worth to apply MRST algorithm: we won't pessimize anything
    totalCases = 0;
    totalSteps = 0;
    totalSpace = SI->getNumOperands();

    std::cerr << "***\n***\nTriggered for" << *SI << "\n***\n***\n";
    
    // Create a new, empty default block so that the new hierarchy of
    // if-then statements go to this and the PHI nodes are happy.
    BasicBlock* NewDefault = new BasicBlock("NewDefault");
    F->getBasicBlockList().insert(Default, NewDefault);

    new BranchInst(Default, NewDefault);
    
    // If there is an entry in any PHI nodes for the default edge, make sure
    // to update them as well.
    for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) {
      PHINode *PN = cast<PHINode>(I);
      int BlockIdx = PN->getBasicBlockIndex(CurBlock);
      assert(BlockIdx != -1 && "Switch didn't go to this successor??");
      PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
    }

    BasicBlock* SwitchBlock = MRSTTransformSwitch(Cases, 0, Val,
                                                  CurBlock, NewDefault);
    // Branch to our shiny new switch stuff...
    new BranchInst(SwitchBlock, CurBlock);

    // We are now done with the switch instruction, delete it.
    CurBlock->getInstList().erase(SI);

    DOUT << " MRST Results:" << "\n";
    DOUT << " -- Cases: " << totalCases << "\n";
    DOUT << " -- Jump tables space: " << totalSpace << "\n";
    DOUT << " -- Total branches: " << totalSteps << "\n";
    DOUT << " -- Avg. branches: " << ((double)totalSteps)/totalCases << "\n";
  }  
}

bool SwitchStrengthReduce::EmitBitTests(SwitchInst *SI) 
{
  BasicBlock *CurBlock = SI->getParent();
  Function *F = CurBlock->getParent();
  Value *Val = SI->getOperand(0);  // The value we are switching on...
  BasicBlock* Default = SI->getDefaultDest();
  unsigned IntPtrBits = UIntPtrTy->getPrimitiveSizeInBits();
  
  // Prepare cases vector. We need "clustered" cases for MRST
  CaseVector Cases;
  unsigned numCmps = Clusterify(Cases, SI);

  DOUT << "Clusterify finished. Total clusters: " << Cases.size()
       << ". Total compares: " << numCmps << "\n";
  
  // Count unique destinations
  SmallSet<BasicBlock*, MAX_CASE_BIT_TESTS+1> Dests;
  for (CaseItr I = Cases.begin(), E = Cases.end(); I!=E; ++I) {
    Dests.insert(I->BB);
    if (Dests.size()>MAX_CASE_BIT_TESTS)
      // Don't bother the code below, if there are too much unique destinations
      return false;
  }
  DOUT << "Total number of unique destinations: " << Dests.size() << "\n";
    
  // Compute span of values.
  Constant* minValue = Cases.front().Low;
  Constant* maxValue = Cases.back().High;
  uint64_t range = cast<ConstantInt>(maxValue)->getSExtValue() -
                   cast<ConstantInt>(minValue)->getSExtValue();
  DOUT << "Compare range: " << range << "\n"
       << "Low bound: " << cast<ConstantInt>(minValue)->getSExtValue() << "\n"
       << "High bound: " << cast<ConstantInt>(maxValue)->getSExtValue() << "\n";
  
  if (range<=IntPtrBits &&
      ((Dests.size() == 1 && numCmps >= 3)
       || (Dests.size() == 2 && numCmps >= 5)
       || (Dests.size() >= 3 && numCmps >= 6))) {

    std::cerr << "***\n***\nTriggered for" << *SI << "\n***\n***\n";
    
    // Create a new, empty default block so that the new hierarchy of
    // if-then statements go to this and the PHI nodes are happy.
    BasicBlock* NewDefault = new BasicBlock("NewDefault");
    F->getBasicBlockList().insert(Default, NewDefault);

    new BranchInst(Default, NewDefault);
    
    // If there is an entry in any PHI nodes for the default edge, make sure
    // to update them as well.
    for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) {
      PHINode *PN = cast<PHINode>(I);
      int BlockIdx = PN->getBasicBlockIndex(CurBlock);
      assert(BlockIdx != -1 && "Switch didn't go to this successor??");
      PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
    }

    // Optimize the case where all the case values fit in a
	  // word without having to subtract minValue. In this case,
    // we can optimize away the subtraction.
    if (cast<ConstantInt>(minValue)->getSExtValue() >= 0 &&
        cast<ConstantInt>(maxValue)->getSExtValue() <= IntPtrBits) {
      range = cast<ConstantInt>(maxValue)->getSExtValue();
      minValue = NULL;
    }

    CaseBitsVector CasesBits;
    unsigned i, count = 0;

    int64_t lowBound = 0;
    if (minValue)
      lowBound = cast<ConstantInt>(minValue)->getSExtValue();

    for (CaseItr I = Cases.begin(), E = Cases.end(); I!=E; ++I) {
      BasicBlock* Dest = I->BB;
      for (i = 0; i < count; ++i)
        if (Dest == CasesBits[i].BB)
          break;
      
      if (i == count) {
        assert((count < MAX_CASE_BIT_TESTS) && "Too much destinations to test!");
        CasesBits.push_back(CaseBits(0, Dest, 0));
        count++;
      }

      uint64_t lo = cast<ConstantInt>(I->Low)->getSExtValue() - lowBound;
      uint64_t hi = cast<ConstantInt>(I->High)->getSExtValue() - lowBound;

      for (uint64_t j = lo; j <= hi; j++) {
        CasesBits[i].Mask |=  1 << j;
        CasesBits[i].Bits++;
      }
      
    }
    std::sort(CasesBits.begin(), CasesBits.end(), CaseBitsCmp());

    BasicBlock* SwitchBlock = CurBlock->splitBasicBlock(SI, "SwitchBlock");
    BasicBlock* RangeCheckBlock = new BasicBlock("RangeCheckBlock");
    F->getBasicBlockList().insert(CurBlock->getNext(), RangeCheckBlock);

    Value *NewVal = Val;
    // Subtract the minimum value
    if (minValue) {
      Constant *NegLo = ConstantExpr::getNeg(minValue);
      NewVal = BinaryOperator::createAdd(Val, NegLo,
                                         Val->getName()+".off",
                                         RangeCheckBlock);
    }
    // Check range
    Instruction* Comp = new ICmpInst(ICmpInst::ICMP_ULE, NewVal,
                                     ConstantInt::get(Val->getType(), range),
                                     "", RangeCheckBlock);

    BasicBlock* CurrentLeaf = new BasicBlock("CaseBitTestLeafBlock");

    new BranchInst(CurrentLeaf, NewDefault, Comp, RangeCheckBlock);

    // If Val's Type differs from UIntPtrTy - insert cast.
    // Note, that everything should be >0 here, so we're inserting zext
    if (NewVal->getType() != UIntPtrTy)
      NewVal = CastInst::createIntegerCast(NewVal, UIntPtrTy, false,
                                           "", CurrentLeaf);
    
    // Make desired shift
    Instruction* ShiftOp =
      BinaryOperator::createShl(ConstantInt::get(UIntPtrTy, 1),
                                NewVal,
                                "",
                                CurrentLeaf);
    
    for (CaseBitsVector::iterator I = CasesBits.begin(), J = I,
           E = CasesBits.end(); I != E; ++I) {
      BasicBlock* NextLeaf;
      if (++J != E)
        NextLeaf = new BasicBlock("CaseBitTestLeafBlock");
      else
        NextLeaf = NewDefault;
      F->getBasicBlockList().insert(RangeCheckBlock->getNext(), CurrentLeaf);

      DOUT << "Mask: " << I->Mask << " bits: " << I->Bits << "\n";
      
      // Mask it!
      BinaryOperator* AndOp =
        BinaryOperator::createAnd(ShiftOp,
                                  ConstantInt::get(UIntPtrTy, I->Mask),
                                  Val->getName()+".masked",
                                  CurrentLeaf);

      ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_NE, AndOp,
                                    ConstantInt::get(UIntPtrTy, 0),
                                    "SwitchLeaf",
                                    CurrentLeaf);
      // Jump if not zero
      new BranchInst(I->BB, NextLeaf, Comp, CurrentLeaf);

      // If there were any PHI nodes in this successor, merge & fix entries.
      for (BasicBlock::iterator P = I->BB->begin(); isa<PHINode>(P); ++P) {
        PHINode* PN = cast<PHINode>(P);
        // Remove all but one incoming entries from cluster
        for (unsigned i=0; i<I->Bits-1; ++i) {
          PN->removeIncomingValue(SwitchBlock);
        }

        int BlockIdx = PN->getBasicBlockIndex(SwitchBlock);
        assert(BlockIdx != -1 && "Switch didn't go to this successor??");
        PN->setIncomingBlock((unsigned)BlockIdx, CurrentLeaf);
      }

      CurrentLeaf = NextLeaf;
    }
    
    CurBlock->getInstList().pop_back();
    new BranchInst(RangeCheckBlock, CurBlock);

    // We are now done with the switch instruction, delete it.
    SwitchBlock->getInstList().erase(SI);
    SwitchBlock->replaceAllUsesWith(RangeCheckBlock);
    SwitchBlock->eraseFromParent();

    return true;
  }

  return false;
}


// processSwitchInst - Split the specified switch instruction with a
// set of switches & branches, optimizing the average amount of jumps.
void SwitchStrengthReduce::processSwitchInst(SwitchInst *SI) {
  BasicBlock *CurBlock = SI->getParent();

  // If there is only the default destination, don't bother with the code below.
  if (SI->getNumOperands() == 2) {
    new BranchInst(SI->getDefaultDest(), CurBlock);
    CurBlock->getInstList().erase(SI);
    return;
  }

  if (EmitBitTests(SI))
    return;
  
  AddMRST(SI);
}
_______________________________________________
llvm-commits mailing list
llvm-commits@cs.uiuc.edu
http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits

Reply via email to