sammccall created this revision.
sammccall added a reviewer: hokein.
Herald added a subscriber: mgrang.
Herald added a project: All.
sammccall requested review of this revision.
Herald added subscribers: cfe-commits, alextsao1999.
Herald added a project: clang-tools-extra.

For shift and goto, use a hashtable for faster lookups. This is ~3x bigger,
but these are not most of the actions (~15% each).

For reduce, the common pattern is that a (state, reduce rule) pair applies to
lots of possible lookahead tokens. So store this as one object with a bitmap for
the valid lookahead tokens. This is very efficient (~4x smaller than before).

Overall we're <20% bigger which seems acceptable.
Before: size of the table (bytes): 401554
After: size of the table (bytes): 470636 (Shift=196608 Reduce=77284 Goto=196608)

This yields a 24% speedup of glrParse on my machine (3.5 -> 4.35 MB/s)


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D128318

Files:
  clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
  clang-tools-extra/pseudo/lib/GLR.cpp
  clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
  clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp

Index: clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
+++ clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
@@ -9,7 +9,6 @@
 #include "clang-pseudo/grammar/Grammar.h"
 #include "clang-pseudo/grammar/LRGraph.h"
 #include "clang-pseudo/grammar/LRTable.h"
-#include "clang/Basic/TokenKinds.h"
 #include <cstdint>
 
 namespace llvm {
@@ -45,46 +44,47 @@
 
   bool insert(Entry E) { return Entries.insert(std::move(E)).second; }
   LRTable build(const GrammarTable &GT, unsigned NumStates) && {
-    // E.g. given the following parsing table with 3 states and 3 terminals:
-    //
-    //            a    b     c
-    // +-------+----+-------+-+
-    // |state0 |    | s0,r0 | |
-    // |state1 | acc|       | |
-    // |state2 |    |  r1   | |
-    // +-------+----+-------+-+
-    //
-    // The final LRTable:
-    //  - StateOffset: [s0] = 0, [s1] = 2, [s2] = 3, [sentinel] = 4
-    //  - Symbols:     [ b,   b,   a,  b]
-    //    Actions:     [ s0, r0, acc, r1]
-    //                   ~~~~~~ range for state 0
-    //                           ~~~~ range for state 1
-    //                                ~~ range for state 2
-    // First step, we sort all entries by (State, Symbol, Action).
-    std::vector<Entry> Sorted(Entries.begin(), Entries.end());
-    llvm::sort(Sorted, [](const Entry &L, const Entry &R) {
-      return std::forward_as_tuple(L.State, L.Symbol, L.Act.opaque()) <
-             std::forward_as_tuple(R.State, R.Symbol, R.Act.opaque());
-    });
-
+    llvm::DenseMap<std::pair<StateID, RuleID>, Reduce> Reduces;
     LRTable Table;
-    Table.Actions.reserve(Sorted.size());
-    Table.Symbols.reserve(Sorted.size());
-    // We are good to finalize the States and Actions.
-    for (const auto &E : Sorted) {
-      Table.Actions.push_back(E.Act);
-      Table.Symbols.push_back(E.Symbol);
+    for (const auto &E : Entries) {
+      switch (E.Act.kind()) {
+      case Action::Sentinel:
+        break;
+      case Action::Shift:
+        ++Table.NumShift;
+        Table.Shift.try_emplace(std::make_pair(E.Symbol, E.State),
+                                E.Act.getShiftState());
+        break;
+      case Action::GoTo:
+        ++Table.NumGoto;
+        Table.Goto.try_emplace(std::make_pair(E.Symbol, E.State),
+                               E.Act.getGoToState());
+        break;
+      case Action::Reduce:
+        ++Table.NumReduce;
+        auto &R = Reduces[{E.State, E.Act.getReduceRule()}];
+        R.State = E.State;
+        R.Rule = E.Act.getReduceRule();
+        R.Filter.set(symbolToToken(E.Symbol));
+        break;
+      }
     }
-    // Initialize the terminal and nonterminal offset, all ranges are empty by
-    // default.
-    Table.StateOffset = std::vector<uint32_t>(NumStates + 1, 0);
-    size_t SortedIndex = 0;
-    for (StateID State = 0; State < Table.StateOffset.size(); ++State) {
-      Table.StateOffset[State] = SortedIndex;
-      while (SortedIndex < Sorted.size() && Sorted[SortedIndex].State == State)
-        ++SortedIndex;
+
+    Table.Reduces.reserve(Reduces.size());
+    for (const auto &R : Reduces)
+      Table.Reduces.push_back(R.second);
+    llvm::sort(Table.Reduces, [](auto &L, auto &R) {
+      return std::tie(L.State, L.Rule) < std::tie(R.State, R.Rule);
+    });
+    Table.ReducesLookup.resize(NumStates + 1);
+    unsigned ReducesIndex = 0;
+    for (StateID S = 0; S < NumStates; ++S) {
+      Table.ReducesLookup[S] = ReducesIndex;
+      while (ReducesIndex < Reduces.size() &&
+             Table.Reduces[ReducesIndex].State == S)
+        ReducesIndex++;
     }
+    Table.ReducesLookup[NumStates] = ReducesIndex;
     Table.StartStates = std::move(StartStates);
     return Table;
   }
Index: clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
+++ clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
@@ -8,8 +8,8 @@
 
 #include "clang-pseudo/grammar/LRTable.h"
 #include "clang-pseudo/grammar/Grammar.h"
-#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Capacity.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -35,10 +35,14 @@
   return llvm::formatv(R"(
 Statistics of the LR parsing table:
     number of states: {0}
-    number of actions: {1}
-    size of the table (bytes): {2}
+    number of actions: shift={1} reduce={2} goto={3}
+    size of the table (bytes): {4} (Shift={5} Reduce={6} Goto={7})
 )",
-                       StateOffset.size() - 1, Actions.size(), bytes())
+                       ReducesLookup.size() - 1, NumShift, NumReduce, NumGoto,
+                       bytes(), capacity_in_bytes(Shift),
+                       llvm::capacity_in_bytes(Reduces) +
+                           llvm::capacity_in_bytes(ReducesLookup),
+                       capacity_in_bytes(Goto))
       .str();
 }
 
@@ -46,27 +50,27 @@
   std::string Result;
   llvm::raw_string_ostream OS(Result);
   OS << "LRTable:\n";
-  for (StateID S = 0; S < StateOffset.size() - 1; ++S) {
+  for (StateID S = 0; S < ReducesLookup.size() - 1; ++S) {
     OS << llvm::formatv("State {0}\n", S);
     for (uint16_t Terminal = 0; Terminal < NumTerminals; ++Terminal) {
       SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
-      for (auto A : find(S, TokID)) {
-        if (A.kind() == LRTable::Action::Shift)
-          OS.indent(4) << llvm::formatv("'{0}': shift state {1}\n",
-                                        G.symbolName(TokID), A.getShiftState());
-        else if (A.kind() == LRTable::Action::Reduce)
+      if (auto NewState = getShiftState(S, TokID)) {
+        OS.indent(4) << llvm::formatv("'{0}': shift state {1}\n",
+                                      G.symbolName(TokID), *NewState);
+      }
+      for (auto &R : getReduces(S)) {
+        if (R.Filter.test(TokID))
           OS.indent(4) << llvm::formatv("'{0}': reduce by rule {1} '{2}'\n",
-                                        G.symbolName(TokID), A.getReduceRule(),
-                                        G.dumpRule(A.getReduceRule()));
+                                        G.symbolName(TokID), R.Rule,
+                                        G.dumpRule(R.Rule));
       }
     }
     for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size();
          ++NontermID) {
-      if (find(S, NontermID).empty())
-        continue;
-      OS.indent(4) << llvm::formatv("'{0}': go to state {1}\n",
-                                    G.symbolName(NontermID),
-                                    getGoToState(S, NontermID));
+      auto It = Goto.find({NontermID, S});
+      if (It != Goto.end())
+        OS.indent(4) << llvm::formatv("'{0}': go to state {1}\n",
+                                      G.symbolName(NontermID), It->second);
     }
   }
   return OS.str();
@@ -75,44 +79,18 @@
 llvm::Optional<LRTable::StateID>
 LRTable::getShiftState(StateID State, SymbolID Terminal) const {
   assert(pseudo::isToken(Terminal) && "expected terminal symbol!");
-  for (const auto &Result : find(State, Terminal))
-    if (Result.kind() == Action::Shift)
-      return Result.getShiftState();
-  return llvm::None;
-}
-
-llvm::ArrayRef<LRTable::Action> LRTable::getActions(StateID State,
-                                                    SymbolID Terminal) const {
-  assert(pseudo::isToken(Terminal) && "expect terminal symbol!");
-  return find(State, Terminal);
+  auto It = Shift.find({Terminal, State});
+  if (It == Shift.end())
+    return llvm::None;
+  return It->second;
 }
 
 LRTable::StateID LRTable::getGoToState(StateID State,
                                        SymbolID Nonterminal) const {
   assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!");
-  auto Result = find(State, Nonterminal);
-  assert(Result.size() == 1 && Result.front().kind() == Action::GoTo);
-  return Result.front().getGoToState();
-}
-
-llvm::ArrayRef<LRTable::Action> LRTable::find(StateID Src, SymbolID ID) const {
-  assert(Src + 1u < StateOffset.size());
-  std::pair<size_t, size_t> Range =
-      std::make_pair(StateOffset[Src], StateOffset[Src + 1]);
-  auto SymbolRange = llvm::makeArrayRef(Symbols.data() + Range.first,
-                                        Symbols.data() + Range.second);
-
-  assert(llvm::is_sorted(SymbolRange) &&
-         "subrange of the Symbols should be sorted!");
-  const LRTable::StateID *Start =
-      llvm::partition_point(SymbolRange, [&ID](SymbolID S) { return S < ID; });
-  if (Start == SymbolRange.end())
-    return {};
-  const LRTable::StateID *End = Start;
-  while (End != SymbolRange.end() && *End == ID)
-    ++End;
-  return llvm::makeArrayRef(&Actions[Start - Symbols.data()],
-                            /*length=*/End - Start);
+  auto It = Goto.find({Nonterminal, State});
+  assert(It != Goto.end());
+  return It->second;
 }
 
 LRTable::StateID LRTable::getStartState(SymbolID Target) const {
Index: clang-tools-extra/pseudo/lib/GLR.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/GLR.cpp
+++ clang-tools-extra/pseudo/lib/GLR.cpp
@@ -283,10 +283,10 @@
       if (popAndPushTrivial())
         continue;
       for (const auto &A :
-           Params.Table.getActions((*Heads)[PoppedHeads]->State, Lookahead)) {
-        if (A.kind() != LRTable::Action::Reduce)
+           Params.Table.getReduces((*Heads)[PoppedHeads]->State)) {
+        if (!A.Filter.test(symbolToToken(Lookahead)))
           continue;
-        pop((*Heads)[PoppedHeads], A.getReduceRule());
+        pop((*Heads)[PoppedHeads], A.Rule);
       }
     }
   }
@@ -364,12 +364,12 @@
       return false;
     const GSS::Node *Head = Heads->back();
     llvm::Optional<RuleID> RID;
-    for (auto &A : Params.Table.getActions(Head->State, Lookahead)) {
-      if (A.kind() != LRTable::Action::Reduce)
+    for (auto &A : Params.Table.getReduces(Head->State)) {
+      if (!A.Filter.test(symbolToToken(Lookahead)))
         continue;
       if (RID.hasValue())
         return false;
-      RID = A.getReduceRule();
+      RID = A.Rule;
     }
     if (!RID.hasValue())
       return false;
Index: clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
===================================================================
--- clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
+++ clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
@@ -38,6 +38,8 @@
 
 #include "clang-pseudo/grammar/Grammar.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Capacity.h"
+#include <bitset>
 #include <cstdint>
 #include <vector>
 
@@ -123,18 +125,24 @@
     uint16_t Value : ValueBits;
   };
 
-  // Returns all available actions for the given state on a terminal.
-  // Expected to be called by LR parsers.
-  llvm::ArrayRef<Action> getActions(StateID State, SymbolID Terminal) const;
+  struct Reduce {
+    StateID State;
+    RuleID Rule;
+    std::bitset<tok::NUM_TOKENS> Filter;
+  };
+
+  llvm::ArrayRef<Reduce> getReduces(StateID State) const {
+    unsigned Begin = ReducesLookup[State], End = Begin;
+    while (Reduces[End].State == State)
+      ++End;
+    return llvm::makeArrayRef(&Reduces[Begin], &Reduces[End]);
+  }
+
   // Returns the state after we reduce a nonterminal.
   // Expected to be called by LR parsers.
   StateID getGoToState(StateID State, SymbolID Nonterminal) const;
   llvm::Optional<StateID> getShiftState(StateID State, SymbolID Terminal) const;
 
-  // Looks up available actions.
-  // Returns empty if no available actions in the table.
-  llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
-
   // Returns the state from which the LR parser should start to parse the input
   // tokens as the given StartSymbol.
   //
@@ -147,9 +155,9 @@
   StateID getStartState(SymbolID StartSymbol) const;
 
   size_t bytes() const {
-    return sizeof(*this) + Actions.capacity() * sizeof(Action) +
-           Symbols.capacity() * sizeof(SymbolID) +
-           StateOffset.capacity() * sizeof(uint32_t);
+    return sizeof(*this) + llvm::capacity_in_bytes(Shift) +
+           llvm::capacity_in_bytes(Goto) + llvm::capacity_in_bytes(Reduces) +
+           llvm::capacity_in_bytes(ReducesLookup);
   }
 
   std::string dumpStatistics() const;
@@ -169,19 +177,15 @@
   static LRTable buildForTests(const GrammarTable &, llvm::ArrayRef<Entry>);
 
 private:
-  // Conceptually the LR table is a multimap from (State, SymbolID) => Action.
-  // Our physical representation is quite different for compactness.
-
-  // Index is StateID, value is the offset into Symbols/Actions
-  // where the entries for this state begin.
-  // Give a state id, the corresponding half-open range of Symbols/Actions is
-  // [StateOffset[id], StateOffset[id+1]).
-  std::vector<uint32_t> StateOffset;
-  // Parallel to Actions, the value is SymbolID (columns of the matrix).
-  // Grouped by the StateID, and only subranges are sorted.
-  std::vector<SymbolID> Symbols;
-  // A flat list of available actions, sorted by (State, SymbolID).
-  std::vector<Action> Actions;
+  unsigned NumShift = 0;
+  unsigned NumGoto = 0;
+  unsigned NumReduce = 0;
+
+  llvm::DenseMap<std::pair<SymbolID, StateID>, StateID> Shift;
+  llvm::DenseMap<std::pair<SymbolID, StateID>, StateID> Goto;
+  std::vector<Reduce> Reduces;
+  std::vector<unsigned> ReducesLookup;
+
   // A sorted table, storing the start state for each target parsing symbol.
   std::vector<std::pair<SymbolID, StateID>> StartStates;
 };
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to