FST  openfst-1.8.3
OpenFst Library
trie.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_EXTENSIONS_LINEAR_TRIE_H_
19 #define FST_EXTENSIONS_LINEAR_TRIE_H_
20 
21 #include <cstddef>
22 #include <istream>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/compat.h>
30 #include <fst/util.h>
31 #include <unordered_map>
32 
33 namespace fst {
34 
35 inline constexpr int kNoTrieNodeId = -1;
36 
37 template <class L, class H>
39 // Forward declarations of all available trie topologies.
40 template <class L, class H>
42 
43 // A pair of parent node id and label, part of a trie edge
44 template <class L>
45 struct ParentLabel {
46  int parent;
47  L label;
48 
49  ParentLabel() = default;
50  ParentLabel(int p, L l) : parent(p), label(l) {}
51 
52  bool operator==(const ParentLabel &that) const {
53  return parent == that.parent && label == that.label;
54  }
55 
56  std::istream &Read(std::istream &strm) {
57  ReadType(strm, &parent);
58  ReadType(strm, &label);
59  return strm;
60  }
61 
62  std::ostream &Write(std::ostream &strm) const {
63  WriteType(strm, parent);
64  WriteType(strm, label);
65  return strm;
66  }
67 };
68 
69 template <class L, class H>
71  size_t operator()(const ParentLabel<L> &pl) const {
72  return static_cast<size_t>(pl.parent * 7853 + H()(pl.label));
73  }
74 };
75 
76 // The trie topology in a nested tree of hash maps; allows efficient
77 // iteration over children of a specific node.
78 template <class L, class H>
79 class NestedTrieTopology {
80  public:
81  typedef L Label;
82  typedef H Hash;
83  typedef std::unordered_map<L, int, H> NextMap;
84 
86  public:
87  typedef std::forward_iterator_tag iterator_category;
88  typedef std::pair<ParentLabel<L>, int> value_type;
89  typedef std::ptrdiff_t difference_type;
90  typedef const value_type *pointer;
91  typedef const value_type &reference;
92 
93  friend class NestedTrieTopology<L, H>;
94 
95  const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {}
96 
97  reference operator*() {
98  UpdateStub();
99  return stub_;
100  }
101  pointer operator->() {
102  UpdateStub();
103  return &stub_;
104  }
105 
106  const_iterator &operator++();
107  const_iterator &operator++(int);
108 
109  bool operator==(const const_iterator &that) const {
110  return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ &&
111  cur_edge_ == that.cur_edge_;
112  }
113  bool operator!=(const const_iterator &that) const {
114  return !(*this == that);
115  }
116 
117  private:
118  const_iterator(const NestedTrieTopology *ptr, int cur_node)
119  : ptr_(ptr), cur_node_(cur_node) {
120  SetProperCurEdge();
121  }
122 
123  void SetProperCurEdge() {
124  if (cur_node_ < ptr_->NumNodes())
125  cur_edge_ = ptr_->nodes_[cur_node_]->begin();
126  else
127  cur_edge_ = ptr_->nodes_[0]->begin();
128  }
129 
130  void UpdateStub() {
131  stub_.first = ParentLabel<L>(cur_node_, cur_edge_->first);
132  stub_.second = cur_edge_->second;
133  }
134 
135  const NestedTrieTopology *ptr_;
136  int cur_node_;
137  typename NextMap::const_iterator cur_edge_;
138  value_type stub_;
139  };
140 
143  ~NestedTrieTopology() = default;
144  void swap(NestedTrieTopology &that);
145  NestedTrieTopology &operator=(const NestedTrieTopology &that);
146  bool operator==(const NestedTrieTopology &that) const;
147  bool operator!=(const NestedTrieTopology &that) const;
148 
149  int Root() const { return 0; }
150  size_t NumNodes() const { return nodes_.size(); }
151  int Insert(int parent, const L &label);
152  int Find(int parent, const L &label) const;
153  const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; }
154 
155  std::istream &Read(std::istream &strm);
156  std::ostream &Write(std::ostream &strm) const;
157 
158  const_iterator begin() const { return const_iterator(this, 0); }
159  const_iterator end() const { return const_iterator(this, NumNodes()); }
160 
161  private:
162  // Use pointers to avoid copying the maps when the vector grows
163  std::vector<std::unique_ptr<NextMap>> nodes_;
164 };
165 
166 template <class L, class H>
168  nodes_.push_back(std::make_unique<NextMap>());
169 }
170 
171 template <class L, class H>
173  nodes_.reserve(that.nodes_.size());
174  for (const auto &node : that.nodes_) {
175  nodes_.push_back(std::make_unique<NextMap>(*node));
176  }
177 }
178 
179 // TODO(wuke): std::swap compatibility
180 template <class L, class H>
182  nodes_.swap(that.nodes_);
183 }
184 
185 template <class L, class H>
187  const NestedTrieTopology &that) {
188  NestedTrieTopology copy(that);
189  swap(copy);
190  return *this;
191 }
192 
193 template <class L, class H>
195  const NestedTrieTopology &that) const {
196  if (NumNodes() != that.NumNodes()) return false;
197  for (int i = 0; i < NumNodes(); ++i)
198  if (ChildrenOf(i) != that.ChildrenOf(i)) return false;
199  return true;
200 }
201 
202 template <class L, class H>
204  const NestedTrieTopology &that) const {
205  return !(*this == that);
206 }
207 
208 template <class L, class H>
209 inline int NestedTrieTopology<L, H>::Insert(int parent, const L &label) {
210  int ret = Find(parent, label);
211  if (ret == kNoTrieNodeId) {
212  ret = NumNodes();
213  (*nodes_[parent])[label] = ret;
214  nodes_.push_back(std::make_unique<NextMap>());
215  }
216  return ret;
217 }
218 
219 template <class L, class H>
220 inline int NestedTrieTopology<L, H>::Find(int parent, const L &label) const {
221  typename NextMap::const_iterator it = nodes_[parent]->find(label);
222  return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second;
223 }
224 
225 template <class L, class H>
226 inline std::istream &NestedTrieTopology<L, H>::Read(std::istream &strm) {
227  NestedTrieTopology new_trie;
228  size_t num_nodes;
229  if (!ReadType(strm, &num_nodes)) return strm;
230  for (size_t i = 1; i < num_nodes; ++i) {
231  new_trie.nodes_.push_back(std::make_unique<NextMap>());
232  }
233  for (size_t i = 0; i < num_nodes; ++i) {
234  ReadType(strm, new_trie.nodes_[i].get());
235  }
236  if (strm) swap(new_trie);
237  return strm;
238 }
239 
240 template <class L, class H>
241 inline std::ostream &NestedTrieTopology<L, H>::Write(std::ostream &strm) const {
242  WriteType(strm, NumNodes());
243  for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]);
244  return strm;
245 }
246 
247 template <class L, class H>
250  ++cur_edge_;
251  if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) {
252  ++cur_node_;
253  while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty())
254  ++cur_node_;
255  SetProperCurEdge();
256  }
257  return *this;
258 }
259 
260 template <class L, class H>
263  const_iterator save(*this);
264  ++(*this);
265  return save;
266 }
267 
268 // The trie topology in a single hash map; only allows iteration over
269 // all the edges in arbitrary order.
270 template <class L, class H>
271 class FlatTrieTopology {
272  private:
273  typedef std::unordered_map<ParentLabel<L>, int, ParentLabelHash<L, H>>
274  NextMap;
275 
276  public:
277  // Iterator over edges as std::pair<ParentLabel<L>, int>
278  typedef typename NextMap::const_iterator const_iterator;
279  typedef L Label;
280  typedef H Hash;
281 
282  FlatTrieTopology() = default;
283  FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {}
284  template <class T>
285  explicit FlatTrieTopology(const T &that);
286 
287  // TODO(wuke): std::swap compatibility
288  void swap(FlatTrieTopology &that) { next_.swap(that.next_); }
289 
290  bool operator==(const FlatTrieTopology &that) const {
291  return next_ == that.next_;
292  }
293  bool operator!=(const FlatTrieTopology &that) const {
294  return !(*this == that);
295  }
296 
297  int Root() const { return 0; }
298  size_t NumNodes() const { return next_.size() + 1; }
299  int Insert(int parent, const L &label);
300  int Find(int parent, const L &label) const;
301 
302  std::istream &Read(std::istream &strm) { return ReadType(strm, &next_); }
303  std::ostream &Write(std::ostream &strm) const {
304  return WriteType(strm, next_);
305  }
306 
307  const_iterator begin() const { return next_.begin(); }
308  const_iterator end() const { return next_.end(); }
309 
310  private:
311  NextMap next_;
312 };
313 
314 template <class L, class H>
315 template <class T>
317  : next_(that.begin(), that.end()) {}
318 
319 template <class L, class H>
320 inline int FlatTrieTopology<L, H>::Insert(int parent, const L &label) {
321  int ret = Find(parent, label);
322  if (ret == kNoTrieNodeId) {
323  ret = NumNodes();
324  next_[ParentLabel<L>(parent, label)] = ret;
325  }
326  return ret;
327 }
328 
329 template <class L, class H>
330 inline int FlatTrieTopology<L, H>::Find(int parent, const L &label) const {
331  typename NextMap::const_iterator it =
332  next_.find(ParentLabel<L>(parent, label));
333  return it == next_.end() ? kNoTrieNodeId : it->second;
334 }
335 
336 // A collection of implementations of the trie data structure. The key
337 // is a sequence of type `L` which must be hashable. The value is of
338 // `V` which must be default constructible and copyable. In addition,
339 // a value object is stored for each node in the trie therefore
340 // copying `V` should be cheap.
341 //
342 // One can access the store values with an integer node id, using the
343 // [] operator. A valid node id can be obtained by the following ways:
344 //
345 // 1. Using the `Root()` method to get the node id of the root.
346 //
347 // 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense
348 // so every integer in this range is a valid node id.
349 //
350 // 3. Using the node id returned from a successful `Insert()` or
351 // `Find()` call.
352 //
353 // 4. Iterating over the trie edges with an `EdgeIterator` and using
354 // the node ids returned from its `Parent()` and `Child()` methods.
355 //
356 // Below is an example of inserting keys into the trie:
357 //
358 // const string words[] = {"hello", "health", "jello"};
359 // Trie<char, bool> dict;
360 // for (auto word : words) {
361 // int cur = dict.Root();
362 // for (char c : word) {
363 // cur = dict.Insert(cur, c);
364 // }
365 // dict[cur] = true;
366 // }
367 //
368 // And the following is an example of looking up the longest prefix of
369 // a string using the trie constructed above:
370 //
371 // string query = "healed";
372 // size_t prefix_length = 0;
373 // int cur = dict.Find(dict.Root(), query[prefix_length]);
374 // while (prefix_length < query.size() &&
375 // cur != Trie<char, bool>::kNoNodeId) {
376 // ++prefix_length;
377 // cur = dict.Find(cur, query[prefix_length]);
378 // }
379 template <class L, class V, class T>
380 class MutableTrie {
381  public:
382  template <class LL, class VV, class TT>
383  friend class MutableTrie;
384 
385  typedef L Label;
386  typedef V Value;
387  typedef T Topology;
388 
389  // Constructs a trie with only the root node.
390  MutableTrie() = default;
391 
392  // Conversion from another trie of a possiblly different
393  // topology. The underlying topology must supported conversion.
394  template <class S>
395  explicit MutableTrie(const MutableTrie<L, V, S> &that)
396  : topology_(that.topology_), values_(that.values_) {}
397 
398  // TODO(wuke): std::swap compatibility
399  void swap(MutableTrie &that) {
400  topology_.swap(that.topology_);
401  values_.swap(that.values_);
402  }
403 
404  int Root() const { return topology_.Root(); }
405  size_t NumNodes() const { return topology_.NumNodes(); }
406 
407  // Inserts an edge with given `label` at node `parent`. Returns the
408  // child node id. If the node already exists, returns the node id
409  // right away.
410  int Insert(int parent, const L &label) {
411  int ret = topology_.Insert(parent, label);
412  values_.resize(NumNodes());
413  return ret;
414  }
415 
416  // Finds the node id of the node from `parent` via `label`. Returns
417  // `kNoTrieNodeId` when such a node does not exist.
418  int Find(int parent, const L &label) const {
419  return topology_.Find(parent, label);
420  }
421 
422  const T &TrieTopology() const { return topology_; }
423 
424  // Accesses the value stored for the given node.
425  V &operator[](int node_id) { return values_[node_id]; }
426  const V &operator[](int node_id) const { return values_[node_id]; }
427 
428  // Comparison by content
429  bool operator==(const MutableTrie &that) const {
430  return topology_ == that.topology_ && values_ == that.values_;
431  }
432 
433  bool operator!=(const MutableTrie &that) const { return !(*this == that); }
434 
435  std::istream &Read(std::istream &strm) {
436  ReadType(strm, &topology_);
437  ReadType(strm, &values_);
438  return strm;
439  }
440  std::ostream &Write(std::ostream &strm) const {
441  WriteType(strm, topology_);
442  WriteType(strm, values_);
443  return strm;
444  }
445 
446  private:
447  T topology_;
448  std::vector<V> values_;
449 };
450 
451 } // namespace fst
452 
453 #endif // FST_EXTENSIONS_LINEAR_TRIE_H_
int Insert(int parent, const L &label)
Definition: trie.h:320
const_iterator begin() const
Definition: trie.h:307
V & operator[](int node_id)
Definition: trie.h:425
int Root() const
Definition: trie.h:297
MutableTrie(const MutableTrie< L, V, S > &that)
Definition: trie.h:395
bool operator!=(const NestedTrieTopology &that) const
Definition: trie.h:203
const V & operator[](int node_id) const
Definition: trie.h:426
std::istream & Read(std::istream &strm)
Definition: trie.h:56
NestedTrieTopology & operator=(const NestedTrieTopology &that)
Definition: trie.h:186
const_iterator & operator++()
Definition: trie.h:249
NextMap::const_iterator const_iterator
Definition: trie.h:278
FlatTrieTopology(const FlatTrieTopology &that)
Definition: trie.h:283
FlatTrieTopology()=default
bool operator!=(const const_iterator &that) const
Definition: trie.h:113
const_iterator begin() const
Definition: trie.h:158
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:440
bool operator==(const NestedTrieTopology &that) const
Definition: trie.h:194
const value_type & reference
Definition: trie.h:91
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:54
size_t operator()(const ParentLabel< L > &pl) const
Definition: trie.h:71
constexpr int kNoTrieNodeId
Definition: trie.h:35
std::forward_iterator_tag iterator_category
Definition: trie.h:87
int Root() const
Definition: trie.h:404
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:241
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:62
ParentLabel()=default
std::istream & Read(std::istream &strm)
Definition: trie.h:302
const value_type * pointer
Definition: trie.h:90
void swap(NestedTrieTopology &that)
Definition: trie.h:181
const_iterator end() const
Definition: trie.h:308
size_t NumNodes() const
Definition: trie.h:298
std::ostream & Write(std::ostream &strm) const
Definition: trie.h:303
int Root() const
Definition: trie.h:149
int Find(int parent, const L &label) const
Definition: trie.h:220
size_t NumNodes() const
Definition: trie.h:150
std::istream & Read(std::istream &strm)
Definition: trie.h:226
int Insert(int parent, const L &label)
Definition: trie.h:209
ParentLabel(int p, L l)
Definition: trie.h:50
const NextMap & ChildrenOf(int parent) const
Definition: trie.h:153
bool operator==(const const_iterator &that) const
Definition: trie.h:109
const_iterator end() const
Definition: trie.h:159
std::unordered_map< L, int, H > NextMap
Definition: trie.h:83
std::pair< ParentLabel< L >, int > value_type
Definition: trie.h:88
bool operator!=(const FlatTrieTopology &that) const
Definition: trie.h:293
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
int parent
Definition: trie.h:46
size_t NumNodes() const
Definition: trie.h:405
const T & TrieTopology() const
Definition: trie.h:422
int Find(int parent, const L &label) const
Definition: trie.h:418
std::istream & Read(std::istream &strm)
Definition: trie.h:435
int Insert(int parent, const L &label)
Definition: trie.h:410
void swap(FlatTrieTopology &that)
Definition: trie.h:288
bool operator==(const ParentLabel &that) const
Definition: trie.h:52
bool operator==(const FlatTrieTopology &that) const
Definition: trie.h:290
void swap(MutableTrie &that)
Definition: trie.h:399
bool operator==(const MutableTrie &that) const
Definition: trie.h:429
bool operator!=(const MutableTrie &that) const
Definition: trie.h:433
int Find(int parent, const L &label) const
Definition: trie.h:330