FST  openfst-1.8.2.post1
OpenFst Library
trie.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_EXTENSIONS_LINEAR_TRIE_H_
19 #define FST_EXTENSIONS_LINEAR_TRIE_H_
20 
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include <fst/compat.h>
26 #include <fst/util.h>
27 #include <unordered_map>
28 
29 namespace fst {
30 
31 inline constexpr int kNoTrieNodeId = -1;
32 
33 // Forward declarations of all available trie topologies.
34 template <class L, class H>
36 template <class L, class H>
38 
39 // A pair of parent node id and label, part of a trie edge
40 template <class L>
41 struct ParentLabel {
42  int parent;
43  L label;
44 
46  ParentLabel(int p, L l) : parent(p), label(l) {}
47 
48  bool operator==(const ParentLabel &that) const {
49  return parent == that.parent && label == that.label;
50  }
51 
52  std::istream &Read(std::istream &strm) {
53  ReadType(strm, &parent);
54  ReadType(strm, &label);
55  return strm;
56  }
57 
58  std::ostream &Write(std::ostream &strm) const {
59  WriteType(strm, parent);
60  WriteType(strm, label);
61  return strm;
62  }
63 };
64 
65 template <class L, class H>
67  size_t operator()(const ParentLabel<L> &pl) const {
68  return static_cast<size_t>(pl.parent * 7853 + H()(pl.label));
69  }
70 };
71 
72 // The trie topology in a nested tree of hash maps; allows efficient
73 // iteration over children of a specific node.
74 template <class L, class H>
75 class NestedTrieTopology {
76  public:
77  typedef L Label;
78  typedef H Hash;
79  typedef std::unordered_map<L, int, H> NextMap;
80 
82  public:
83  typedef std::forward_iterator_tag iterator_category;
84  typedef std::pair<ParentLabel<L>, int> value_type;
85  typedef std::ptrdiff_t difference_type;
86  typedef const value_type *pointer;
87  typedef const value_type &reference;
88 
89  friend class NestedTrieTopology<L, H>;
90 
91  const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {}
92 
93  reference operator*() {
94  UpdateStub();
95  return stub_;
96  }
97  pointer operator->() {
98  UpdateStub();
99  return &stub_;
100  }
101 
102  const_iterator &operator++();
103  const_iterator &operator++(int);
104 
105  bool operator==(const const_iterator &that) const {
106  return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ &&
107  cur_edge_ == that.cur_edge_;
108  }
109  bool operator!=(const const_iterator &that) const {
110  return !(*this == that);
111  }
112 
113  private:
114  const_iterator(const NestedTrieTopology *ptr, int cur_node)
115  : ptr_(ptr), cur_node_(cur_node) {
116  SetProperCurEdge();
117  }
118 
119  void SetProperCurEdge() {
120  if (cur_node_ < ptr_->NumNodes())
121  cur_edge_ = ptr_->nodes_[cur_node_]->begin();
122  else
123  cur_edge_ = ptr_->nodes_[0]->begin();
124  }
125 
126  void UpdateStub() {
127  stub_.first = ParentLabel<L>(cur_node_, cur_edge_->first);
128  stub_.second = cur_edge_->second;
129  }
130 
131  const NestedTrieTopology *ptr_;
132  int cur_node_;
133  typename NextMap::const_iterator cur_edge_;
134  value_type stub_;
135  };
136 
139  ~NestedTrieTopology() = default;
140  void swap(NestedTrieTopology &that);
141  NestedTrieTopology &operator=(const NestedTrieTopology &that);
142  bool operator==(const NestedTrieTopology &that) const;
143  bool operator!=(const NestedTrieTopology &that) const;
144 
145  int Root() const { return 0; }
146  size_t NumNodes() const { return nodes_.size(); }
147  int Insert(int parent, const L &label);
148  int Find(int parent, const L &label) const;
149  const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; }
150 
151  std::istream &Read(std::istream &strm);
152  std::ostream &Write(std::ostream &strm) const;
153 
154  const_iterator begin() const { return const_iterator(this, 0); }
155  const_iterator end() const { return const_iterator(this, NumNodes()); }
156 
157  private:
158  // Use pointers to avoid copying the maps when the vector grows
159  std::vector<std::unique_ptr<NextMap>> nodes_;
160 };
161 
162 template <class L, class H>
164  nodes_.push_back(std::make_unique<NextMap>());
165 }
166 
167 template <class L, class H>
169  nodes_.reserve(that.nodes_.size());
170  for (const auto &node : that.nodes_) {
171  nodes_.push_back(std::make_unique<NextMap>(*node));
172  }
173 }
174 
175 // TODO(wuke): std::swap compatibility
176 template <class L, class H>
178  nodes_.swap(that.nodes_);
179 }
180 
181 template <class L, class H>
183  const NestedTrieTopology &that) {
184  NestedTrieTopology copy(that);
185  swap(copy);
186  return *this;
187 }
188 
189 template <class L, class H>
191  const NestedTrieTopology &that) const {
192  if (NumNodes() != that.NumNodes()) return false;
193  for (int i = 0; i < NumNodes(); ++i)
194  if (ChildrenOf(i) != that.ChildrenOf(i)) return false;
195  return true;
196 }
197 
198 template <class L, class H>
200  const NestedTrieTopology &that) const {
201  return !(*this == that);
202 }
203 
204 template <class L, class H>
205 inline int NestedTrieTopology<L, H>::Insert(int parent, const L &label) {
206  int ret = Find(parent, label);
207  if (ret == kNoTrieNodeId) {
208  ret = NumNodes();
209  (*nodes_[parent])[label] = ret;
210  nodes_.push_back(std::make_unique<NextMap>());
211  }
212  return ret;
213 }
214 
215 template <class L, class H>
216 inline int NestedTrieTopology<L, H>::Find(int parent, const L &label) const {
217  typename NextMap::const_iterator it = nodes_[parent]->find(label);
218  return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second;
219 }
220 
221 template <class L, class H>
222 inline std::istream &NestedTrieTopology<L, H>::Read(std::istream &strm) {
223  NestedTrieTopology new_trie;
224  size_t num_nodes;
225  if (!ReadType(strm, &num_nodes)) return strm;
226  for (size_t i = 1; i < num_nodes; ++i) {
227  new_trie.nodes_.push_back(std::make_unique<NextMap>());
228  }
229  for (size_t i = 0; i < num_nodes; ++i) {
230  ReadType(strm, new_trie.nodes_[i].get());
231  }
232  if (strm) swap(new_trie);
233  return strm;
234 }
235 
236 template <class L, class H>
237 inline std::ostream &NestedTrieTopology<L, H>::Write(std::ostream &strm) const {
238  WriteType(strm, NumNodes());
239  for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]);
240  return strm;
241 }
242 
243 template <class L, class H>
246  ++cur_edge_;
247  if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) {
248  ++cur_node_;
249  while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty())
250  ++cur_node_;
251  SetProperCurEdge();
252  }
253  return *this;
254 }
255 
256 template <class L, class H>
259  const_iterator save(*this);
260  ++(*this);
261  return save;
262 }
263 
264 // The trie topology in a single hash map; only allows iteration over
265 // all the edges in arbitrary order.
266 template <class L, class H>
267 class FlatTrieTopology {
268  private:
269  typedef std::unordered_map<ParentLabel<L>, int, ParentLabelHash<L, H>>
270  NextMap;
271 
272  public:
273  // Iterator over edges as std::pair<ParentLabel<L>, int>
274  typedef typename NextMap::const_iterator const_iterator;
275  typedef L Label;
276  typedef H Hash;
277 
279  FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {}
280  template <class T>
281  explicit FlatTrieTopology(const T &that);
282 
283  // TODO(wuke): std::swap compatibility
284  void swap(FlatTrieTopology &that) { next_.swap(that.next_); }
285 
286  bool operator==(const FlatTrieTopology &that) const {
287  return next_ == that.next_;
288  }
289  bool operator!=(const FlatTrieTopology &that) const {
290  return !(*this == that);
291  }
292 
293  int Root() const { return 0; }
294  size_t NumNodes() const { return next_.size() + 1; }
295  int Insert(int parent, const L &label);
296  int Find(int parent, const L &label) const;
297 
298  std::istream &Read(std::istream &strm) { return ReadType(strm, &next_); }
299  std::ostream &Write(std::ostream &strm) const {
300  return WriteType(strm, next_);
301  }
302 
303  const_iterator begin() const { return next_.begin(); }
304  const_iterator end() const { return next_.end(); }
305 
306  private:
307  NextMap next_;
308 };
309 
310 template <class L, class H>
311 template <class T>
313  : next_(that.begin(), that.end()) {}
314 
315 template <class L, class H>
316 inline int FlatTrieTopology<L, H>::Insert(int parent, const L &label) {
317  int ret = Find(parent, label);
318  if (ret == kNoTrieNodeId) {
319  ret = NumNodes();
320  next_[ParentLabel<L>(parent, label)] = ret;
321  }
322  return ret;
323 }
324 
325 template <class L, class H>
326 inline int FlatTrieTopology<L, H>::Find(int parent, const L &label) const {
327  typename NextMap::const_iterator it =
328  next_.find(ParentLabel<L>(parent, label));
329  return it == next_.end() ? kNoTrieNodeId : it->second;
330 }
331 
332 // A collection of implementations of the trie data structure. The key
333 // is a sequence of type `L` which must be hashable. The value is of
334 // `V` which must be default constructible and copyable. In addition,
335 // a value object is stored for each node in the trie therefore
336 // copying `V` should be cheap.
337 //
338 // One can access the store values with an integer node id, using the
339 // [] operator. A valid node id can be obtained by the following ways:
340 //
341 // 1. Using the `Root()` method to get the node id of the root.
342 //
343 // 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense
344 // so every integer in this range is a valid node id.
345 //
346 // 3. Using the node id returned from a successful `Insert()` or
347 // `Find()` call.
348 //
349 // 4. Iterating over the trie edges with an `EdgeIterator` and using
350 // the node ids returned from its `Parent()` and `Child()` methods.
351 //
352 // Below is an example of inserting keys into the trie:
353 //
354 // const string words[] = {"hello", "health", "jello"};
355 // Trie<char, bool> dict;
356 // for (auto word : words) {
357 // int cur = dict.Root();
358 // for (char c : word) {
359 // cur = dict.Insert(cur, c);
360 // }
361 // dict[cur] = true;
362 // }
363 //
364 // And the following is an example of looking up the longest prefix of
365 // a string using the trie constructed above:
366 //
367 // string query = "healed";
368 // size_t prefix_length = 0;
369 // int cur = dict.Find(dict.Root(), query[prefix_length]);
370 // while (prefix_length < query.size() &&
371 // cur != Trie<char, bool>::kNoNodeId) {
372 // ++prefix_length;
373 // cur = dict.Find(cur, query[prefix_length]);
374 // }
375 template <class L, class V, class T>
376 class MutableTrie {
377  public:
378  template <class LL, class VV, class TT>
379  friend class MutableTrie;
380 
381  typedef L Label;
382  typedef V Value;
383  typedef T Topology;
384 
385  // Constructs a trie with only the root node.
387 
388  // Conversion from another trie of a possiblly different
389  // topology. The underlying topology must supported conversion.
390  template <class S>
391  explicit MutableTrie(const MutableTrie<L, V, S> &that)
392  : topology_(that.topology_), values_(that.values_) {}
393 
394  // TODO(wuke): std::swap compatibility
395  void swap(MutableTrie &that) {
396  topology_.swap(that.topology_);
397  values_.swap(that.values_);
398  }
399 
400  int Root() const { return topology_.Root(); }
401  size_t NumNodes() const { return topology_.NumNodes(); }
402 
403  // Inserts an edge with given `label` at node `parent`. Returns the
404  // child node id. If the node already exists, returns the node id
405  // right away.
406  int Insert(int parent, const L &label) {
407  int ret = topology_.Insert(parent, label);
408  values_.resize(NumNodes());
409  return ret;
410  }
411 
412  // Finds the node id of the node from `parent` via `label`. Returns
413  // `kNoTrieNodeId` when such a node does not exist.
414  int Find(int parent, const L &label) const {
415  return topology_.Find(parent, label);
416  }
417 
418  const T &TrieTopology() const { return topology_; }
419 
420  // Accesses the value stored for the given node.
421  V &operator[](int node_id) { return values_[node_id]; }
422  const V &operator[](int node_id) const { return values_[node_id]; }
423 
424  // Comparison by content
425  bool operator==(const MutableTrie &that) const {
426  return topology_ == that.topology_ && values_ == that.values_;
427  }
428 
429  bool operator!=(const MutableTrie &that) const { return !(*this == that); }
430 
431  std::istream &Read(std::istream &strm) {
432  ReadType(strm, &topology_);
433  ReadType(strm, &values_);
434  return strm;
435  }
436  std::ostream &Write(std::ostream &strm) const {
437  WriteType(strm, topology_);
438  WriteType(strm, values_);
439  return strm;
440  }
441 
442  private:
443  T topology_;
444  std::vector<V> values_;
445 };
446 
447 } // namespace fst
448 
449 #endif // FST_EXTENSIONS_LINEAR_TRIE_H_
int Insert(int parent, const L &label)
Definition: trie.h:316
const_iterator begin() const
Definition: trie.h:303
V & operator[](int node_id)
Definition: trie.h:421
int Root() const
Definition: trie.h:293
MutableTrie(const MutableTrie< L, V, S > &that)
Definition: trie.h:391
bool operator!=(const NestedTrieTopology &that) const
Definition: trie.h:199
const V & operator[](int node_id) const
Definition: trie.h:422
std::istream & Read(std::istream &strm)
Definition: trie.h:52
NestedTrieTopology & operator=(const NestedTrieTopology &that)
Definition: trie.h:182
const_iterator & operator++()
Definition: trie.h:245
NextMap::const_iterator const_iterator
Definition: trie.h:274
FlatTrieTopology(const FlatTrieTopology &that)
Definition: trie.h:279
bool operator!=(const const_iterator &that) const
Definition: trie.h:109
const_iterator begin() const
Definition: trie.h:154
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:214
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:436
bool operator==(const NestedTrieTopology &that) const
Definition: trie.h:190
const value_type & reference
Definition: trie.h:87
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:53
size_t operator()(const ParentLabel< L > &pl) const
Definition: trie.h:67
constexpr int kNoTrieNodeId
Definition: trie.h:31
std::forward_iterator_tag iterator_category
Definition: trie.h:83
int Root() const
Definition: trie.h:400
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:237
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:58
std::istream & Read(std::istream &strm)
Definition: trie.h:298
const value_type * pointer
Definition: trie.h:86
void swap(NestedTrieTopology &that)
Definition: trie.h:177
const_iterator end() const
Definition: trie.h:304
size_t NumNodes() const
Definition: trie.h:294
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:299
int Root() const
Definition: trie.h:145
int Find(int parent, const L &label) const
Definition: trie.h:216
size_t NumNodes() const
Definition: trie.h:146
std::istream & Read(std::istream &strm)
Definition: trie.h:222
int Insert(int parent, const L &label)
Definition: trie.h:205
ParentLabel(int p, L l)
Definition: trie.h:46
const NextMap & ChildrenOf(int parent) const
Definition: trie.h:149
bool operator==(const const_iterator &that) const
Definition: trie.h:105
const_iterator end() const
Definition: trie.h:155
std::unordered_map< L, int, H > NextMap
Definition: trie.h:79
std::pair< ParentLabel< L >, int > value_type
Definition: trie.h:84
bool operator!=(const FlatTrieTopology &that) const
Definition: trie.h:289
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:68
int parent
Definition: trie.h:42
size_t NumNodes() const
Definition: trie.h:401
const T & TrieTopology() const
Definition: trie.h:418
int Find(int parent, const L &label) const
Definition: trie.h:414
std::istream & Read(std::istream &strm)
Definition: trie.h:431
int Insert(int parent, const L &label)
Definition: trie.h:406
void swap(FlatTrieTopology &that)
Definition: trie.h:284
bool operator==(const ParentLabel &that) const
Definition: trie.h:48
bool operator==(const FlatTrieTopology &that) const
Definition: trie.h:286
void swap(MutableTrie &that)
Definition: trie.h:395
bool operator==(const MutableTrie &that) const
Definition: trie.h:425
bool operator!=(const MutableTrie &that) const
Definition: trie.h:429
int Find(int parent, const L &label) const
Definition: trie.h:326