FST  openfst-1.7.5
OpenFst Library
ngram-fst.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // NgramFst implements a n-gram language model based upon the LOUDS data
5 // structure. Please refer to "Unary Data Structures for Language Models"
6 // http://research.google.com/pubs/archive/37218.pdf
7 
8 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
9 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
10 
11 #include <algorithm>
12 #include <cstddef>
13 #include <cstring>
14 #include <iostream>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include <fst/compat.h>
20 #include <fst/types.h>
21 #include <fst/log.h>
23 #include <fstream>
24 #include <fst/fstlib.h>
25 #include <fst/mapped-file.h>
26 
27 namespace fst {
28 template <class A>
29 class NGramFst;
30 
31 template <class A>
33 
34 // Instance data containing mutable state for bookkeeping repeated access to
35 // the same state.
36 template <class A>
37 struct NGramFstInst {
38  typedef typename A::Label Label;
39  typedef typename A::StateId StateId;
40  typedef typename A::Weight Weight;
41  StateId state_;
42  size_t num_futures_;
43  size_t offset_;
44  size_t node_;
45  StateId node_state_;
46  std::vector<Label> context_;
47  StateId context_state_;
49  : state_(kNoStateId),
50  node_state_(kNoStateId),
51  context_state_(kNoStateId) {}
52 };
53 
54 namespace internal {
55 
56 // Implementation class for LOUDS based NgramFst interface.
57 template <class A>
58 class NGramFstImpl : public FstImpl<A> {
61  using FstImpl<A>::SetType;
63 
64  friend class ArcIterator<NGramFst<A>>;
65  friend class NGramFstMatcher<A>;
66 
67  public:
71 
72  typedef A Arc;
73  typedef typename A::Label Label;
74  typedef typename A::StateId StateId;
75  typedef typename A::Weight Weight;
76 
78  SetType("ngram");
79  SetInputSymbols(nullptr);
80  SetOutputSymbols(nullptr);
81  SetProperties(kStaticProperties);
82  }
83 
84  NGramFstImpl(const Fst<A> &fst, std::vector<StateId> *order_out);
85 
86  explicit NGramFstImpl(const Fst<A> &fst) : NGramFstImpl(fst, nullptr) {}
87 
88  NGramFstImpl(const NGramFstImpl &other) {
89  FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false.";
90  SetProperties(kError, kError);
91  }
92 
93  ~NGramFstImpl() override {
94  if (owned_) {
95  delete[] data_;
96  }
97  }
98 
99  static NGramFstImpl<A> *Read(std::istream &strm, // NOLINT
100  const FstReadOptions &opts) {
101  auto impl = fst::make_unique<NGramFstImpl<A>>();
102  FstHeader hdr;
103  if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr;
104  uint64 num_states, num_futures, num_final;
105  const size_t offset =
106  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
107  // Peek at num_states and num_futures to see how much more needs to be read.
108  strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
109  strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
110  strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
111  size_t size = Storage(num_states, num_futures, num_final);
112  MappedFile *data_region = MappedFile::Allocate(size);
113  char *data = static_cast<char *>(data_region->mutable_data());
114  // Copy num_states, num_futures and num_final back into data.
115  memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
116  memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
117  sizeof(num_futures));
118  memcpy(data + sizeof(num_states) + sizeof(num_futures),
119  reinterpret_cast<char *>(&num_final), sizeof(num_final));
120  strm.read(data + offset, size - offset);
121  if (strm.fail()) return nullptr;
122  impl->Init(data, false, data_region);
123  return impl.release();
124  }
125 
126  bool Write(std::ostream &strm, // NOLINT
127  const FstWriteOptions &opts) const {
128  FstHeader hdr;
129  hdr.SetStart(Start());
130  hdr.SetNumStates(num_states_);
131  WriteHeader(strm, opts, kFileVersion, &hdr);
132  strm.write(data_, StorageSize());
133  return !strm.fail();
134  }
135 
136  StateId Start() const { return start_; }
137 
138  Weight Final(StateId state) const {
139  if (final_index_.Get(state)) {
140  return final_probs_[final_index_.Rank1(state)];
141  } else {
142  return Weight::Zero();
143  }
144  }
145 
146  size_t NumArcs(StateId state, NGramFstInst<A> *inst = nullptr) const {
147  if (inst == nullptr) {
148  const std::pair<size_t, size_t> zeros =
149  (state == 0) ? select_root_ : future_index_.Select0s(state);
150  return zeros.second - zeros.first - 1;
151  }
152  SetInstFuture(state, inst);
153  return inst->num_futures_ + ((state == 0) ? 0 : 1);
154  }
155 
156  size_t NumInputEpsilons(StateId state) const {
157  // State 0 has no parent, thus no backoff.
158  if (state == 0) return 0;
159  return 1;
160  }
161 
162  size_t NumOutputEpsilons(StateId state) const {
163  return NumInputEpsilons(state);
164  }
165 
166  StateId NumStates() const { return num_states_; }
167 
169  data->base = nullptr;
170  data->nstates = num_states_;
171  }
172 
173  static size_t Storage(uint64 num_states, uint64 num_futures,
174  uint64 num_final) {
175  uint64 b64;
176  Weight weight;
177  Label label;
178  size_t offset =
179  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
180  offset +=
181  sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) +
182  BitmapIndex::StorageSize(num_futures + num_states + 1) +
183  BitmapIndex::StorageSize(num_states));
184  offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
185  // Pad for alignemnt, see
186  // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
187  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
188  offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
189  (num_futures + 1) * sizeof(weight);
190  return offset;
191  }
192 
193  void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
194  if (inst->state_ != state) {
195  inst->state_ = state;
196  const std::pair<size_t, size_t> zeros = future_index_.Select0s(state);
197  inst->num_futures_ = zeros.second - zeros.first - 1;
198  inst->offset_ = future_index_.Rank1(zeros.first + 1);
199  }
200  }
201 
202  void SetInstNode(NGramFstInst<A> *inst) const {
203  if (inst->node_state_ != inst->state_) {
204  inst->node_state_ = inst->state_;
205  inst->node_ = context_index_.Select1(inst->state_);
206  }
207  }
208 
209  void SetInstContext(NGramFstInst<A> *inst) const {
210  SetInstNode(inst);
211  if (inst->context_state_ != inst->state_) {
212  inst->context_state_ = inst->state_;
213  inst->context_.clear();
214  size_t node = inst->node_;
215  while (node != 0) {
216  inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
217  node = context_index_.Select1(context_index_.Rank0(node) - 1);
218  }
219  }
220  }
221 
222  // Access to the underlying representation
223  const char *GetData(size_t *data_size) const {
224  *data_size = StorageSize();
225  return data_;
226  }
227 
228  void Init(const char *data, bool owned, MappedFile *file = nullptr);
229 
230  const std::vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
231  SetInstFuture(s, inst);
232  SetInstContext(inst);
233  return inst->context_;
234  }
235 
236  size_t StorageSize() const {
237  return Storage(num_states_, num_futures_, num_final_);
238  }
239 
240  void GetStates(const std::vector<Label> &context,
241  std::vector<StateId> *states) const;
242 
243  private:
244  StateId Transition(const std::vector<Label> &context, Label future) const;
245 
246  // Properties always true for this Fst class.
247  static const uint64 kStaticProperties =
252  // Current file format version.
253  static const int kFileVersion = 4;
254  // Minimum file format version supported.
255  static const int kMinFileVersion = 4;
256 
257  std::unique_ptr<MappedFile> data_region_;
258  const char *data_ = nullptr;
259  bool owned_ = false; // True if we own data_
260  StateId start_ = fst::kNoStateId;
261  uint64 num_states_ = 0;
262  uint64 num_futures_ = 0;
263  uint64 num_final_ = 0;
264  std::pair<size_t, size_t> select_root_;
265  const Label *root_children_ = nullptr;
266  // borrowed references
267  const uint64 *context_ = nullptr;
268  const uint64 *future_ = nullptr;
269  const uint64 *final_ = nullptr;
270  const Label *context_words_ = nullptr;
271  const Label *future_words_ = nullptr;
272  const Weight *backoff_ = nullptr;
273  const Weight *final_probs_ = nullptr;
274  const Weight *future_probs_ = nullptr;
275  BitmapIndex context_index_;
276  BitmapIndex future_index_;
277  BitmapIndex final_index_;
278 };
279 
280 template <typename A>
282  const std::vector<Label> &context,
283  std::vector<typename A::StateId> *states) const {
284  states->clear();
285  states->push_back(0);
286  typename std::vector<Label>::const_reverse_iterator cit = context.rbegin();
287  const Label *children = root_children_;
288  size_t num_children = select_root_.second - 2;
289  const Label *loc = std::lower_bound(children, children + num_children, *cit);
290  if (loc == children + num_children || *loc != *cit) return;
291  size_t node = 2 + loc - children;
292  states->push_back(context_index_.Rank1(node));
293  if (context.size() == 1) return;
294  size_t node_rank = context_index_.Rank1(node);
295  std::pair<size_t, size_t> zeros =
296  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
297  size_t first_child = zeros.first + 1;
298  ++cit;
299  if (context_index_.Get(first_child) != false) {
300  size_t last_child = zeros.second - 1;
301  while (cit != context.rend()) {
302  children = context_words_ + context_index_.Rank1(first_child);
303  loc = std::lower_bound(children, children + last_child - first_child + 1,
304  *cit);
305  if (loc == children + last_child - first_child + 1 || *loc != *cit) {
306  break;
307  }
308  ++cit;
309  node = first_child + loc - children;
310  states->push_back(context_index_.Rank1(node));
311  node_rank = context_index_.Rank1(node);
312  zeros =
313  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
314  first_child = zeros.first + 1;
315  if (context_index_.Get(first_child) == false) break;
316  last_child = zeros.second - 1;
317  }
318  }
319 }
320 
321 } // namespace internal
322 
323 /*****************************************************************************/
324 template <class A>
325 class NGramFst : public ImplToExpandedFst<internal::NGramFstImpl<A>> {
326  friend class ArcIterator<NGramFst<A>>;
327  friend class NGramFstMatcher<A>;
328 
329  public:
330  typedef A Arc;
331  typedef typename A::StateId StateId;
332  typedef typename A::Label Label;
333  typedef typename A::Weight Weight;
335 
336  explicit NGramFst(const Fst<A> &dst)
337  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(dst, nullptr)) {}
338 
339  NGramFst(const Fst<A> &fst, std::vector<StateId> *order_out)
340  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst, order_out)) {}
341 
342  // Because the NGramFstImpl is a const stateless data structure, there
343  // is never a need to do anything beside copy the reference.
344  NGramFst(const NGramFst<A> &fst, bool safe = false)
345  : ImplToExpandedFst<Impl>(fst, false) {}
346 
347  NGramFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
348 
349  // Non-standard constructor to initialize NGramFst directly from data.
350  NGramFst(const char *data, bool owned)
351  : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {
352  GetMutableImpl()->Init(data, owned, nullptr);
353  }
354 
355  // Get method that gets the data associated with Init().
356  const char *GetData(size_t *data_size) const {
357  return GetImpl()->GetData(data_size);
358  }
359 
360  const std::vector<Label> GetContext(StateId s) const {
361  return GetImpl()->GetContext(s, &inst_);
362  }
363 
364  // Consumes as much as possible of context from right to left, returns the
365  // the states corresponding to the increasingly conditioned input sequence.
366  void GetStates(const std::vector<Label> &context,
367  std::vector<StateId> *state) const {
368  return GetImpl()->GetStates(context, state);
369  }
370 
371  size_t NumArcs(StateId s) const override {
372  return GetImpl()->NumArcs(s, &inst_);
373  }
374 
375  NGramFst<A> *Copy(bool safe = false) const override {
376  return new NGramFst(*this, safe);
377  }
378 
379  static NGramFst<A> *Read(std::istream &strm, const FstReadOptions &opts) {
380  Impl *impl = Impl::Read(strm, opts);
381  return impl ? new NGramFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
382  }
383 
384  static NGramFst<A> *Read(const std::string &source) {
385  if (!source.empty()) {
386  std::ifstream strm(source,
387  std::ios_base::in | std::ios_base::binary);
388  if (!strm.good()) {
389  LOG(ERROR) << "NGramFst::Read: Can't open file: " << source;
390  return nullptr;
391  }
392  return Read(strm, FstReadOptions(source));
393  } else {
394  return Read(std::cin, FstReadOptions("standard input"));
395  }
396  }
397 
398  bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
399  return GetImpl()->Write(strm, opts);
400  }
401 
402  bool Write(const std::string &source) const override {
403  return Fst<A>::WriteFile(source);
404  }
405 
406  inline void InitStateIterator(StateIteratorData<A> *data) const override {
407  GetImpl()->InitStateIterator(data);
408  }
409 
410  inline void InitArcIterator(StateId s,
411  ArcIteratorData<A> *data) const override;
412 
413  MatcherBase<A> *InitMatcher(MatchType match_type) const override {
414  return new NGramFstMatcher<A>(this, match_type);
415  }
416 
417  size_t StorageSize() const { return GetImpl()->StorageSize(); }
418 
419  static bool HasRequiredProps(const Fst<A> &fst) {
420  static const auto props =
422  return fst.Properties(props, true) == props;
423  }
424 
425  static bool HasRequiredStructure(const Fst<A> &fst) {
426  if (!HasRequiredProps(fst)) {
427  return false;
428  }
429  typename A::StateId unigram = fst.Start();
430  while (true) { // Follows epsilon arc chain to find unigram state.
431  if (unigram == fst::kNoStateId) return false; // No unigram state.
432  typename fst::ArcIterator<Fst<A>> aiter(fst, unigram);
433  if (aiter.Done() || aiter.Value().ilabel != 0) break;
434  unigram = aiter.Value().nextstate;
435  aiter.Next();
436  }
437  // Other requirement: all states other than unigram an epsilon arc.
438  for (fst::StateIterator<Fst<A>> siter(fst); !siter.Done();
439  siter.Next()) {
440  const typename A::StateId &state = siter.Value();
441  fst::ArcIterator<Fst<A>> aiter(fst, state);
442  if (state != unigram) {
443  if (aiter.Done()) return false;
444  if (aiter.Value().ilabel != 0) return false;
445  aiter.Next();
446  if (!aiter.Done() && aiter.Value().ilabel == 0) return false;
447  }
448  }
449  return true;
450  }
451 
452  private:
454  using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetMutableImpl;
455 
456  explicit NGramFst(std::shared_ptr<Impl> impl)
457  : ImplToExpandedFst<Impl>(impl) {}
458 
459  mutable NGramFstInst<A> inst_;
460 };
461 
462 template <class A>
464  ArcIteratorData<A> *data) const {
465  GetImpl()->SetInstFuture(s, &inst_);
466  GetImpl()->SetInstNode(&inst_);
467  data->base = new ArcIterator<NGramFst<A>>(*this, s);
468 }
469 
470 namespace internal {
471 
472 template <typename A>
474  std::vector<StateId> *order_out) {
475  typedef A Arc;
476  typedef typename Arc::Label Label;
477  typedef typename Arc::Weight Weight;
478  typedef typename Arc::StateId StateId;
479  SetType("ngram");
480  SetInputSymbols(fst.InputSymbols());
481  SetOutputSymbols(fst.OutputSymbols());
482  SetProperties(kStaticProperties);
483 
484  // Check basic requirements for an OpenGrm language model Fst.
485  if (!NGramFst<A>::HasRequiredProps(fst)) {
486  FSTERROR() << "NGramFst only accepts OpenGrm language models as input";
487  SetProperties(kError, kError);
488  return;
489  }
490 
491  int64 num_states = CountStates(fst);
492  std::vector<Label> context(num_states, 0);
493 
494  // Find the unigram state by starting from the start state, following
495  // epsilons.
496  StateId unigram = fst.Start();
497  while (true) {
498  if (unigram == kNoStateId) {
499  FSTERROR() << "Could not identify unigram state";
500  SetProperties(kError, kError);
501  return;
502  }
503  ArcIterator<Fst<A>> aiter(fst, unigram);
504  if (aiter.Done()) {
505  LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
506  break;
507  }
508  if (aiter.Value().ilabel != 0) break;
509  unigram = aiter.Value().nextstate;
510  }
511 
512  // Each state's context is determined by the subtree it is under from the
513  // unigram state.
514  std::queue<std::pair<StateId, Label>> label_queue;
515  std::vector<bool> visited(num_states);
516  // Force an epsilon link to the start state.
517  label_queue.push(std::make_pair(fst.Start(), 0));
518  for (ArcIterator<Fst<A>> aiter(fst, unigram); !aiter.Done(); aiter.Next()) {
519  label_queue.push(
520  std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
521  }
522  // investigate states in breadth first fashion to assign context words.
523  while (!label_queue.empty()) {
524  std::pair<StateId, Label> &now = label_queue.front();
525  if (!visited[now.first]) {
526  context[now.first] = now.second;
527  visited[now.first] = true;
528  for (ArcIterator<Fst<A>> aiter(fst, now.first); !aiter.Done();
529  aiter.Next()) {
530  const Arc &arc = aiter.Value();
531  if (arc.ilabel != 0) {
532  label_queue.push(std::make_pair(arc.nextstate, now.second));
533  }
534  }
535  }
536  label_queue.pop();
537  }
538  visited.clear();
539 
540  // The arc from the start state should be assigned an epsilon to put it
541  // in front of the all other labels (which makes Start state 1 after
542  // unigram which is state 0).
543  context[fst.Start()] = 0;
544 
545  // Build the tree of contexts fst by reversing the epsilon arcs from fst.
546  VectorFst<Arc> context_fst;
547  uint64 num_final = 0;
548  for (int i = 0; i < num_states; ++i) {
549  if (fst.Final(i) != Weight::Zero()) {
550  ++num_final;
551  }
552  context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
553  }
554  context_fst.SetStart(unigram);
555  context_fst.SetInputSymbols(fst.InputSymbols());
556  context_fst.SetOutputSymbols(fst.OutputSymbols());
557  int64 num_context_arcs = 0;
558  int64 num_futures = 0;
559  for (StateIterator<Fst<A>> siter(fst); !siter.Done(); siter.Next()) {
560  const StateId &state = siter.Value();
561  num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
562  ArcIterator<Fst<A>> aiter(fst, state);
563  if (!aiter.Done()) {
564  const Arc &arc = aiter.Value();
565  // this arc goes from state to arc.nextstate, so create an arc from
566  // arc.nextstate to state to reverse it.
567  if (arc.ilabel == 0) {
568  context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
569  arc.weight, state));
570  num_context_arcs++;
571  }
572  }
573  }
574  if (num_context_arcs != context_fst.NumStates() - 1) {
575  FSTERROR() << "Number of contexts arcs != number of states - 1";
576  SetProperties(kError, kError);
577  return;
578  }
579  if (context_fst.NumStates() != num_states) {
580  FSTERROR() << "Number of contexts != number of states";
581  SetProperties(kError, kError);
582  return;
583  }
584  int64 context_props =
585  context_fst.Properties(kIDeterministic | kILabelSorted, true);
586  if (!(context_props & kIDeterministic)) {
587  FSTERROR() << "Input Fst is not structured properly";
588  SetProperties(kError, kError);
589  return;
590  }
591  if (!(context_props & kILabelSorted)) {
592  ArcSort(&context_fst, ILabelCompare<Arc>());
593  }
594 
595  uint64 b64;
596  Weight weight;
597  Label label = kNoLabel;
598  const size_t storage = Storage(num_states, num_futures, num_final);
599  MappedFile *data_region = MappedFile::Allocate(storage);
600  char *data = static_cast<char *>(data_region->mutable_data());
601  memset(data, 0, storage);
602  size_t offset = 0;
603  memcpy(data + offset, reinterpret_cast<char *>(&num_states),
604  sizeof(num_states));
605  offset += sizeof(num_states);
606  memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
607  sizeof(num_futures));
608  offset += sizeof(num_futures);
609  memcpy(data + offset, reinterpret_cast<char *>(&num_final),
610  sizeof(num_final));
611  offset += sizeof(num_final);
612  uint64 *context_bits = reinterpret_cast<uint64 *>(data + offset);
613  offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
614  uint64 *future_bits = reinterpret_cast<uint64 *>(data + offset);
615  offset +=
616  BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
617  uint64 *final_bits = reinterpret_cast<uint64 *>(data + offset);
618  offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
619  Label *context_words = reinterpret_cast<Label *>(data + offset);
620  offset += (num_states + 1) * sizeof(label);
621  Label *future_words = reinterpret_cast<Label *>(data + offset);
622  offset += num_futures * sizeof(label);
623  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
624  Weight *backoff = reinterpret_cast<Weight *>(data + offset);
625  offset += (num_states + 1) * sizeof(weight);
626  Weight *final_probs = reinterpret_cast<Weight *>(data + offset);
627  offset += num_final * sizeof(weight);
628  Weight *future_probs = reinterpret_cast<Weight *>(data + offset);
629  int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
630  final_bit = 0;
631 
632  // pseudo-root bits
633  BitmapIndex::Set(context_bits, context_bit++);
634  ++context_bit;
635  context_words[context_arc] = label;
636  backoff[context_arc] = Weight::Zero();
637  context_arc++;
638 
639  ++future_bit;
640  if (order_out) {
641  order_out->clear();
642  order_out->resize(num_states);
643  }
644 
645  std::queue<StateId> context_q;
646  context_q.push(context_fst.Start());
647  StateId state_number = 0;
648  while (!context_q.empty()) {
649  const StateId &state = context_q.front();
650  if (order_out) {
651  (*order_out)[state] = state_number;
652  }
653 
654  const Weight final_weight = context_fst.Final(state);
655  if (final_weight != Weight::Zero()) {
656  BitmapIndex::Set(final_bits, state_number);
657  final_probs[final_bit] = final_weight;
658  ++final_bit;
659  }
660 
661  for (ArcIterator<VectorFst<A>> aiter(context_fst, state); !aiter.Done();
662  aiter.Next()) {
663  const Arc &arc = aiter.Value();
664  context_words[context_arc] = arc.ilabel;
665  backoff[context_arc] = arc.weight;
666  ++context_arc;
667  BitmapIndex::Set(context_bits, context_bit++);
668  context_q.push(arc.nextstate);
669  }
670  ++context_bit;
671 
672  for (ArcIterator<Fst<A>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
673  const Arc &arc = aiter.Value();
674  if (arc.ilabel != 0) {
675  future_words[future_arc] = arc.ilabel;
676  future_probs[future_arc] = arc.weight;
677  ++future_arc;
678  BitmapIndex::Set(future_bits, future_bit++);
679  }
680  }
681  ++future_bit;
682  ++state_number;
683  context_q.pop();
684  }
685 
686  if ((state_number != num_states) || (context_bit != num_states * 2 + 1) ||
687  (context_arc != num_states) || (future_arc != num_futures) ||
688  (future_bit != num_futures + num_states + 1) ||
689  (final_bit != num_final)) {
690  FSTERROR() << "Structure problems detected during construction";
691  SetProperties(kError, kError);
692  return;
693  }
694 
695  Init(data, false, data_region);
696 }
697 
698 template <typename A>
699 inline void NGramFstImpl<A>::Init(const char *data, bool owned,
700  MappedFile *data_region) {
701  if (owned_) {
702  delete[] data_;
703  }
704  data_region_.reset(data_region);
705  owned_ = owned;
706  data_ = data;
707  size_t offset = 0;
708  num_states_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
709  offset += sizeof(num_states_);
710  num_futures_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
711  offset += sizeof(num_futures_);
712  num_final_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
713  offset += sizeof(num_final_);
714  uint64 bits;
715  size_t context_bits = num_states_ * 2 + 1;
716  size_t future_bits = num_futures_ + num_states_ + 1;
717  context_ = reinterpret_cast<const uint64 *>(data_ + offset);
718  offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
719  future_ = reinterpret_cast<const uint64 *>(data_ + offset);
720  offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
721  final_ = reinterpret_cast<const uint64 *>(data_ + offset);
722  offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
723  context_words_ = reinterpret_cast<const Label *>(data_ + offset);
724  offset += (num_states_ + 1) * sizeof(*context_words_);
725  future_words_ = reinterpret_cast<const Label *>(data_ + offset);
726  offset += num_futures_ * sizeof(*future_words_);
727  offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
728  backoff_ = reinterpret_cast<const Weight *>(data_ + offset);
729  offset += (num_states_ + 1) * sizeof(*backoff_);
730  final_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
731  offset += num_final_ * sizeof(*final_probs_);
732  future_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
733 
734  context_index_.BuildIndex(context_, context_bits);
735  future_index_.BuildIndex(future_, future_bits);
736  final_index_.BuildIndex(final_, num_states_);
737 
738  select_root_ = context_index_.Select0s(0);
739  if (context_index_.Rank1(0) != 0 || select_root_.first != 1 ||
740  context_index_.Get(2) == false) {
741  FSTERROR() << "Malformed file";
742  SetProperties(kError, kError);
743  return;
744  }
745  root_children_ = context_words_ + context_index_.Rank1(2);
746  start_ = 1;
747 }
748 
749 template <typename A>
750 inline typename A::StateId NGramFstImpl<A>::Transition(
751  const std::vector<Label> &context, Label future) const {
752  const Label *children = root_children_;
753  size_t num_children = select_root_.second - 2;
754  const Label *loc =
755  std::lower_bound(children, children + num_children, future);
756  if (loc == children + num_children || *loc != future) {
757  return context_index_.Rank1(0);
758  }
759  size_t node = 2 + loc - children;
760  size_t node_rank = context_index_.Rank1(node);
761  std::pair<size_t, size_t> zeros =
762  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
763  size_t first_child = zeros.first + 1;
764  if (context_index_.Get(first_child) == false) {
765  return context_index_.Rank1(node);
766  }
767  size_t last_child = zeros.second - 1;
768  for (int word = context.size() - 1; word >= 0; --word) {
769  children = context_words_ + context_index_.Rank1(first_child);
770  loc = std::lower_bound(children, children + last_child - first_child + 1,
771  context[word]);
772  if (loc == children + last_child - first_child + 1 ||
773  *loc != context[word]) {
774  break;
775  }
776  node = first_child + loc - children;
777  node_rank = context_index_.Rank1(node);
778  zeros =
779  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
780  first_child = zeros.first + 1;
781  if (context_index_.Get(first_child) == false) break;
782  last_child = zeros.second - 1;
783  }
784  return context_index_.Rank1(node);
785 }
786 
787 } // namespace internal
788 
789 /*****************************************************************************/
790 template <class A>
791 class NGramFstMatcher : public MatcherBase<A> {
792  public:
793  typedef A Arc;
794  typedef typename A::Label Label;
795  typedef typename A::StateId StateId;
796  typedef typename A::Weight Weight;
797 
798  // This makes a copy of the FST.
800  : owned_fst_(fst.Copy()),
801  fst_(*owned_fst_),
802  inst_(fst_.inst_),
803  match_type_(match_type),
804  current_loop_(false),
805  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
806  if (match_type_ == MATCH_OUTPUT) {
807  std::swap(loop_.ilabel, loop_.olabel);
808  }
809  }
810 
811  // This doesn't copy the FST.
813  : fst_(*fst),
814  inst_(fst_.inst_),
815  match_type_(match_type),
816  current_loop_(false),
817  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
818  if (match_type_ == MATCH_OUTPUT) {
819  std::swap(loop_.ilabel, loop_.olabel);
820  }
821  }
822 
823  // This makes a copy of the FST.
824  NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
825  : owned_fst_(matcher.fst_.Copy(safe)),
826  fst_(*owned_fst_),
827  inst_(matcher.inst_),
828  match_type_(matcher.match_type_),
829  current_loop_(false),
830  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
831  if (match_type_ == MATCH_OUTPUT) {
832  std::swap(loop_.ilabel, loop_.olabel);
833  }
834  }
835 
836  NGramFstMatcher<A> *Copy(bool safe = false) const override {
837  return new NGramFstMatcher<A>(*this, safe);
838  }
839 
840  MatchType Type(bool test) const override { return match_type_; }
841 
842  const Fst<A> &GetFst() const override { return fst_; }
843 
844  uint64 Properties(uint64 props) const override { return props; }
845 
846  void SetState(StateId s) final {
847  fst_.GetImpl()->SetInstFuture(s, &inst_);
848  current_loop_ = false;
849  }
850 
851  bool Find(Label label) final {
852  const Label nolabel = kNoLabel;
853  done_ = true;
854  if (label == 0 || label == nolabel) {
855  if (label == 0) {
856  current_loop_ = true;
857  loop_.nextstate = inst_.state_;
858  }
859  // The unigram state has no epsilon arc.
860  if (inst_.state_ != 0) {
861  arc_.ilabel = arc_.olabel = 0;
862  fst_.GetImpl()->SetInstNode(&inst_);
863  arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
864  fst_.GetImpl()->context_index_.Select1(
865  fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
866  arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
867  done_ = false;
868  }
869  } else {
870  current_loop_ = false;
871  const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
872  const Label *end = start + inst_.num_futures_;
873  const Label *search = std::lower_bound(start, end, label);
874  if (search != end && *search == label) {
875  size_t state = search - start;
876  arc_.ilabel = arc_.olabel = label;
877  arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
878  fst_.GetImpl()->SetInstContext(&inst_);
879  arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
880  done_ = false;
881  }
882  }
883  return !Done();
884  }
885 
886  bool Done() const final { return !current_loop_ && done_; }
887 
888  const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; }
889 
890  void Next() final {
891  if (current_loop_) {
892  current_loop_ = false;
893  } else {
894  done_ = true;
895  }
896  }
897 
898  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
899 
900  private:
901  std::unique_ptr<NGramFst<A>> owned_fst_;
902  const NGramFst<A> &fst_;
903  NGramFstInst<A> inst_;
904  MatchType match_type_; // Supplied by caller
905  bool done_;
906  Arc arc_;
907  bool current_loop_; // Current arc is the implicit loop
908  Arc loop_;
909 };
910 
911 /*****************************************************************************/
912 // Specialization for NGramFst; see generic version in fst.h
913 // for sample usage (but use the ProdLmFst type!). This version
914 // should inline.
915 template <class A>
916 class StateIterator<NGramFst<A>> : public StateIteratorBase<A> {
917  public:
918  typedef typename A::StateId StateId;
919 
920  explicit StateIterator(const NGramFst<A> &fst)
921  : s_(0), num_states_(fst.NumStates()) {}
922 
923  bool Done() const final { return s_ >= num_states_; }
924 
925  StateId Value() const final { return s_; }
926 
927  void Next() final { ++s_; }
928 
929  void Reset() final { s_ = 0; }
930 
931  private:
932  StateId s_;
933  StateId num_states_;
934 };
935 
936 /*****************************************************************************/
937 template <class A>
938 class ArcIterator<NGramFst<A>> : public ArcIteratorBase<A> {
939  public:
940  typedef A Arc;
941  typedef typename A::Label Label;
942  typedef typename A::StateId StateId;
943  typedef typename A::Weight Weight;
944 
945  ArcIterator(const NGramFst<A> &fst, StateId state)
946  : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
947  inst_ = fst.inst_;
948  impl_->SetInstFuture(state, &inst_);
949  impl_->SetInstNode(&inst_);
950  }
951 
952  bool Done() const final {
953  return i_ >=
954  ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1);
955  }
956 
957  const Arc &Value() const final {
958  bool eps = (inst_.node_ != 0 && i_ == 0);
959  StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
960  if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
961  arc_.ilabel = arc_.olabel =
962  eps ? 0 : impl_->future_words_[inst_.offset_ + state];
963  lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
964  }
965  if (flags_ & lazy_ & kArcNextStateValue) {
966  if (eps) {
967  arc_.nextstate =
968  impl_->context_index_.Rank1(impl_->context_index_.Select1(
969  impl_->context_index_.Rank0(inst_.node_) - 1));
970  } else {
971  if (lazy_ & kArcNextStateValue) {
972  impl_->SetInstContext(&inst_); // first time only.
973  }
974  arc_.nextstate = impl_->Transition(
975  inst_.context_, impl_->future_words_[inst_.offset_ + state]);
976  }
977  lazy_ &= ~kArcNextStateValue;
978  }
979  if (flags_ & lazy_ & kArcWeightValue) {
980  arc_.weight = eps ? impl_->backoff_[inst_.state_]
981  : impl_->future_probs_[inst_.offset_ + state];
982  lazy_ &= ~kArcWeightValue;
983  }
984  return arc_;
985  }
986 
987  void Next() final {
988  ++i_;
989  lazy_ = ~0;
990  }
991 
992  size_t Position() const final { return i_; }
993 
994  void Reset() final {
995  i_ = 0;
996  lazy_ = ~0;
997  }
998 
999  void Seek(size_t a) final {
1000  if (i_ != a) {
1001  i_ = a;
1002  lazy_ = ~0;
1003  }
1004  }
1005 
1006  uint8 Flags() const final { return flags_; }
1007 
1008  void SetFlags(uint8 flags, uint8 mask) final {
1009  flags_ &= ~mask;
1010  flags_ |= (flags & kArcValueFlags);
1011  }
1012 
1013  private:
1014  mutable Arc arc_;
1015  mutable uint8 lazy_;
1016  const internal::NGramFstImpl<A> *impl_; // Borrowed reference.
1017  mutable NGramFstInst<A> inst_;
1018 
1019  size_t i_;
1020  uint8 flags_;
1021 };
1022 
1023 } // namespace fst
1024 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
internal::NGramFstImpl< A > Impl
Definition: ngram-fst.h:334
void GetStates(const std::vector< Label > &context, std::vector< StateId > *state) const
Definition: ngram-fst.h:366
void AddArc(StateId s, const Arc &arc) override
Definition: mutable-fst.h:312
void * mutable_data() const
Definition: mapped-file.h:34
NGramFst< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:375
void SetStart(StateId s) override
Definition: mutable-fst.h:284
size_t num_futures_
Definition: ngram-fst.h:42
constexpr uint64 kInitialAcyclic
Definition: properties.h:99
void Next() final
Definition: ngram-fst.h:890
constexpr int kNoLabel
Definition: fst.h:179
size_t NumInputEpsilons(StateId state) const
Definition: ngram-fst.h:156
uint64_t uint64
Definition: types.h:32
const Arc & Value() const final
Definition: ngram-fst.h:888
void SetFinal(StateId s, Weight weight=Weight::One()) override
Definition: mutable-fst.h:289
virtual size_t NumArcs(StateId) const =0
void SetInstContext(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:209
uint64 Properties(uint64 props) const override
Definition: ngram-fst.h:844
MatchType
Definition: fst.h:171
StateId state_
Definition: ngram-fst.h:41
constexpr uint64 kILabelSorted
Definition: properties.h:77
size_t Position() const final
Definition: ngram-fst.h:992
size_t StorageSize() const
Definition: ngram-fst.h:417
static NGramFstImpl< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:99
#define LOG(type)
Definition: log.h:46
StateId Start() const override
Definition: fst.h:883
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:38
StateId node_state_
Definition: ngram-fst.h:45
const Arc & Value() const final
Definition: ngram-fst.h:957
constexpr int kNoStateId
Definition: fst.h:180
NGramFstMatcher< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:836
constexpr uint64 kExpanded
Definition: properties.h:29
const Arc & Value() const
Definition: fst.h:512
NGramFstMatcher(const NGramFstMatcher< A > &matcher, bool safe=false)
Definition: ngram-fst.h:824
static MappedFile * Allocate(size_t size, size_t align=kArchAlignment)
Definition: mapped-file.cc:103
virtual uint64 Properties(uint64 mask, bool test) const =0
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
Definition: ngram-fst.h:463
NGramFst(const Fst< A > &fst, std::vector< StateId > *order_out)
Definition: ngram-fst.h:339
const Fst< A > & GetFst() const override
Definition: ngram-fst.h:842
virtual size_t NumInputEpsilons(StateId) const =0
int64_t int64
Definition: types.h:27
constexpr uint64 kEpsilons
Definition: properties.h:62
#define FSTERROR()
Definition: util.h:36
StateId AddState() override
Definition: mutable-fst.h:302
NGramFst(const Fst< A > &dst)
Definition: ngram-fst.h:336
size_t NumArcs(StateId state, NGramFstInst< A > *inst=nullptr) const
Definition: ngram-fst.h:146
const std::vector< Label > & GetContext(StateId s, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:230
void SetOutputSymbols(const SymbolTable *osyms) override
Definition: mutable-fst.h:382
static size_t StorageSize(size_t num_bits)
Definition: bitmap-index.h:36
NGramFst(const NGramFst< A > &fst, bool safe=false)
Definition: ngram-fst.h:344
StateIteratorBase< Arc > * base
Definition: fst.h:359
A::Weight Weight
Definition: ngram-fst.h:40
static void Set(uint64 *bits, size_t index)
Definition: bitmap-index.h:49
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:94
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:356
A::StateId StateId
Definition: ngram-fst.h:39
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:89
void SetInstNode(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:202
constexpr uint64 kNotString
Definition: properties.h:122
static NGramFst< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:379
uint8_t uint8
Definition: types.h:29
void SetInputSymbols(const SymbolTable *isyms) override
Definition: mutable-fst.h:377
std::vector< Label > context_
Definition: ngram-fst.h:46
MatcherBase< A > * InitMatcher(MatchType match_type) const override
Definition: ngram-fst.h:413
bool Write(const std::string &source) const override
Definition: ngram-fst.h:402
constexpr uint64 kOLabelSorted
Definition: properties.h:82
constexpr uint64 kIEpsilons
Definition: properties.h:67
StateId Value() const final
Definition: ngram-fst.h:925
void SetNumStates(int64 numstates)
Definition: fst.h:148
StateId NumStates() const override
Definition: expanded-fst.h:120
NGramFstMatcher(const NGramFst< A > &fst, MatchType match_type)
Definition: ngram-fst.h:799
StateId nstates
Definition: fst.h:361
constexpr uint64 kIDeterministic
Definition: properties.h:52
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:223
MatchType Type(bool test) const override
Definition: ngram-fst.h:840
size_t NumArcs(StateId s) const override
Definition: ngram-fst.h:371
void SetInstFuture(StateId state, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:193
virtual StateId Start() const =0
bool Done() const
Definition: fst.h:508
A::StateId StateId
Definition: ngram-fst.h:795
constexpr uint64 kAcceptor
Definition: properties.h:47
NGramFst(const char *data, bool owned)
Definition: ngram-fst.h:350
void SetState(StateId s) final
Definition: ngram-fst.h:846
NGramFstImpl(const Fst< A > &fst)
Definition: ngram-fst.h:86
const std::vector< Label > GetContext(StateId s) const
Definition: ngram-fst.h:360
void Seek(size_t a) final
Definition: ngram-fst.h:999
bool WriteFile(const std::string &source) const
Definition: fst.h:309
StateId NumStates() const
Definition: ngram-fst.h:166
A::Label Label
Definition: ngram-fst.h:332
size_t StorageSize() const
Definition: ngram-fst.h:236
constexpr uint64 kCoAccessible
Definition: properties.h:112
void SetStart(int64 start)
Definition: fst.h:146
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
Definition: ngram-fst.h:398
void SetFlags(uint8 flags, uint8 mask) final
Definition: ngram-fst.h:1008
StateId context_state_
Definition: ngram-fst.h:47
virtual const SymbolTable * InputSymbols() const =0
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:155
ssize_t Priority(StateId s) final
Definition: ngram-fst.h:898
ArcIterator(const NGramFst< A > &fst, StateId state)
Definition: ngram-fst.h:945
static size_t Storage(uint64 num_states, uint64 num_futures, uint64 num_final)
Definition: ngram-fst.h:173
Weight Final(StateId s) const override
Definition: fst.h:885
size_t NumOutputEpsilons(StateId state) const
Definition: ngram-fst.h:162
NGramFstImpl(const NGramFstImpl &other)
Definition: ngram-fst.h:88
bool Done() const final
Definition: ngram-fst.h:886
constexpr uint64 kWeighted
Definition: properties.h:87
constexpr uint64 kError
Definition: properties.h:35
StateId Start() const
Definition: ngram-fst.h:136
NGramFstMatcher(const NGramFst< A > *fst, MatchType match_type)
Definition: ngram-fst.h:812
A::Label Label
Definition: ngram-fst.h:38
constexpr uint64 kNotTopSorted
Definition: properties.h:104
ArcIteratorBase< Arc > * base
Definition: fst.h:470
A::Weight Weight
Definition: ngram-fst.h:333
void InitStateIterator(StateIteratorData< A > *data) const override
Definition: ngram-fst.h:406
Weight Final(StateId state) const
Definition: ngram-fst.h:138
void InitStateIterator(StateIteratorData< A > *data) const
Definition: ngram-fst.h:168
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
Definition: ngram-fst.h:126
A::StateId StateId
Definition: ngram-fst.h:331
uint64 Properties(uint64 mask, bool test) const override
Definition: fst.h:897
bool Find(Label label) final
Definition: ngram-fst.h:851
constexpr uint64 kODeterministic
Definition: properties.h:57
static bool HasRequiredStructure(const Fst< A > &fst)
Definition: ngram-fst.h:425
StateIterator(const NGramFst< A > &fst)
Definition: ngram-fst.h:920
constexpr uint64 kOEpsilons
Definition: properties.h:72
constexpr uint64 kAccessible
Definition: properties.h:107
static NGramFst< A > * Read(const std::string &source)
Definition: ngram-fst.h:384
static bool HasRequiredProps(const Fst< A > &fst)
Definition: ngram-fst.h:419
constexpr uint64 kCyclic
Definition: properties.h:92
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:516