FST  openfst-1.7.1
OpenFst Library
trie.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 
4 #ifndef FST_EXTENSIONS_LINEAR_TRIE_H_
5 #define FST_EXTENSIONS_LINEAR_TRIE_H_
6 
7 #include <unordered_map>
8 #include <utility>
9 #include <vector>
10 
11 #include <fst/compat.h>
12 #include <fst/util.h>
13 
14 namespace fst {
15 
16 const int kNoTrieNodeId = -1;
17 
18 // Forward declarations of all available trie topologies.
19 template <class L, class H>
21 template <class L, class H>
23 
24 // A pair of parent node id and label, part of a trie edge
25 template <class L>
26 struct ParentLabel {
27  int parent;
28  L label;
29 
31  ParentLabel(int p, L l) : parent(p), label(l) {}
32 
33  bool operator==(const ParentLabel &that) const {
34  return parent == that.parent && label == that.label;
35  }
36 
37  std::istream &Read(std::istream &strm) { // NOLINT
38  ReadType(strm, &parent);
39  ReadType(strm, &label);
40  return strm;
41  }
42 
43  std::ostream &Write(std::ostream &strm) const { // NOLINT
44  WriteType(strm, parent);
45  WriteType(strm, label);
46  return strm;
47  }
48 };
49 
50 template <class L, class H>
52  size_t operator()(const ParentLabel<L> &pl) const {
53  return static_cast<size_t>(pl.parent * 7853 + H()(pl.label));
54  }
55 };
56 
57 // The trie topology in a nested tree of hash maps; allows efficient
58 // iteration over children of a specific node.
59 template <class L, class H>
60 class NestedTrieTopology {
61  public:
62  typedef L Label;
63  typedef H Hash;
64  typedef std::unordered_map<L, int, H> NextMap;
65 
67  public:
68  typedef std::forward_iterator_tag iterator_category;
69  typedef std::pair<ParentLabel<L>, int> value_type;
70  typedef std::ptrdiff_t difference_type;
71  typedef const value_type *pointer;
72  typedef const value_type &reference;
73 
74  friend class NestedTrieTopology<L, H>;
75 
76  const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {}
77 
78  reference operator*() {
79  UpdateStub();
80  return stub_;
81  }
82  pointer operator->() {
83  UpdateStub();
84  return &stub_;
85  }
86 
87  const_iterator &operator++();
88  const_iterator &operator++(int); // NOLINT
89 
90  bool operator==(const const_iterator &that) const {
91  return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ &&
92  cur_edge_ == that.cur_edge_;
93  }
94  bool operator!=(const const_iterator &that) const {
95  return !(*this == that);
96  }
97 
98  private:
99  const_iterator(const NestedTrieTopology *ptr, int cur_node)
100  : ptr_(ptr), cur_node_(cur_node) {
101  SetProperCurEdge();
102  }
103 
104  void SetProperCurEdge() {
105  if (cur_node_ < ptr_->NumNodes())
106  cur_edge_ = ptr_->nodes_[cur_node_]->begin();
107  else
108  cur_edge_ = ptr_->nodes_[0]->begin();
109  }
110 
111  void UpdateStub() {
112  stub_.first = ParentLabel<L>(cur_node_, cur_edge_->first);
113  stub_.second = cur_edge_->second;
114  }
115 
116  const NestedTrieTopology *ptr_;
117  int cur_node_;
118  typename NextMap::const_iterator cur_edge_;
119  value_type stub_;
120  };
121 
125  void swap(NestedTrieTopology &that);
126  NestedTrieTopology &operator=(const NestedTrieTopology &that);
127  bool operator==(const NestedTrieTopology &that) const;
128  bool operator!=(const NestedTrieTopology &that) const;
129 
130  int Root() const { return 0; }
131  size_t NumNodes() const { return nodes_.size(); }
132  int Insert(int parent, const L &label);
133  int Find(int parent, const L &label) const;
134  const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; }
135 
136  std::istream &Read(std::istream &strm); // NOLINT
137  std::ostream &Write(std::ostream &strm) const; // NOLINT
138 
139  const_iterator begin() const { return const_iterator(this, 0); }
140  const_iterator end() const { return const_iterator(this, NumNodes()); }
141 
142  private:
143  std::vector<NextMap *>
144  nodes_; // Use pointers to avoid copying the maps when the
145  // vector grows
146 };
147 
148 template <class L, class H>
150  nodes_.push_back(new NextMap);
151 }
152 
153 template <class L, class H>
155  nodes_.reserve(that.nodes_.size());
156  for (size_t i = 0; i < that.nodes_.size(); ++i) {
157  NextMap *node = that.nodes_[i];
158  nodes_.push_back(new NextMap(*node));
159  }
160 }
161 
162 template <class L, class H>
164  for (size_t i = 0; i < nodes_.size(); ++i) {
165  NextMap *node = nodes_[i];
166  delete node;
167  }
168 }
169 
170 // TODO(wuke): std::swap compatibility
171 template <class L, class H>
173  nodes_.swap(that.nodes_);
174 }
175 
176 template <class L, class H>
178  const NestedTrieTopology &that) {
179  NestedTrieTopology copy(that);
180  swap(copy);
181  return *this;
182 }
183 
184 template <class L, class H>
186  const NestedTrieTopology &that) const {
187  if (NumNodes() != that.NumNodes()) return false;
188  for (int i = 0; i < NumNodes(); ++i)
189  if (ChildrenOf(i) != that.ChildrenOf(i)) return false;
190  return true;
191 }
192 
193 template <class L, class H>
195  const NestedTrieTopology &that) const {
196  return !(*this == that);
197 }
198 
199 template <class L, class H>
200 inline int NestedTrieTopology<L, H>::Insert(int parent, const L &label) {
201  int ret = Find(parent, label);
202  if (ret == kNoTrieNodeId) {
203  ret = NumNodes();
204  (*nodes_[parent])[label] = ret;
205  nodes_.push_back(new NextMap);
206  }
207  return ret;
208 }
209 
210 template <class L, class H>
211 inline int NestedTrieTopology<L, H>::Find(int parent, const L &label) const {
212  typename NextMap::const_iterator it = nodes_[parent]->find(label);
213  return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second;
214 }
215 
216 template <class L, class H>
217 inline std::istream &NestedTrieTopology<L, H>::Read(
218  std::istream &strm) { // NOLINT
219  NestedTrieTopology new_trie;
220  size_t num_nodes;
221  if (!ReadType(strm, &num_nodes)) return strm;
222  for (size_t i = 1; i < num_nodes; ++i) new_trie.nodes_.push_back(new NextMap);
223  for (size_t i = 0; i < num_nodes; ++i) ReadType(strm, new_trie.nodes_[i]);
224  if (strm) swap(new_trie);
225  return strm;
226 }
227 
228 template <class L, class H>
229 inline std::ostream &NestedTrieTopology<L, H>::Write(
230  std::ostream &strm) const { // NOLINT
231  WriteType(strm, NumNodes());
232  for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]);
233  return strm;
234 }
235 
236 template <class L, class H>
239  ++cur_edge_;
240  if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) {
241  ++cur_node_;
242  while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty())
243  ++cur_node_;
244  SetProperCurEdge();
245  }
246  return *this;
247 }
248 
249 template <class L, class H>
252  const_iterator save(*this);
253  ++(*this);
254  return save;
255 }
256 
257 // The trie topology in a single hash map; only allows iteration over
258 // all the edges in arbitrary order.
259 template <class L, class H>
260 class FlatTrieTopology {
261  private:
262  typedef std::unordered_map<ParentLabel<L>, int, ParentLabelHash<L, H>>
263  NextMap;
264 
265  public:
266  // Iterator over edges as std::pair<ParentLabel<L>, int>
267  typedef typename NextMap::const_iterator const_iterator;
268  typedef L Label;
269  typedef H Hash;
270 
272  FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {}
273  template <class T>
274  explicit FlatTrieTopology(const T &that);
275 
276  // TODO(wuke): std::swap compatibility
277  void swap(FlatTrieTopology &that) { next_.swap(that.next_); }
278 
279  bool operator==(const FlatTrieTopology &that) const {
280  return next_ == that.next_;
281  }
282  bool operator!=(const FlatTrieTopology &that) const {
283  return !(*this == that);
284  }
285 
286  int Root() const { return 0; }
287  size_t NumNodes() const { return next_.size() + 1; }
288  int Insert(int parent, const L &label);
289  int Find(int parent, const L &label) const;
290 
291  std::istream &Read(std::istream &strm) { // NOLINT
292  return ReadType(strm, &next_);
293  }
294  std::ostream &Write(std::ostream &strm) const { // NOLINT
295  return WriteType(strm, next_);
296  }
297 
298  const_iterator begin() const { return next_.begin(); }
299  const_iterator end() const { return next_.end(); }
300 
301  private:
302  NextMap next_;
303 };
304 
305 template <class L, class H>
306 template <class T>
308  : next_(that.begin(), that.end()) {}
309 
310 template <class L, class H>
311 inline int FlatTrieTopology<L, H>::Insert(int parent, const L &label) {
312  int ret = Find(parent, label);
313  if (ret == kNoTrieNodeId) {
314  ret = NumNodes();
315  next_[ParentLabel<L>(parent, label)] = ret;
316  }
317  return ret;
318 }
319 
320 template <class L, class H>
321 inline int FlatTrieTopology<L, H>::Find(int parent, const L &label) const {
322  typename NextMap::const_iterator it =
323  next_.find(ParentLabel<L>(parent, label));
324  return it == next_.end() ? kNoTrieNodeId : it->second;
325 }
326 
327 // A collection of implementations of the trie data structure. The key
328 // is a sequence of type `L` which must be hashable. The value is of
329 // `V` which must be default constructible and copyable. In addition,
330 // a value object is stored for each node in the trie therefore
331 // copying `V` should be cheap.
332 //
333 // One can access the store values with an integer node id, using the
334 // [] operator. A valid node id can be obtained by the following ways:
335 //
336 // 1. Using the `Root()` method to get the node id of the root.
337 //
338 // 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense
339 // so every integer in this range is a valid node id.
340 //
341 // 3. Using the node id returned from a successful `Insert()` or
342 // `Find()` call.
343 //
344 // 4. Iterating over the trie edges with an `EdgeIterator` and using
345 // the node ids returned from its `Parent()` and `Child()` methods.
346 //
347 // Below is an example of inserting keys into the trie:
348 //
349 // const string words[] = {"hello", "health", "jello"};
350 // Trie<char, bool> dict;
351 // for (auto word : words) {
352 // int cur = dict.Root();
353 // for (char c : word) {
354 // cur = dict.Insert(cur, c);
355 // }
356 // dict[cur] = true;
357 // }
358 //
359 // And the following is an example of looking up the longest prefix of
360 // a string using the trie constructed above:
361 //
362 // string query = "healed";
363 // size_t prefix_length = 0;
364 // int cur = dict.Find(dict.Root(), query[prefix_length]);
365 // while (prefix_length < query.size() &&
366 // cur != Trie<char, bool>::kNoNodeId) {
367 // ++prefix_length;
368 // cur = dict.Find(cur, query[prefix_length]);
369 // }
370 template <class L, class V, class T>
371 class MutableTrie {
372  public:
373  template <class LL, class VV, class TT>
374  friend class MutableTrie;
375 
376  typedef L Label;
377  typedef V Value;
378  typedef T Topology;
379 
380  // Constructs a trie with only the root node.
382 
383  // Conversion from another trie of a possiblly different
384  // topology. The underlying topology must supported conversion.
385  template <class S>
386  explicit MutableTrie(const MutableTrie<L, V, S> &that)
387  : topology_(that.topology_), values_(that.values_) {}
388 
389  // TODO(wuke): std::swap compatibility
390  void swap(MutableTrie &that) {
391  topology_.swap(that.topology_);
392  values_.swap(that.values_);
393  }
394 
395  int Root() const { return topology_.Root(); }
396  size_t NumNodes() const { return topology_.NumNodes(); }
397 
398  // Inserts an edge with given `label` at node `parent`. Returns the
399  // child node id. If the node already exists, returns the node id
400  // right away.
401  int Insert(int parent, const L &label) {
402  int ret = topology_.Insert(parent, label);
403  values_.resize(NumNodes());
404  return ret;
405  }
406 
407  // Finds the node id of the node from `parent` via `label`. Returns
408  // `kNoTrieNodeId` when such a node does not exist.
409  int Find(int parent, const L &label) const {
410  return topology_.Find(parent, label);
411  }
412 
413  const T &TrieTopology() const { return topology_; }
414 
415  // Accesses the value stored for the given node.
416  V &operator[](int node_id) { return values_[node_id]; }
417  const V &operator[](int node_id) const { return values_[node_id]; }
418 
419  // Comparison by content
420  bool operator==(const MutableTrie &that) const {
421  return topology_ == that.topology_ && values_ == that.values_;
422  }
423 
424  bool operator!=(const MutableTrie &that) const { return !(*this == that); }
425 
426  std::istream &Read(std::istream &strm) { // NOLINT
427  ReadType(strm, &topology_);
428  ReadType(strm, &values_);
429  return strm;
430  }
431  std::ostream &Write(std::ostream &strm) const { // NOLINT
432  WriteType(strm, topology_);
433  WriteType(strm, values_);
434  return strm;
435  }
436 
437  private:
438  T topology_;
439  std::vector<V> values_;
440 };
441 
442 } // namespace fst
443 
444 #endif // FST_EXTENSIONS_LINEAR_TRIE_H_
int Insert(int parent, const L &label)
Definition: trie.h:311
const_iterator begin() const
Definition: trie.h:298
V & operator[](int node_id)
Definition: trie.h:416
int Root() const
Definition: trie.h:286
MutableTrie(const MutableTrie< L, V, S > &that)
Definition: trie.h:386
bool operator!=(const NestedTrieTopology &that) const
Definition: trie.h:194
const V & operator[](int node_id) const
Definition: trie.h:417
std::istream & Read(std::istream &strm)
Definition: trie.h:37
NestedTrieTopology & operator=(const NestedTrieTopology &that)
Definition: trie.h:177
const_iterator & operator++()
Definition: trie.h:238
NextMap::const_iterator const_iterator
Definition: trie.h:267
FlatTrieTopology(const FlatTrieTopology &that)
Definition: trie.h:272
bool operator!=(const const_iterator &that) const
Definition: trie.h:94
const_iterator begin() const
Definition: trie.h:139
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:431
bool operator==(const NestedTrieTopology &that) const
Definition: trie.h:185
const value_type & reference
Definition: trie.h:72
size_t operator()(const ParentLabel< L > &pl) const
Definition: trie.h:52
const int kNoTrieNodeId
Definition: trie.h:16
std::forward_iterator_tag iterator_category
Definition: trie.h:68
int Root() const
Definition: trie.h:395
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:229
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:43
std::istream & Read(std::istream &strm)
Definition: trie.h:291
constexpr bool operator!=(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2)
Definition: float-weight.h:119
const value_type * pointer
Definition: trie.h:71
void swap(NestedTrieTopology &that)
Definition: trie.h:172
const_iterator end() const
Definition: trie.h:299
size_t NumNodes() const
Definition: trie.h:287
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:294
int Root() const
Definition: trie.h:130
int Find(int parent, const L &label) const
Definition: trie.h:211
size_t NumNodes() const
Definition: trie.h:131
std::istream & Read(std::istream &strm)
Definition: trie.h:217
int Insert(int parent, const L &label)
Definition: trie.h:200
ParentLabel(int p, L l)
Definition: trie.h:31
const NextMap & ChildrenOf(int parent) const
Definition: trie.h:134
bool operator==(const const_iterator &that) const
Definition: trie.h:90
const_iterator end() const
Definition: trie.h:140
std::unordered_map< L, int, H > NextMap
Definition: trie.h:64
std::pair< ParentLabel< L >, int > value_type
Definition: trie.h:69
bool operator!=(const FlatTrieTopology &that) const
Definition: trie.h:282
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
int parent
Definition: trie.h:27
size_t NumNodes() const
Definition: trie.h:396
const T & TrieTopology() const
Definition: trie.h:413
int Find(int parent, const L &label) const
Definition: trie.h:409
std::istream & Read(std::istream &strm)
Definition: trie.h:426
int Insert(int parent, const L &label)
Definition: trie.h:401
void swap(FlatTrieTopology &that)
Definition: trie.h:277
bool operator==(const ParentLabel &that) const
Definition: trie.h:33
bool operator==(const FlatTrieTopology &that) const
Definition: trie.h:279
void swap(MutableTrie &that)
Definition: trie.h:390
bool operator==(const MutableTrie &that) const
Definition: trie.h:420
bool operator!=(const MutableTrie &that) const
Definition: trie.h:424
int Find(int parent, const L &label) const
Definition: trie.h:321