FST  openfst-1.7.1
OpenFst Library
ngram-fst.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // NgramFst implements a n-gram language model based upon the LOUDS data
5 // structure. Please refer to "Unary Data Structures for Language Models"
6 // http://research.google.com/pubs/archive/37218.pdf
7 
8 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
9 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
10 
11 #include <stddef.h>
12 #include <string.h>
13 #include <algorithm>
14 #include <iostream>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include <fst/compat.h>
20 #include <fst/log.h>
21 #include <fstream>
23 #include <fst/fstlib.h>
24 #include <fst/mapped-file.h>
25 
26 namespace fst {
27 template <class A>
28 class NGramFst;
29 template <class A>
31 
32 // Instance data containing mutable state for bookkeeping repeated access to
33 // the same state.
34 template <class A>
35 struct NGramFstInst {
36  typedef typename A::Label Label;
37  typedef typename A::StateId StateId;
38  typedef typename A::Weight Weight;
39  StateId state_;
40  size_t num_futures_;
41  size_t offset_;
42  size_t node_;
43  StateId node_state_;
44  std::vector<Label> context_;
45  StateId context_state_;
47  : state_(kNoStateId),
48  node_state_(kNoStateId),
49  context_state_(kNoStateId) {}
50 };
51 
52 namespace internal {
53 
54 // Implementation class for LOUDS based NgramFst interface.
55 template <class A>
56 class NGramFstImpl : public FstImpl<A> {
59  using FstImpl<A>::SetType;
61 
62  friend class ArcIterator<NGramFst<A>>;
63  friend class NGramFstMatcher<A>;
64 
65  public:
69 
70  typedef A Arc;
71  typedef typename A::Label Label;
72  typedef typename A::StateId StateId;
73  typedef typename A::Weight Weight;
74 
76  SetType("ngram");
77  SetInputSymbols(nullptr);
78  SetOutputSymbols(nullptr);
79  SetProperties(kStaticProperties);
80  }
81 
82  NGramFstImpl(const Fst<A> &fst, std::vector<StateId> *order_out);
83 
84  explicit NGramFstImpl(const Fst<A> &fst) : NGramFstImpl(fst, nullptr) {}
85 
86  NGramFstImpl(const NGramFstImpl &other) {
87  FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false.";
88  SetProperties(kError, kError);
89  }
90 
91  ~NGramFstImpl() override {
92  if (owned_) {
93  delete[] data_;
94  }
95  }
96 
97  static NGramFstImpl<A> *Read(std::istream &strm, // NOLINT
98  const FstReadOptions &opts) {
99  NGramFstImpl<A> *impl = new NGramFstImpl();
100  FstHeader hdr;
101  if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
102  uint64 num_states, num_futures, num_final;
103  const size_t offset =
104  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
105  // Peek at num_states and num_futures to see how much more needs to be read.
106  strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
107  strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
108  strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
109  size_t size = Storage(num_states, num_futures, num_final);
110  MappedFile *data_region = MappedFile::Allocate(size);
111  char *data = reinterpret_cast<char *>(data_region->mutable_data());
112  // Copy num_states, num_futures and num_final back into data.
113  memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
114  memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
115  sizeof(num_futures));
116  memcpy(data + sizeof(num_states) + sizeof(num_futures),
117  reinterpret_cast<char *>(&num_final), sizeof(num_final));
118  strm.read(data + offset, size - offset);
119  if (strm.fail()) {
120  delete impl;
121  return nullptr;
122  }
123  impl->Init(data, false, data_region);
124  return impl;
125  }
126 
127  bool Write(std::ostream &strm, // NOLINT
128  const FstWriteOptions &opts) const {
129  FstHeader hdr;
130  hdr.SetStart(Start());
131  hdr.SetNumStates(num_states_);
132  WriteHeader(strm, opts, kFileVersion, &hdr);
133  strm.write(data_, StorageSize());
134  return !strm.fail();
135  }
136 
137  StateId Start() const { return start_; }
138 
139  Weight Final(StateId state) const {
140  if (final_index_.Get(state)) {
141  return final_probs_[final_index_.Rank1(state)];
142  } else {
143  return Weight::Zero();
144  }
145  }
146 
147  size_t NumArcs(StateId state, NGramFstInst<A> *inst = nullptr) const {
148  if (inst == nullptr) {
149  const std::pair<size_t, size_t> zeros =
150  (state == 0) ? select_root_ : future_index_.Select0s(state);
151  return zeros.second - zeros.first - 1;
152  }
153  SetInstFuture(state, inst);
154  return inst->num_futures_ + ((state == 0) ? 0 : 1);
155  }
156 
157  size_t NumInputEpsilons(StateId state) const {
158  // State 0 has no parent, thus no backoff.
159  if (state == 0) return 0;
160  return 1;
161  }
162 
163  size_t NumOutputEpsilons(StateId state) const {
164  return NumInputEpsilons(state);
165  }
166 
167  StateId NumStates() const { return num_states_; }
168 
170  data->base = 0;
171  data->nstates = num_states_;
172  }
173 
174  static size_t Storage(uint64 num_states, uint64 num_futures,
175  uint64 num_final) {
176  uint64 b64;
177  Weight weight;
178  Label label;
179  size_t offset =
180  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
181  offset +=
182  sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) +
183  BitmapIndex::StorageSize(num_futures + num_states + 1) +
184  BitmapIndex::StorageSize(num_states));
185  offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
186  // Pad for alignemnt, see
187  // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
188  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
189  offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
190  (num_futures + 1) * sizeof(weight);
191  return offset;
192  }
193 
194  void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
195  if (inst->state_ != state) {
196  inst->state_ = state;
197  const std::pair<size_t, size_t> zeros = future_index_.Select0s(state);
198  inst->num_futures_ = zeros.second - zeros.first - 1;
199  inst->offset_ = future_index_.Rank1(zeros.first + 1);
200  }
201  }
202 
203  void SetInstNode(NGramFstInst<A> *inst) const {
204  if (inst->node_state_ != inst->state_) {
205  inst->node_state_ = inst->state_;
206  inst->node_ = context_index_.Select1(inst->state_);
207  }
208  }
209 
210  void SetInstContext(NGramFstInst<A> *inst) const {
211  SetInstNode(inst);
212  if (inst->context_state_ != inst->state_) {
213  inst->context_state_ = inst->state_;
214  inst->context_.clear();
215  size_t node = inst->node_;
216  while (node != 0) {
217  inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
218  node = context_index_.Select1(context_index_.Rank0(node) - 1);
219  }
220  }
221  }
222 
223  // Access to the underlying representation
224  const char *GetData(size_t *data_size) const {
225  *data_size = StorageSize();
226  return data_;
227  }
228 
229  void Init(const char *data, bool owned, MappedFile *file = nullptr);
230 
231  const std::vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
232  SetInstFuture(s, inst);
233  SetInstContext(inst);
234  return inst->context_;
235  }
236 
237  size_t StorageSize() const {
238  return Storage(num_states_, num_futures_, num_final_);
239  }
240 
241  void GetStates(const std::vector<Label> &context,
242  std::vector<StateId> *states) const;
243 
244  private:
245  StateId Transition(const std::vector<Label> &context, Label future) const;
246 
247  // Properties always true for this Fst class.
248  static const uint64 kStaticProperties =
253  // Current file format version.
254  static const int kFileVersion = 4;
255  // Minimum file format version supported.
256  static const int kMinFileVersion = 4;
257 
258  std::unique_ptr<MappedFile> data_region_;
259  const char *data_ = nullptr;
260  bool owned_ = false; // True if we own data_
261  StateId start_ = fst::kNoStateId;
262  uint64 num_states_ = 0;
263  uint64 num_futures_ = 0;
264  uint64 num_final_ = 0;
265  std::pair<size_t, size_t> select_root_;
266  const Label *root_children_ = nullptr;
267  // borrowed references
268  const uint64 *context_ = nullptr;
269  const uint64 *future_ = nullptr;
270  const uint64 *final_ = nullptr;
271  const Label *context_words_ = nullptr;
272  const Label *future_words_ = nullptr;
273  const Weight *backoff_ = nullptr;
274  const Weight *final_probs_ = nullptr;
275  const Weight *future_probs_ = nullptr;
276  BitmapIndex context_index_;
277  BitmapIndex future_index_;
278  BitmapIndex final_index_;
279 };
280 
281 template <typename A>
283  const std::vector<Label> &context,
284  std::vector<typename A::StateId> *states) const {
285  states->clear();
286  states->push_back(0);
287  typename std::vector<Label>::const_reverse_iterator cit = context.rbegin();
288  const Label *children = root_children_;
289  size_t num_children = select_root_.second - 2;
290  const Label *loc = std::lower_bound(children, children + num_children, *cit);
291  if (loc == children + num_children || *loc != *cit) return;
292  size_t node = 2 + loc - children;
293  states->push_back(context_index_.Rank1(node));
294  if (context.size() == 1) return;
295  size_t node_rank = context_index_.Rank1(node);
296  std::pair<size_t, size_t> zeros =
297  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
298  size_t first_child = zeros.first + 1;
299  ++cit;
300  if (context_index_.Get(first_child) != false) {
301  size_t last_child = zeros.second - 1;
302  while (cit != context.rend()) {
303  children = context_words_ + context_index_.Rank1(first_child);
304  loc = std::lower_bound(children, children + last_child - first_child + 1,
305  *cit);
306  if (loc == children + last_child - first_child + 1 || *loc != *cit) {
307  break;
308  }
309  ++cit;
310  node = first_child + loc - children;
311  states->push_back(context_index_.Rank1(node));
312  node_rank = context_index_.Rank1(node);
313  zeros =
314  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
315  first_child = zeros.first + 1;
316  if (context_index_.Get(first_child) == false) break;
317  last_child = zeros.second - 1;
318  }
319  }
320 }
321 
322 } // namespace internal
323 
324 /*****************************************************************************/
325 template <class A>
326 class NGramFst : public ImplToExpandedFst<internal::NGramFstImpl<A>> {
327  friend class ArcIterator<NGramFst<A>>;
328  friend class NGramFstMatcher<A>;
329 
330  public:
331  typedef A Arc;
332  typedef typename A::StateId StateId;
333  typedef typename A::Label Label;
334  typedef typename A::Weight Weight;
336 
337  explicit NGramFst(const Fst<A> &dst)
338  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(dst, nullptr)) {}
339 
340  NGramFst(const Fst<A> &fst, std::vector<StateId> *order_out)
341  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst, order_out)) {}
342 
343  // Because the NGramFstImpl is a const stateless data structure, there
344  // is never a need to do anything beside copy the reference.
345  NGramFst(const NGramFst<A> &fst, bool safe = false)
346  : ImplToExpandedFst<Impl>(fst, false) {}
347 
348  NGramFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
349 
350  // Non-standard constructor to initialize NGramFst directly from data.
351  NGramFst(const char *data, bool owned)
352  : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {
353  GetMutableImpl()->Init(data, owned, nullptr);
354  }
355 
356  // Get method that gets the data associated with Init().
357  const char *GetData(size_t *data_size) const {
358  return GetImpl()->GetData(data_size);
359  }
360 
361  const std::vector<Label> GetContext(StateId s) const {
362  return GetImpl()->GetContext(s, &inst_);
363  }
364 
365  // Consumes as much as possible of context from right to left, returns the
366  // the states corresponding to the increasingly conditioned input sequence.
367  void GetStates(const std::vector<Label> &context,
368  std::vector<StateId> *state) const {
369  return GetImpl()->GetStates(context, state);
370  }
371 
372  size_t NumArcs(StateId s) const override {
373  return GetImpl()->NumArcs(s, &inst_);
374  }
375 
376  NGramFst<A> *Copy(bool safe = false) const override {
377  return new NGramFst(*this, safe);
378  }
379 
380  static NGramFst<A> *Read(std::istream &strm, const FstReadOptions &opts) {
381  Impl *impl = Impl::Read(strm, opts);
382  return impl ? new NGramFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
383  }
384 
385  static NGramFst<A> *Read(const string &filename) {
386  if (!filename.empty()) {
387  std::ifstream strm(filename,
388  std::ios_base::in | std::ios_base::binary);
389  if (!strm.good()) {
390  LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
391  return nullptr;
392  }
393  return Read(strm, FstReadOptions(filename));
394  } else {
395  return Read(std::cin, FstReadOptions("standard input"));
396  }
397  }
398 
399  bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
400  return GetImpl()->Write(strm, opts);
401  }
402 
403  bool Write(const string &filename) const override {
404  return Fst<A>::WriteFile(filename);
405  }
406 
407  inline void InitStateIterator(StateIteratorData<A> *data) const override {
408  GetImpl()->InitStateIterator(data);
409  }
410 
411  inline void InitArcIterator(StateId s,
412  ArcIteratorData<A> *data) const override;
413 
414  MatcherBase<A> *InitMatcher(MatchType match_type) const override {
415  return new NGramFstMatcher<A>(this, match_type);
416  }
417 
418  size_t StorageSize() const { return GetImpl()->StorageSize(); }
419 
420  static bool HasRequiredProps(const Fst<A> &fst) {
421  static const auto props =
423  return fst.Properties(props, true) == props;
424  }
425 
426  static bool HasRequiredStructure(const Fst<A> &fst) {
427  if (!HasRequiredProps(fst)) {
428  return false;
429  }
430  typename A::StateId unigram = fst.Start();
431  while (true) { // Follows epsilon arc chain to find unigram state.
432  if (unigram == fst::kNoStateId) return false; // No unigram state.
433  typename fst::ArcIterator<Fst<A>> aiter(fst, unigram);
434  if (aiter.Done() || aiter.Value().ilabel != 0) break;
435  unigram = aiter.Value().nextstate;
436  aiter.Next();
437  }
438  // Other requirement: all states other than unigram an epsilon arc.
439  for (fst::StateIterator<Fst<A>> siter(fst); !siter.Done();
440  siter.Next()) {
441  const typename A::StateId &state = siter.Value();
442  fst::ArcIterator<Fst<A>> aiter(fst, state);
443  if (state != unigram) {
444  if (aiter.Done()) return false;
445  if (aiter.Value().ilabel != 0) return false;
446  aiter.Next();
447  if (!aiter.Done() && aiter.Value().ilabel == 0) return false;
448  }
449  }
450  return true;
451  }
452 
453  private:
455  using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetMutableImpl;
456 
457  explicit NGramFst(std::shared_ptr<Impl> impl)
458  : ImplToExpandedFst<Impl>(impl) {}
459 
460  mutable NGramFstInst<A> inst_;
461 };
462 
463 template <class A>
465  ArcIteratorData<A> *data) const {
466  GetImpl()->SetInstFuture(s, &inst_);
467  GetImpl()->SetInstNode(&inst_);
468  data->base = new ArcIterator<NGramFst<A>>(*this, s);
469 }
470 
471 namespace internal {
472 
473 template <typename A>
475  std::vector<StateId> *order_out) {
476  typedef A Arc;
477  typedef typename Arc::Label Label;
478  typedef typename Arc::Weight Weight;
479  typedef typename Arc::StateId StateId;
480  SetType("ngram");
481  SetInputSymbols(fst.InputSymbols());
482  SetOutputSymbols(fst.OutputSymbols());
483  SetProperties(kStaticProperties);
484 
485  // Check basic requirements for an OpenGrm language model Fst.
486  if (!NGramFst<A>::HasRequiredProps(fst)) {
487  FSTERROR() << "NGramFst only accepts OpenGrm language models as input";
488  SetProperties(kError, kError);
489  return;
490  }
491 
492  int64 num_states = CountStates(fst);
493  Label *context = new Label[num_states];
494 
495  // Find the unigram state by starting from the start state, following
496  // epsilons.
497  StateId unigram = fst.Start();
498  while (1) {
499  if (unigram == kNoStateId) {
500  FSTERROR() << "Could not identify unigram state";
501  SetProperties(kError, kError);
502  return;
503  }
504  ArcIterator<Fst<A>> aiter(fst, unigram);
505  if (aiter.Done()) {
506  LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
507  break;
508  }
509  if (aiter.Value().ilabel != 0) break;
510  unigram = aiter.Value().nextstate;
511  }
512 
513  // Each state's context is determined by the subtree it is under from the
514  // unigram state.
515  std::queue<std::pair<StateId, Label>> label_queue;
516  std::vector<bool> visited(num_states);
517  // Force an epsilon link to the start state.
518  label_queue.push(std::make_pair(fst.Start(), 0));
519  for (ArcIterator<Fst<A>> aiter(fst, unigram); !aiter.Done(); aiter.Next()) {
520  label_queue.push(
521  std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
522  }
523  // investigate states in breadth first fashion to assign context words.
524  while (!label_queue.empty()) {
525  std::pair<StateId, Label> &now = label_queue.front();
526  if (!visited[now.first]) {
527  context[now.first] = now.second;
528  visited[now.first] = true;
529  for (ArcIterator<Fst<A>> aiter(fst, now.first); !aiter.Done();
530  aiter.Next()) {
531  const Arc &arc = aiter.Value();
532  if (arc.ilabel != 0) {
533  label_queue.push(std::make_pair(arc.nextstate, now.second));
534  }
535  }
536  }
537  label_queue.pop();
538  }
539  visited.clear();
540 
541  // The arc from the start state should be assigned an epsilon to put it
542  // in front of the all other labels (which makes Start state 1 after
543  // unigram which is state 0).
544  context[fst.Start()] = 0;
545 
546  // Build the tree of contexts fst by reversing the epsilon arcs from fst.
547  VectorFst<Arc> context_fst;
548  uint64 num_final = 0;
549  for (int i = 0; i < num_states; ++i) {
550  if (fst.Final(i) != Weight::Zero()) {
551  ++num_final;
552  }
553  context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
554  }
555  context_fst.SetStart(unigram);
556  context_fst.SetInputSymbols(fst.InputSymbols());
557  context_fst.SetOutputSymbols(fst.OutputSymbols());
558  int64 num_context_arcs = 0;
559  int64 num_futures = 0;
560  for (StateIterator<Fst<A>> siter(fst); !siter.Done(); siter.Next()) {
561  const StateId &state = siter.Value();
562  num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
563  ArcIterator<Fst<A>> aiter(fst, state);
564  if (!aiter.Done()) {
565  const Arc &arc = aiter.Value();
566  // this arc goes from state to arc.nextstate, so create an arc from
567  // arc.nextstate to state to reverse it.
568  if (arc.ilabel == 0) {
569  context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
570  arc.weight, state));
571  num_context_arcs++;
572  }
573  }
574  }
575  if (num_context_arcs != context_fst.NumStates() - 1) {
576  FSTERROR() << "Number of contexts arcs != number of states - 1";
577  SetProperties(kError, kError);
578  return;
579  }
580  if (context_fst.NumStates() != num_states) {
581  FSTERROR() << "Number of contexts != number of states";
582  SetProperties(kError, kError);
583  return;
584  }
585  int64 context_props =
586  context_fst.Properties(kIDeterministic | kILabelSorted, true);
587  if (!(context_props & kIDeterministic)) {
588  FSTERROR() << "Input Fst is not structured properly";
589  SetProperties(kError, kError);
590  return;
591  }
592  if (!(context_props & kILabelSorted)) {
593  ArcSort(&context_fst, ILabelCompare<Arc>());
594  }
595 
596  delete[] context;
597 
598  uint64 b64;
599  Weight weight;
600  Label label = kNoLabel;
601  const size_t storage = Storage(num_states, num_futures, num_final);
602  MappedFile *data_region = MappedFile::Allocate(storage);
603  char *data = reinterpret_cast<char *>(data_region->mutable_data());
604  memset(data, 0, storage);
605  size_t offset = 0;
606  memcpy(data + offset, reinterpret_cast<char *>(&num_states),
607  sizeof(num_states));
608  offset += sizeof(num_states);
609  memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
610  sizeof(num_futures));
611  offset += sizeof(num_futures);
612  memcpy(data + offset, reinterpret_cast<char *>(&num_final),
613  sizeof(num_final));
614  offset += sizeof(num_final);
615  uint64 *context_bits = reinterpret_cast<uint64 *>(data + offset);
616  offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
617  uint64 *future_bits = reinterpret_cast<uint64 *>(data + offset);
618  offset +=
619  BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
620  uint64 *final_bits = reinterpret_cast<uint64 *>(data + offset);
621  offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
622  Label *context_words = reinterpret_cast<Label *>(data + offset);
623  offset += (num_states + 1) * sizeof(label);
624  Label *future_words = reinterpret_cast<Label *>(data + offset);
625  offset += num_futures * sizeof(label);
626  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
627  Weight *backoff = reinterpret_cast<Weight *>(data + offset);
628  offset += (num_states + 1) * sizeof(weight);
629  Weight *final_probs = reinterpret_cast<Weight *>(data + offset);
630  offset += num_final * sizeof(weight);
631  Weight *future_probs = reinterpret_cast<Weight *>(data + offset);
632  int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
633  final_bit = 0;
634 
635  // pseudo-root bits
636  BitmapIndex::Set(context_bits, context_bit++);
637  ++context_bit;
638  context_words[context_arc] = label;
639  backoff[context_arc] = Weight::Zero();
640  context_arc++;
641 
642  ++future_bit;
643  if (order_out) {
644  order_out->clear();
645  order_out->resize(num_states);
646  }
647 
648  std::queue<StateId> context_q;
649  context_q.push(context_fst.Start());
650  StateId state_number = 0;
651  while (!context_q.empty()) {
652  const StateId &state = context_q.front();
653  if (order_out) {
654  (*order_out)[state] = state_number;
655  }
656 
657  const Weight final_weight = context_fst.Final(state);
658  if (final_weight != Weight::Zero()) {
659  BitmapIndex::Set(final_bits, state_number);
660  final_probs[final_bit] = final_weight;
661  ++final_bit;
662  }
663 
664  for (ArcIterator<VectorFst<A>> aiter(context_fst, state); !aiter.Done();
665  aiter.Next()) {
666  const Arc &arc = aiter.Value();
667  context_words[context_arc] = arc.ilabel;
668  backoff[context_arc] = arc.weight;
669  ++context_arc;
670  BitmapIndex::Set(context_bits, context_bit++);
671  context_q.push(arc.nextstate);
672  }
673  ++context_bit;
674 
675  for (ArcIterator<Fst<A>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
676  const Arc &arc = aiter.Value();
677  if (arc.ilabel != 0) {
678  future_words[future_arc] = arc.ilabel;
679  future_probs[future_arc] = arc.weight;
680  ++future_arc;
681  BitmapIndex::Set(future_bits, future_bit++);
682  }
683  }
684  ++future_bit;
685  ++state_number;
686  context_q.pop();
687  }
688 
689  if ((state_number != num_states) || (context_bit != num_states * 2 + 1) ||
690  (context_arc != num_states) || (future_arc != num_futures) ||
691  (future_bit != num_futures + num_states + 1) ||
692  (final_bit != num_final)) {
693  FSTERROR() << "Structure problems detected during construction";
694  SetProperties(kError, kError);
695  return;
696  }
697 
698  Init(data, false, data_region);
699 }
700 
701 template <typename A>
702 inline void NGramFstImpl<A>::Init(const char *data, bool owned,
703  MappedFile *data_region) {
704  if (owned_) {
705  delete[] data_;
706  }
707  data_region_.reset(data_region);
708  owned_ = owned;
709  data_ = data;
710  size_t offset = 0;
711  num_states_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
712  offset += sizeof(num_states_);
713  num_futures_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
714  offset += sizeof(num_futures_);
715  num_final_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
716  offset += sizeof(num_final_);
717  uint64 bits;
718  size_t context_bits = num_states_ * 2 + 1;
719  size_t future_bits = num_futures_ + num_states_ + 1;
720  context_ = reinterpret_cast<const uint64 *>(data_ + offset);
721  offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
722  future_ = reinterpret_cast<const uint64 *>(data_ + offset);
723  offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
724  final_ = reinterpret_cast<const uint64 *>(data_ + offset);
725  offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
726  context_words_ = reinterpret_cast<const Label *>(data_ + offset);
727  offset += (num_states_ + 1) * sizeof(*context_words_);
728  future_words_ = reinterpret_cast<const Label *>(data_ + offset);
729  offset += num_futures_ * sizeof(*future_words_);
730  offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
731  backoff_ = reinterpret_cast<const Weight *>(data_ + offset);
732  offset += (num_states_ + 1) * sizeof(*backoff_);
733  final_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
734  offset += num_final_ * sizeof(*final_probs_);
735  future_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
736 
737  context_index_.BuildIndex(context_, context_bits);
738  future_index_.BuildIndex(future_, future_bits);
739  final_index_.BuildIndex(final_, num_states_);
740 
741  select_root_ = context_index_.Select0s(0);
742  if (context_index_.Rank1(0) != 0 || select_root_.first != 1 ||
743  context_index_.Get(2) == false) {
744  FSTERROR() << "Malformed file";
745  SetProperties(kError, kError);
746  return;
747  }
748  root_children_ = context_words_ + context_index_.Rank1(2);
749  start_ = 1;
750 }
751 
752 template <typename A>
753 inline typename A::StateId NGramFstImpl<A>::Transition(
754  const std::vector<Label> &context, Label future) const {
755  const Label *children = root_children_;
756  size_t num_children = select_root_.second - 2;
757  const Label *loc =
758  std::lower_bound(children, children + num_children, future);
759  if (loc == children + num_children || *loc != future) {
760  return context_index_.Rank1(0);
761  }
762  size_t node = 2 + loc - children;
763  size_t node_rank = context_index_.Rank1(node);
764  std::pair<size_t, size_t> zeros =
765  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
766  size_t first_child = zeros.first + 1;
767  if (context_index_.Get(first_child) == false) {
768  return context_index_.Rank1(node);
769  }
770  size_t last_child = zeros.second - 1;
771  for (int word = context.size() - 1; word >= 0; --word) {
772  children = context_words_ + context_index_.Rank1(first_child);
773  loc = std::lower_bound(children, children + last_child - first_child + 1,
774  context[word]);
775  if (loc == children + last_child - first_child + 1 ||
776  *loc != context[word]) {
777  break;
778  }
779  node = first_child + loc - children;
780  node_rank = context_index_.Rank1(node);
781  zeros =
782  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
783  first_child = zeros.first + 1;
784  if (context_index_.Get(first_child) == false) break;
785  last_child = zeros.second - 1;
786  }
787  return context_index_.Rank1(node);
788 }
789 
790 } // namespace internal
791 
792 /*****************************************************************************/
793 template <class A>
794 class NGramFstMatcher : public MatcherBase<A> {
795  public:
796  typedef A Arc;
797  typedef typename A::Label Label;
798  typedef typename A::StateId StateId;
799  typedef typename A::Weight Weight;
800 
801  // This makes a copy of the FST.
803  : owned_fst_(fst.Copy()),
804  fst_(*owned_fst_),
805  inst_(fst_.inst_),
806  match_type_(match_type),
807  current_loop_(false),
808  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
809  if (match_type_ == MATCH_OUTPUT) {
810  std::swap(loop_.ilabel, loop_.olabel);
811  }
812  }
813 
814  // This doesn't copy the FST.
816  : fst_(*fst),
817  inst_(fst_.inst_),
818  match_type_(match_type),
819  current_loop_(false),
820  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
821  if (match_type_ == MATCH_OUTPUT) {
822  std::swap(loop_.ilabel, loop_.olabel);
823  }
824  }
825 
826  // This makes a copy of the FST.
827  NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
828  : owned_fst_(matcher.fst_.Copy(safe)),
829  fst_(*owned_fst_),
830  inst_(matcher.inst_),
831  match_type_(matcher.match_type_),
832  current_loop_(false),
833  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
834  if (match_type_ == MATCH_OUTPUT) {
835  std::swap(loop_.ilabel, loop_.olabel);
836  }
837  }
838 
839  NGramFstMatcher<A> *Copy(bool safe = false) const override {
840  return new NGramFstMatcher<A>(*this, safe);
841  }
842 
843  MatchType Type(bool test) const override { return match_type_; }
844 
845  const Fst<A> &GetFst() const override { return fst_; }
846 
847  uint64 Properties(uint64 props) const override { return props; }
848 
849  void SetState(StateId s) final {
850  fst_.GetImpl()->SetInstFuture(s, &inst_);
851  current_loop_ = false;
852  }
853 
854  bool Find(Label label) final {
855  const Label nolabel = kNoLabel;
856  done_ = true;
857  if (label == 0 || label == nolabel) {
858  if (label == 0) {
859  current_loop_ = true;
860  loop_.nextstate = inst_.state_;
861  }
862  // The unigram state has no epsilon arc.
863  if (inst_.state_ != 0) {
864  arc_.ilabel = arc_.olabel = 0;
865  fst_.GetImpl()->SetInstNode(&inst_);
866  arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
867  fst_.GetImpl()->context_index_.Select1(
868  fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
869  arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
870  done_ = false;
871  }
872  } else {
873  current_loop_ = false;
874  const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
875  const Label *end = start + inst_.num_futures_;
876  const Label *search = std::lower_bound(start, end, label);
877  if (search != end && *search == label) {
878  size_t state = search - start;
879  arc_.ilabel = arc_.olabel = label;
880  arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
881  fst_.GetImpl()->SetInstContext(&inst_);
882  arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
883  done_ = false;
884  }
885  }
886  return !Done();
887  }
888 
889  bool Done() const final { return !current_loop_ && done_; }
890 
891  const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; }
892 
893  void Next() final {
894  if (current_loop_) {
895  current_loop_ = false;
896  } else {
897  done_ = true;
898  }
899  }
900 
901  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
902 
903  private:
904  std::unique_ptr<NGramFst<A>> owned_fst_;
905  const NGramFst<A> &fst_;
906  NGramFstInst<A> inst_;
907  MatchType match_type_; // Supplied by caller
908  bool done_;
909  Arc arc_;
910  bool current_loop_; // Current arc is the implicit loop
911  Arc loop_;
912 };
913 
914 /*****************************************************************************/
915 // Specialization for NGramFst; see generic version in fst.h
916 // for sample usage (but use the ProdLmFst type!). This version
917 // should inline.
918 template <class A>
919 class StateIterator<NGramFst<A>> : public StateIteratorBase<A> {
920  public:
921  typedef typename A::StateId StateId;
922 
923  explicit StateIterator(const NGramFst<A> &fst)
924  : s_(0), num_states_(fst.NumStates()) {}
925 
926  bool Done() const final { return s_ >= num_states_; }
927 
928  StateId Value() const final { return s_; }
929 
930  void Next() final { ++s_; }
931 
932  void Reset() final { s_ = 0; }
933 
934  private:
935  StateId s_;
936  StateId num_states_;
937 };
938 
939 /*****************************************************************************/
940 template <class A>
941 class ArcIterator<NGramFst<A>> : public ArcIteratorBase<A> {
942  public:
943  typedef A Arc;
944  typedef typename A::Label Label;
945  typedef typename A::StateId StateId;
946  typedef typename A::Weight Weight;
947 
948  ArcIterator(const NGramFst<A> &fst, StateId state)
949  : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
950  inst_ = fst.inst_;
951  impl_->SetInstFuture(state, &inst_);
952  impl_->SetInstNode(&inst_);
953  }
954 
955  bool Done() const final {
956  return i_ >=
957  ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1);
958  }
959 
960  const Arc &Value() const final {
961  bool eps = (inst_.node_ != 0 && i_ == 0);
962  StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
963  if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
964  arc_.ilabel = arc_.olabel =
965  eps ? 0 : impl_->future_words_[inst_.offset_ + state];
966  lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
967  }
968  if (flags_ & lazy_ & kArcNextStateValue) {
969  if (eps) {
970  arc_.nextstate =
971  impl_->context_index_.Rank1(impl_->context_index_.Select1(
972  impl_->context_index_.Rank0(inst_.node_) - 1));
973  } else {
974  if (lazy_ & kArcNextStateValue) {
975  impl_->SetInstContext(&inst_); // first time only.
976  }
977  arc_.nextstate = impl_->Transition(
978  inst_.context_, impl_->future_words_[inst_.offset_ + state]);
979  }
980  lazy_ &= ~kArcNextStateValue;
981  }
982  if (flags_ & lazy_ & kArcWeightValue) {
983  arc_.weight = eps ? impl_->backoff_[inst_.state_]
984  : impl_->future_probs_[inst_.offset_ + state];
985  lazy_ &= ~kArcWeightValue;
986  }
987  return arc_;
988  }
989 
990  void Next() final {
991  ++i_;
992  lazy_ = ~0;
993  }
994 
995  size_t Position() const final { return i_; }
996 
997  void Reset() final {
998  i_ = 0;
999  lazy_ = ~0;
1000  }
1001 
1002  void Seek(size_t a) final {
1003  if (i_ != a) {
1004  i_ = a;
1005  lazy_ = ~0;
1006  }
1007  }
1008 
1009  uint32 Flags() const final { return flags_; }
1010 
1011  void SetFlags(uint32 flags, uint32 mask) final {
1012  flags_ &= ~mask;
1013  flags_ |= (flags & kArcValueFlags);
1014  }
1015 
1016  private:
1017  mutable Arc arc_;
1018  mutable uint32 lazy_;
1019  const internal::NGramFstImpl<A> *impl_; // Borrowed reference.
1020  mutable NGramFstInst<A> inst_;
1021 
1022  size_t i_;
1023  uint32 flags_;
1024 };
1025 
1026 } // namespace fst
1027 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
internal::NGramFstImpl< A > Impl
Definition: ngram-fst.h:335
void GetStates(const std::vector< Label > &context, std::vector< StateId > *state) const
Definition: ngram-fst.h:367
void AddArc(StateId s, const Arc &arc) override
Definition: mutable-fst.h:303
void * mutable_data() const
Definition: mapped-file.h:34
NGramFst< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:376
void SetStart(StateId s) override
Definition: mutable-fst.h:280
size_t num_futures_
Definition: ngram-fst.h:40
constexpr uint64 kInitialAcyclic
Definition: properties.h:97
void Next() final
Definition: ngram-fst.h:893
constexpr int kNoLabel
Definition: fst.h:179
size_t NumInputEpsilons(StateId state) const
Definition: ngram-fst.h:157
uint64_t uint64
Definition: types.h:32
const Arc & Value() const final
Definition: ngram-fst.h:891
void SetFinal(StateId s, Weight weight) override
Definition: mutable-fst.h:285
virtual size_t NumArcs(StateId) const =0
void SetInstContext(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:210
uint64 Properties(uint64 props) const override
Definition: ngram-fst.h:847
MatchType
Definition: fst.h:171
StateId state_
Definition: ngram-fst.h:39
constexpr uint64 kILabelSorted
Definition: properties.h:75
size_t Position() const final
Definition: ngram-fst.h:995
size_t StorageSize() const
Definition: ngram-fst.h:418
static NGramFstImpl< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:97
#define LOG(type)
Definition: log.h:48
StateId Start() const override
Definition: fst.h:875
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:37
StateId node_state_
Definition: ngram-fst.h:43
const Arc & Value() const final
Definition: ngram-fst.h:960
constexpr int kNoStateId
Definition: fst.h:180
NGramFstMatcher< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:839
constexpr uint64 kExpanded
Definition: properties.h:27
const Arc & Value() const
Definition: fst.h:503
NGramFstMatcher(const NGramFstMatcher< A > &matcher, bool safe=false)
Definition: ngram-fst.h:827
virtual uint64 Properties(uint64 mask, bool test) const =0
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
Definition: ngram-fst.h:464
NGramFst(const Fst< A > &fst, std::vector< StateId > *order_out)
Definition: ngram-fst.h:340
const Fst< A > & GetFst() const override
Definition: ngram-fst.h:845
virtual size_t NumInputEpsilons(StateId) const =0
int64_t int64
Definition: types.h:27
constexpr uint64 kEpsilons
Definition: properties.h:60
#define FSTERROR()
Definition: util.h:35
StateId AddState() override
Definition: mutable-fst.h:298
bool WriteFile(const string &filename) const
Definition: fst.h:303
NGramFst(const Fst< A > &dst)
Definition: ngram-fst.h:337
size_t NumArcs(StateId state, NGramFstInst< A > *inst=nullptr) const
Definition: ngram-fst.h:147
const std::vector< Label > & GetContext(StateId s, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:231
void SetOutputSymbols(const SymbolTable *osyms) override
Definition: mutable-fst.h:373
NGramFst(const NGramFst< A > &fst, bool safe=false)
Definition: ngram-fst.h:345
StateIteratorBase< Arc > * base
Definition: fst.h:351
A::Weight Weight
Definition: ngram-fst.h:38
static void Set(uint64 *bits, size_t index)
Definition: bitmap-index.h:41
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:93
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:357
A::StateId StateId
Definition: ngram-fst.h:37
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:87
void SetInstNode(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:203
constexpr uint64 kNotString
Definition: properties.h:120
static NGramFst< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:380
void SetInputSymbols(const SymbolTable *isyms) override
Definition: mutable-fst.h:368
std::vector< Label > context_
Definition: ngram-fst.h:44
MatcherBase< A > * InitMatcher(MatchType match_type) const override
Definition: ngram-fst.h:414
constexpr uint64 kOLabelSorted
Definition: properties.h:80
constexpr uint64 kIEpsilons
Definition: properties.h:65
StateId Value() const final
Definition: ngram-fst.h:928
void SetNumStates(int64 numstates)
Definition: fst.h:147
StateId NumStates() const override
Definition: expanded-fst.h:119
NGramFstMatcher(const NGramFst< A > &fst, MatchType match_type)
Definition: ngram-fst.h:802
StateId nstates
Definition: fst.h:353
constexpr uint64 kIDeterministic
Definition: properties.h:50
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:224
MatchType Type(bool test) const override
Definition: ngram-fst.h:843
size_t NumArcs(StateId s) const override
Definition: ngram-fst.h:372
void SetInstFuture(StateId state, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:194
virtual StateId Start() const =0
bool Done() const
Definition: fst.h:499
static size_t StorageSize(size_t size)
Definition: bitmap-index.h:30
A::StateId StateId
Definition: ngram-fst.h:798
constexpr uint64 kAcceptor
Definition: properties.h:45
NGramFst(const char *data, bool owned)
Definition: ngram-fst.h:351
void SetState(StateId s) final
Definition: ngram-fst.h:849
uint32 Flags() const final
Definition: ngram-fst.h:1009
NGramFstImpl(const Fst< A > &fst)
Definition: ngram-fst.h:84
const std::vector< Label > GetContext(StateId s) const
Definition: ngram-fst.h:361
void Seek(size_t a) final
Definition: ngram-fst.h:1002
StateId NumStates() const
Definition: ngram-fst.h:167
A::Label Label
Definition: ngram-fst.h:333
size_t StorageSize() const
Definition: ngram-fst.h:237
uint32_t uint32
Definition: types.h:31
constexpr uint64 kCoAccessible
Definition: properties.h:110
void SetStart(int64 start)
Definition: fst.h:145
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
Definition: ngram-fst.h:399
StateId context_state_
Definition: ngram-fst.h:45
virtual const SymbolTable * InputSymbols() const =0
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:154
ssize_t Priority(StateId s) final
Definition: ngram-fst.h:901
ArcIterator(const NGramFst< A > &fst, StateId state)
Definition: ngram-fst.h:948
static size_t Storage(uint64 num_states, uint64 num_futures, uint64 num_final)
Definition: ngram-fst.h:174
Weight Final(StateId s) const override
Definition: fst.h:877
size_t NumOutputEpsilons(StateId state) const
Definition: ngram-fst.h:163
NGramFstImpl(const NGramFstImpl &other)
Definition: ngram-fst.h:86
bool Done() const final
Definition: ngram-fst.h:889
constexpr uint64 kWeighted
Definition: properties.h:85
constexpr uint64 kError
Definition: properties.h:33
static MappedFile * Allocate(size_t size, int align=kArchAlignment)
Definition: mapped-file.cc:105
StateId Start() const
Definition: ngram-fst.h:137
NGramFstMatcher(const NGramFst< A > *fst, MatchType match_type)
Definition: ngram-fst.h:815
A::Label Label
Definition: ngram-fst.h:36
constexpr uint64 kNotTopSorted
Definition: properties.h:102
ArcIteratorBase< Arc > * base
Definition: fst.h:461
void SetFlags(uint32 flags, uint32 mask) final
Definition: ngram-fst.h:1011
A::Weight Weight
Definition: ngram-fst.h:334
void InitStateIterator(StateIteratorData< A > *data) const override
Definition: ngram-fst.h:407
Weight Final(StateId state) const
Definition: ngram-fst.h:139
void InitStateIterator(StateIteratorData< A > *data) const
Definition: ngram-fst.h:169
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
Definition: ngram-fst.h:127
A::StateId StateId
Definition: ngram-fst.h:332
uint64 Properties(uint64 mask, bool test) const override
Definition: fst.h:889
bool Find(Label label) final
Definition: ngram-fst.h:854
constexpr uint64 kODeterministic
Definition: properties.h:55
static bool HasRequiredStructure(const Fst< A > &fst)
Definition: ngram-fst.h:426
StateIterator(const NGramFst< A > &fst)
Definition: ngram-fst.h:923
bool ReadHeader(std::istream &strm, const FstReadOptions &opts, int min_version, FstHeader *hdr)
constexpr uint64 kOEpsilons
Definition: properties.h:70
constexpr uint64 kAccessible
Definition: properties.h:105
void Init(const char *data, bool owned, MappedFile *file=nullptr)
Definition: ngram-fst.h:702
static bool HasRequiredProps(const Fst< A > &fst)
Definition: ngram-fst.h:420
bool Write(const string &filename) const override
Definition: ngram-fst.h:403
static NGramFst< A > * Read(const string &filename)
Definition: ngram-fst.h:385
constexpr uint64 kCyclic
Definition: properties.h:90
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:507