FST  openfst-1.7.3
OpenFst Library
ngram-fst.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // NgramFst implements a n-gram language model based upon the LOUDS data
5 // structure. Please refer to "Unary Data Structures for Language Models"
6 // http://research.google.com/pubs/archive/37218.pdf
7 
8 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
9 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
10 
11 #include <algorithm>
12 #include <cstddef>
13 #include <cstring>
14 #include <iostream>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include <fst/compat.h>
20 #include <fst/log.h>
22 #include <fstream>
23 #include <fst/fstlib.h>
24 #include <fst/mapped-file.h>
25 
26 namespace fst {
27 template <class A>
28 class NGramFst;
29 
30 template <class A>
32 
33 // Instance data containing mutable state for bookkeeping repeated access to
34 // the same state.
35 template <class A>
36 struct NGramFstInst {
37  typedef typename A::Label Label;
38  typedef typename A::StateId StateId;
39  typedef typename A::Weight Weight;
40  StateId state_;
41  size_t num_futures_;
42  size_t offset_;
43  size_t node_;
44  StateId node_state_;
45  std::vector<Label> context_;
46  StateId context_state_;
48  : state_(kNoStateId),
49  node_state_(kNoStateId),
50  context_state_(kNoStateId) {}
51 };
52 
53 namespace internal {
54 
55 // Implementation class for LOUDS based NgramFst interface.
56 template <class A>
57 class NGramFstImpl : public FstImpl<A> {
60  using FstImpl<A>::SetType;
62 
63  friend class ArcIterator<NGramFst<A>>;
64  friend class NGramFstMatcher<A>;
65 
66  public:
70 
71  typedef A Arc;
72  typedef typename A::Label Label;
73  typedef typename A::StateId StateId;
74  typedef typename A::Weight Weight;
75 
77  SetType("ngram");
78  SetInputSymbols(nullptr);
79  SetOutputSymbols(nullptr);
80  SetProperties(kStaticProperties);
81  }
82 
83  NGramFstImpl(const Fst<A> &fst, std::vector<StateId> *order_out);
84 
85  explicit NGramFstImpl(const Fst<A> &fst) : NGramFstImpl(fst, nullptr) {}
86 
87  NGramFstImpl(const NGramFstImpl &other) {
88  FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false.";
89  SetProperties(kError, kError);
90  }
91 
92  ~NGramFstImpl() override {
93  if (owned_) {
94  delete[] data_;
95  }
96  }
97 
98  static NGramFstImpl<A> *Read(std::istream &strm, // NOLINT
99  const FstReadOptions &opts) {
100  auto impl = std::make_unique<NGramFstImpl<A>>();
101  FstHeader hdr;
102  if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
103  uint64 num_states, num_futures, num_final;
104  const size_t offset =
105  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
106  // Peek at num_states and num_futures to see how much more needs to be read.
107  strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
108  strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
109  strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
110  size_t size = Storage(num_states, num_futures, num_final);
111  MappedFile *data_region = MappedFile::Allocate(size);
112  char *data = reinterpret_cast<char *>(data_region->mutable_data());
113  // Copy num_states, num_futures and num_final back into data.
114  memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
115  memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
116  sizeof(num_futures));
117  memcpy(data + sizeof(num_states) + sizeof(num_futures),
118  reinterpret_cast<char *>(&num_final), sizeof(num_final));
119  strm.read(data + offset, size - offset);
120  if (strm.fail()) return nullptr;
121  impl->Init(data, false, data_region);
122  return impl.release();
123  }
124 
125  bool Write(std::ostream &strm, // NOLINT
126  const FstWriteOptions &opts) const {
127  FstHeader hdr;
128  hdr.SetStart(Start());
129  hdr.SetNumStates(num_states_);
130  WriteHeader(strm, opts, kFileVersion, &hdr);
131  strm.write(data_, StorageSize());
132  return !strm.fail();
133  }
134 
135  StateId Start() const { return start_; }
136 
137  Weight Final(StateId state) const {
138  if (final_index_.Get(state)) {
139  return final_probs_[final_index_.Rank1(state)];
140  } else {
141  return Weight::Zero();
142  }
143  }
144 
145  size_t NumArcs(StateId state, NGramFstInst<A> *inst = nullptr) const {
146  if (inst == nullptr) {
147  const std::pair<size_t, size_t> zeros =
148  (state == 0) ? select_root_ : future_index_.Select0s(state);
149  return zeros.second - zeros.first - 1;
150  }
151  SetInstFuture(state, inst);
152  return inst->num_futures_ + ((state == 0) ? 0 : 1);
153  }
154 
155  size_t NumInputEpsilons(StateId state) const {
156  // State 0 has no parent, thus no backoff.
157  if (state == 0) return 0;
158  return 1;
159  }
160 
161  size_t NumOutputEpsilons(StateId state) const {
162  return NumInputEpsilons(state);
163  }
164 
165  StateId NumStates() const { return num_states_; }
166 
168  data->base = 0;
169  data->nstates = num_states_;
170  }
171 
172  static size_t Storage(uint64 num_states, uint64 num_futures,
173  uint64 num_final) {
174  uint64 b64;
175  Weight weight;
176  Label label;
177  size_t offset =
178  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
179  offset +=
180  sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) +
181  BitmapIndex::StorageSize(num_futures + num_states + 1) +
182  BitmapIndex::StorageSize(num_states));
183  offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
184  // Pad for alignemnt, see
185  // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
186  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
187  offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
188  (num_futures + 1) * sizeof(weight);
189  return offset;
190  }
191 
192  void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
193  if (inst->state_ != state) {
194  inst->state_ = state;
195  const std::pair<size_t, size_t> zeros = future_index_.Select0s(state);
196  inst->num_futures_ = zeros.second - zeros.first - 1;
197  inst->offset_ = future_index_.Rank1(zeros.first + 1);
198  }
199  }
200 
201  void SetInstNode(NGramFstInst<A> *inst) const {
202  if (inst->node_state_ != inst->state_) {
203  inst->node_state_ = inst->state_;
204  inst->node_ = context_index_.Select1(inst->state_);
205  }
206  }
207 
208  void SetInstContext(NGramFstInst<A> *inst) const {
209  SetInstNode(inst);
210  if (inst->context_state_ != inst->state_) {
211  inst->context_state_ = inst->state_;
212  inst->context_.clear();
213  size_t node = inst->node_;
214  while (node != 0) {
215  inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
216  node = context_index_.Select1(context_index_.Rank0(node) - 1);
217  }
218  }
219  }
220 
221  // Access to the underlying representation
222  const char *GetData(size_t *data_size) const {
223  *data_size = StorageSize();
224  return data_;
225  }
226 
227  void Init(const char *data, bool owned, MappedFile *file = nullptr);
228 
229  const std::vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
230  SetInstFuture(s, inst);
231  SetInstContext(inst);
232  return inst->context_;
233  }
234 
235  size_t StorageSize() const {
236  return Storage(num_states_, num_futures_, num_final_);
237  }
238 
239  void GetStates(const std::vector<Label> &context,
240  std::vector<StateId> *states) const;
241 
242  private:
243  StateId Transition(const std::vector<Label> &context, Label future) const;
244 
245  // Properties always true for this Fst class.
246  static const uint64 kStaticProperties =
251  // Current file format version.
252  static const int kFileVersion = 4;
253  // Minimum file format version supported.
254  static const int kMinFileVersion = 4;
255 
256  std::unique_ptr<MappedFile> data_region_;
257  const char *data_ = nullptr;
258  bool owned_ = false; // True if we own data_
259  StateId start_ = fst::kNoStateId;
260  uint64 num_states_ = 0;
261  uint64 num_futures_ = 0;
262  uint64 num_final_ = 0;
263  std::pair<size_t, size_t> select_root_;
264  const Label *root_children_ = nullptr;
265  // borrowed references
266  const uint64 *context_ = nullptr;
267  const uint64 *future_ = nullptr;
268  const uint64 *final_ = nullptr;
269  const Label *context_words_ = nullptr;
270  const Label *future_words_ = nullptr;
271  const Weight *backoff_ = nullptr;
272  const Weight *final_probs_ = nullptr;
273  const Weight *future_probs_ = nullptr;
274  BitmapIndex context_index_;
275  BitmapIndex future_index_;
276  BitmapIndex final_index_;
277 };
278 
279 template <typename A>
281  const std::vector<Label> &context,
282  std::vector<typename A::StateId> *states) const {
283  states->clear();
284  states->push_back(0);
285  typename std::vector<Label>::const_reverse_iterator cit = context.rbegin();
286  const Label *children = root_children_;
287  size_t num_children = select_root_.second - 2;
288  const Label *loc = std::lower_bound(children, children + num_children, *cit);
289  if (loc == children + num_children || *loc != *cit) return;
290  size_t node = 2 + loc - children;
291  states->push_back(context_index_.Rank1(node));
292  if (context.size() == 1) return;
293  size_t node_rank = context_index_.Rank1(node);
294  std::pair<size_t, size_t> zeros =
295  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
296  size_t first_child = zeros.first + 1;
297  ++cit;
298  if (context_index_.Get(first_child) != false) {
299  size_t last_child = zeros.second - 1;
300  while (cit != context.rend()) {
301  children = context_words_ + context_index_.Rank1(first_child);
302  loc = std::lower_bound(children, children + last_child - first_child + 1,
303  *cit);
304  if (loc == children + last_child - first_child + 1 || *loc != *cit) {
305  break;
306  }
307  ++cit;
308  node = first_child + loc - children;
309  states->push_back(context_index_.Rank1(node));
310  node_rank = context_index_.Rank1(node);
311  zeros =
312  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
313  first_child = zeros.first + 1;
314  if (context_index_.Get(first_child) == false) break;
315  last_child = zeros.second - 1;
316  }
317  }
318 }
319 
320 } // namespace internal
321 
322 /*****************************************************************************/
323 template <class A>
324 class NGramFst : public ImplToExpandedFst<internal::NGramFstImpl<A>> {
325  friend class ArcIterator<NGramFst<A>>;
326  friend class NGramFstMatcher<A>;
327 
328  public:
329  typedef A Arc;
330  typedef typename A::StateId StateId;
331  typedef typename A::Label Label;
332  typedef typename A::Weight Weight;
334 
335  explicit NGramFst(const Fst<A> &dst)
336  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(dst, nullptr)) {}
337 
338  NGramFst(const Fst<A> &fst, std::vector<StateId> *order_out)
339  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst, order_out)) {}
340 
341  // Because the NGramFstImpl is a const stateless data structure, there
342  // is never a need to do anything beside copy the reference.
343  NGramFst(const NGramFst<A> &fst, bool safe = false)
344  : ImplToExpandedFst<Impl>(fst, false) {}
345 
346  NGramFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
347 
348  // Non-standard constructor to initialize NGramFst directly from data.
349  NGramFst(const char *data, bool owned)
350  : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {
351  GetMutableImpl()->Init(data, owned, nullptr);
352  }
353 
354  // Get method that gets the data associated with Init().
355  const char *GetData(size_t *data_size) const {
356  return GetImpl()->GetData(data_size);
357  }
358 
359  const std::vector<Label> GetContext(StateId s) const {
360  return GetImpl()->GetContext(s, &inst_);
361  }
362 
363  // Consumes as much as possible of context from right to left, returns the
364  // the states corresponding to the increasingly conditioned input sequence.
365  void GetStates(const std::vector<Label> &context,
366  std::vector<StateId> *state) const {
367  return GetImpl()->GetStates(context, state);
368  }
369 
370  size_t NumArcs(StateId s) const override {
371  return GetImpl()->NumArcs(s, &inst_);
372  }
373 
374  NGramFst<A> *Copy(bool safe = false) const override {
375  return new NGramFst(*this, safe);
376  }
377 
378  static NGramFst<A> *Read(std::istream &strm, const FstReadOptions &opts) {
379  Impl *impl = Impl::Read(strm, opts);
380  return impl ? new NGramFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
381  }
382 
383  static NGramFst<A> *Read(const std::string &filename) {
384  if (!filename.empty()) {
385  std::ifstream strm(filename,
386  std::ios_base::in | std::ios_base::binary);
387  if (!strm.good()) {
388  LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
389  return nullptr;
390  }
391  return Read(strm, FstReadOptions(filename));
392  } else {
393  return Read(std::cin, FstReadOptions("standard input"));
394  }
395  }
396 
397  bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
398  return GetImpl()->Write(strm, opts);
399  }
400 
401  bool Write(const std::string &filename) const override {
402  return Fst<A>::WriteFile(filename);
403  }
404 
405  inline void InitStateIterator(StateIteratorData<A> *data) const override {
406  GetImpl()->InitStateIterator(data);
407  }
408 
409  inline void InitArcIterator(StateId s,
410  ArcIteratorData<A> *data) const override;
411 
412  MatcherBase<A> *InitMatcher(MatchType match_type) const override {
413  return new NGramFstMatcher<A>(this, match_type);
414  }
415 
416  size_t StorageSize() const { return GetImpl()->StorageSize(); }
417 
418  static bool HasRequiredProps(const Fst<A> &fst) {
419  static const auto props =
421  return fst.Properties(props, true) == props;
422  }
423 
424  static bool HasRequiredStructure(const Fst<A> &fst) {
425  if (!HasRequiredProps(fst)) {
426  return false;
427  }
428  typename A::StateId unigram = fst.Start();
429  while (true) { // Follows epsilon arc chain to find unigram state.
430  if (unigram == fst::kNoStateId) return false; // No unigram state.
431  typename fst::ArcIterator<Fst<A>> aiter(fst, unigram);
432  if (aiter.Done() || aiter.Value().ilabel != 0) break;
433  unigram = aiter.Value().nextstate;
434  aiter.Next();
435  }
436  // Other requirement: all states other than unigram an epsilon arc.
437  for (fst::StateIterator<Fst<A>> siter(fst); !siter.Done();
438  siter.Next()) {
439  const typename A::StateId &state = siter.Value();
440  fst::ArcIterator<Fst<A>> aiter(fst, state);
441  if (state != unigram) {
442  if (aiter.Done()) return false;
443  if (aiter.Value().ilabel != 0) return false;
444  aiter.Next();
445  if (!aiter.Done() && aiter.Value().ilabel == 0) return false;
446  }
447  }
448  return true;
449  }
450 
451  private:
453  using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetMutableImpl;
454 
455  explicit NGramFst(std::shared_ptr<Impl> impl)
456  : ImplToExpandedFst<Impl>(impl) {}
457 
458  mutable NGramFstInst<A> inst_;
459 };
460 
461 template <class A>
463  ArcIteratorData<A> *data) const {
464  GetImpl()->SetInstFuture(s, &inst_);
465  GetImpl()->SetInstNode(&inst_);
466  data->base = new ArcIterator<NGramFst<A>>(*this, s);
467 }
468 
469 namespace internal {
470 
471 template <typename A>
473  std::vector<StateId> *order_out) {
474  typedef A Arc;
475  typedef typename Arc::Label Label;
476  typedef typename Arc::Weight Weight;
477  typedef typename Arc::StateId StateId;
478  SetType("ngram");
479  SetInputSymbols(fst.InputSymbols());
480  SetOutputSymbols(fst.OutputSymbols());
481  SetProperties(kStaticProperties);
482 
483  // Check basic requirements for an OpenGrm language model Fst.
484  if (!NGramFst<A>::HasRequiredProps(fst)) {
485  FSTERROR() << "NGramFst only accepts OpenGrm language models as input";
486  SetProperties(kError, kError);
487  return;
488  }
489 
490  int64 num_states = CountStates(fst);
491  Label *context = new Label[num_states];
492 
493  // Find the unigram state by starting from the start state, following
494  // epsilons.
495  StateId unigram = fst.Start();
496  while (1) {
497  if (unigram == kNoStateId) {
498  FSTERROR() << "Could not identify unigram state";
499  SetProperties(kError, kError);
500  return;
501  }
502  ArcIterator<Fst<A>> aiter(fst, unigram);
503  if (aiter.Done()) {
504  LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
505  break;
506  }
507  if (aiter.Value().ilabel != 0) break;
508  unigram = aiter.Value().nextstate;
509  }
510 
511  // Each state's context is determined by the subtree it is under from the
512  // unigram state.
513  std::queue<std::pair<StateId, Label>> label_queue;
514  std::vector<bool> visited(num_states);
515  // Force an epsilon link to the start state.
516  label_queue.push(std::make_pair(fst.Start(), 0));
517  for (ArcIterator<Fst<A>> aiter(fst, unigram); !aiter.Done(); aiter.Next()) {
518  label_queue.push(
519  std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
520  }
521  // investigate states in breadth first fashion to assign context words.
522  while (!label_queue.empty()) {
523  std::pair<StateId, Label> &now = label_queue.front();
524  if (!visited[now.first]) {
525  context[now.first] = now.second;
526  visited[now.first] = true;
527  for (ArcIterator<Fst<A>> aiter(fst, now.first); !aiter.Done();
528  aiter.Next()) {
529  const Arc &arc = aiter.Value();
530  if (arc.ilabel != 0) {
531  label_queue.push(std::make_pair(arc.nextstate, now.second));
532  }
533  }
534  }
535  label_queue.pop();
536  }
537  visited.clear();
538 
539  // The arc from the start state should be assigned an epsilon to put it
540  // in front of the all other labels (which makes Start state 1 after
541  // unigram which is state 0).
542  context[fst.Start()] = 0;
543 
544  // Build the tree of contexts fst by reversing the epsilon arcs from fst.
545  VectorFst<Arc> context_fst;
546  uint64 num_final = 0;
547  for (int i = 0; i < num_states; ++i) {
548  if (fst.Final(i) != Weight::Zero()) {
549  ++num_final;
550  }
551  context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
552  }
553  context_fst.SetStart(unigram);
554  context_fst.SetInputSymbols(fst.InputSymbols());
555  context_fst.SetOutputSymbols(fst.OutputSymbols());
556  int64 num_context_arcs = 0;
557  int64 num_futures = 0;
558  for (StateIterator<Fst<A>> siter(fst); !siter.Done(); siter.Next()) {
559  const StateId &state = siter.Value();
560  num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
561  ArcIterator<Fst<A>> aiter(fst, state);
562  if (!aiter.Done()) {
563  const Arc &arc = aiter.Value();
564  // this arc goes from state to arc.nextstate, so create an arc from
565  // arc.nextstate to state to reverse it.
566  if (arc.ilabel == 0) {
567  context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
568  arc.weight, state));
569  num_context_arcs++;
570  }
571  }
572  }
573  if (num_context_arcs != context_fst.NumStates() - 1) {
574  FSTERROR() << "Number of contexts arcs != number of states - 1";
575  SetProperties(kError, kError);
576  return;
577  }
578  if (context_fst.NumStates() != num_states) {
579  FSTERROR() << "Number of contexts != number of states";
580  SetProperties(kError, kError);
581  return;
582  }
583  int64 context_props =
584  context_fst.Properties(kIDeterministic | kILabelSorted, true);
585  if (!(context_props & kIDeterministic)) {
586  FSTERROR() << "Input Fst is not structured properly";
587  SetProperties(kError, kError);
588  return;
589  }
590  if (!(context_props & kILabelSorted)) {
591  ArcSort(&context_fst, ILabelCompare<Arc>());
592  }
593 
594  delete[] context;
595 
596  uint64 b64;
597  Weight weight;
598  Label label = kNoLabel;
599  const size_t storage = Storage(num_states, num_futures, num_final);
600  MappedFile *data_region = MappedFile::Allocate(storage);
601  char *data = reinterpret_cast<char *>(data_region->mutable_data());
602  memset(data, 0, storage);
603  size_t offset = 0;
604  memcpy(data + offset, reinterpret_cast<char *>(&num_states),
605  sizeof(num_states));
606  offset += sizeof(num_states);
607  memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
608  sizeof(num_futures));
609  offset += sizeof(num_futures);
610  memcpy(data + offset, reinterpret_cast<char *>(&num_final),
611  sizeof(num_final));
612  offset += sizeof(num_final);
613  uint64 *context_bits = reinterpret_cast<uint64 *>(data + offset);
614  offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
615  uint64 *future_bits = reinterpret_cast<uint64 *>(data + offset);
616  offset +=
617  BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
618  uint64 *final_bits = reinterpret_cast<uint64 *>(data + offset);
619  offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
620  Label *context_words = reinterpret_cast<Label *>(data + offset);
621  offset += (num_states + 1) * sizeof(label);
622  Label *future_words = reinterpret_cast<Label *>(data + offset);
623  offset += num_futures * sizeof(label);
624  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
625  Weight *backoff = reinterpret_cast<Weight *>(data + offset);
626  offset += (num_states + 1) * sizeof(weight);
627  Weight *final_probs = reinterpret_cast<Weight *>(data + offset);
628  offset += num_final * sizeof(weight);
629  Weight *future_probs = reinterpret_cast<Weight *>(data + offset);
630  int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
631  final_bit = 0;
632 
633  // pseudo-root bits
634  BitmapIndex::Set(context_bits, context_bit++);
635  ++context_bit;
636  context_words[context_arc] = label;
637  backoff[context_arc] = Weight::Zero();
638  context_arc++;
639 
640  ++future_bit;
641  if (order_out) {
642  order_out->clear();
643  order_out->resize(num_states);
644  }
645 
646  std::queue<StateId> context_q;
647  context_q.push(context_fst.Start());
648  StateId state_number = 0;
649  while (!context_q.empty()) {
650  const StateId &state = context_q.front();
651  if (order_out) {
652  (*order_out)[state] = state_number;
653  }
654 
655  const Weight final_weight = context_fst.Final(state);
656  if (final_weight != Weight::Zero()) {
657  BitmapIndex::Set(final_bits, state_number);
658  final_probs[final_bit] = final_weight;
659  ++final_bit;
660  }
661 
662  for (ArcIterator<VectorFst<A>> aiter(context_fst, state); !aiter.Done();
663  aiter.Next()) {
664  const Arc &arc = aiter.Value();
665  context_words[context_arc] = arc.ilabel;
666  backoff[context_arc] = arc.weight;
667  ++context_arc;
668  BitmapIndex::Set(context_bits, context_bit++);
669  context_q.push(arc.nextstate);
670  }
671  ++context_bit;
672 
673  for (ArcIterator<Fst<A>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
674  const Arc &arc = aiter.Value();
675  if (arc.ilabel != 0) {
676  future_words[future_arc] = arc.ilabel;
677  future_probs[future_arc] = arc.weight;
678  ++future_arc;
679  BitmapIndex::Set(future_bits, future_bit++);
680  }
681  }
682  ++future_bit;
683  ++state_number;
684  context_q.pop();
685  }
686 
687  if ((state_number != num_states) || (context_bit != num_states * 2 + 1) ||
688  (context_arc != num_states) || (future_arc != num_futures) ||
689  (future_bit != num_futures + num_states + 1) ||
690  (final_bit != num_final)) {
691  FSTERROR() << "Structure problems detected during construction";
692  SetProperties(kError, kError);
693  return;
694  }
695 
696  Init(data, false, data_region);
697 }
698 
699 template <typename A>
700 inline void NGramFstImpl<A>::Init(const char *data, bool owned,
701  MappedFile *data_region) {
702  if (owned_) {
703  delete[] data_;
704  }
705  data_region_.reset(data_region);
706  owned_ = owned;
707  data_ = data;
708  size_t offset = 0;
709  num_states_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
710  offset += sizeof(num_states_);
711  num_futures_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
712  offset += sizeof(num_futures_);
713  num_final_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
714  offset += sizeof(num_final_);
715  uint64 bits;
716  size_t context_bits = num_states_ * 2 + 1;
717  size_t future_bits = num_futures_ + num_states_ + 1;
718  context_ = reinterpret_cast<const uint64 *>(data_ + offset);
719  offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
720  future_ = reinterpret_cast<const uint64 *>(data_ + offset);
721  offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
722  final_ = reinterpret_cast<const uint64 *>(data_ + offset);
723  offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
724  context_words_ = reinterpret_cast<const Label *>(data_ + offset);
725  offset += (num_states_ + 1) * sizeof(*context_words_);
726  future_words_ = reinterpret_cast<const Label *>(data_ + offset);
727  offset += num_futures_ * sizeof(*future_words_);
728  offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
729  backoff_ = reinterpret_cast<const Weight *>(data_ + offset);
730  offset += (num_states_ + 1) * sizeof(*backoff_);
731  final_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
732  offset += num_final_ * sizeof(*final_probs_);
733  future_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
734 
735  context_index_.BuildIndex(context_, context_bits);
736  future_index_.BuildIndex(future_, future_bits);
737  final_index_.BuildIndex(final_, num_states_);
738 
739  select_root_ = context_index_.Select0s(0);
740  if (context_index_.Rank1(0) != 0 || select_root_.first != 1 ||
741  context_index_.Get(2) == false) {
742  FSTERROR() << "Malformed file";
743  SetProperties(kError, kError);
744  return;
745  }
746  root_children_ = context_words_ + context_index_.Rank1(2);
747  start_ = 1;
748 }
749 
750 template <typename A>
751 inline typename A::StateId NGramFstImpl<A>::Transition(
752  const std::vector<Label> &context, Label future) const {
753  const Label *children = root_children_;
754  size_t num_children = select_root_.second - 2;
755  const Label *loc =
756  std::lower_bound(children, children + num_children, future);
757  if (loc == children + num_children || *loc != future) {
758  return context_index_.Rank1(0);
759  }
760  size_t node = 2 + loc - children;
761  size_t node_rank = context_index_.Rank1(node);
762  std::pair<size_t, size_t> zeros =
763  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
764  size_t first_child = zeros.first + 1;
765  if (context_index_.Get(first_child) == false) {
766  return context_index_.Rank1(node);
767  }
768  size_t last_child = zeros.second - 1;
769  for (int word = context.size() - 1; word >= 0; --word) {
770  children = context_words_ + context_index_.Rank1(first_child);
771  loc = std::lower_bound(children, children + last_child - first_child + 1,
772  context[word]);
773  if (loc == children + last_child - first_child + 1 ||
774  *loc != context[word]) {
775  break;
776  }
777  node = first_child + loc - children;
778  node_rank = context_index_.Rank1(node);
779  zeros =
780  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
781  first_child = zeros.first + 1;
782  if (context_index_.Get(first_child) == false) break;
783  last_child = zeros.second - 1;
784  }
785  return context_index_.Rank1(node);
786 }
787 
788 } // namespace internal
789 
790 /*****************************************************************************/
791 template <class A>
792 class NGramFstMatcher : public MatcherBase<A> {
793  public:
794  typedef A Arc;
795  typedef typename A::Label Label;
796  typedef typename A::StateId StateId;
797  typedef typename A::Weight Weight;
798 
799  // This makes a copy of the FST.
801  : owned_fst_(fst.Copy()),
802  fst_(*owned_fst_),
803  inst_(fst_.inst_),
804  match_type_(match_type),
805  current_loop_(false),
806  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
807  if (match_type_ == MATCH_OUTPUT) {
808  std::swap(loop_.ilabel, loop_.olabel);
809  }
810  }
811 
812  // This doesn't copy the FST.
814  : fst_(*fst),
815  inst_(fst_.inst_),
816  match_type_(match_type),
817  current_loop_(false),
818  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
819  if (match_type_ == MATCH_OUTPUT) {
820  std::swap(loop_.ilabel, loop_.olabel);
821  }
822  }
823 
824  // This makes a copy of the FST.
825  NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
826  : owned_fst_(matcher.fst_.Copy(safe)),
827  fst_(*owned_fst_),
828  inst_(matcher.inst_),
829  match_type_(matcher.match_type_),
830  current_loop_(false),
831  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
832  if (match_type_ == MATCH_OUTPUT) {
833  std::swap(loop_.ilabel, loop_.olabel);
834  }
835  }
836 
837  NGramFstMatcher<A> *Copy(bool safe = false) const override {
838  return new NGramFstMatcher<A>(*this, safe);
839  }
840 
841  MatchType Type(bool test) const override { return match_type_; }
842 
843  const Fst<A> &GetFst() const override { return fst_; }
844 
845  uint64 Properties(uint64 props) const override { return props; }
846 
847  void SetState(StateId s) final {
848  fst_.GetImpl()->SetInstFuture(s, &inst_);
849  current_loop_ = false;
850  }
851 
852  bool Find(Label label) final {
853  const Label nolabel = kNoLabel;
854  done_ = true;
855  if (label == 0 || label == nolabel) {
856  if (label == 0) {
857  current_loop_ = true;
858  loop_.nextstate = inst_.state_;
859  }
860  // The unigram state has no epsilon arc.
861  if (inst_.state_ != 0) {
862  arc_.ilabel = arc_.olabel = 0;
863  fst_.GetImpl()->SetInstNode(&inst_);
864  arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
865  fst_.GetImpl()->context_index_.Select1(
866  fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
867  arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
868  done_ = false;
869  }
870  } else {
871  current_loop_ = false;
872  const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
873  const Label *end = start + inst_.num_futures_;
874  const Label *search = std::lower_bound(start, end, label);
875  if (search != end && *search == label) {
876  size_t state = search - start;
877  arc_.ilabel = arc_.olabel = label;
878  arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
879  fst_.GetImpl()->SetInstContext(&inst_);
880  arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
881  done_ = false;
882  }
883  }
884  return !Done();
885  }
886 
887  bool Done() const final { return !current_loop_ && done_; }
888 
889  const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; }
890 
891  void Next() final {
892  if (current_loop_) {
893  current_loop_ = false;
894  } else {
895  done_ = true;
896  }
897  }
898 
899  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
900 
901  private:
902  std::unique_ptr<NGramFst<A>> owned_fst_;
903  const NGramFst<A> &fst_;
904  NGramFstInst<A> inst_;
905  MatchType match_type_; // Supplied by caller
906  bool done_;
907  Arc arc_;
908  bool current_loop_; // Current arc is the implicit loop
909  Arc loop_;
910 };
911 
912 /*****************************************************************************/
913 // Specialization for NGramFst; see generic version in fst.h
914 // for sample usage (but use the ProdLmFst type!). This version
915 // should inline.
916 template <class A>
917 class StateIterator<NGramFst<A>> : public StateIteratorBase<A> {
918  public:
919  typedef typename A::StateId StateId;
920 
921  explicit StateIterator(const NGramFst<A> &fst)
922  : s_(0), num_states_(fst.NumStates()) {}
923 
924  bool Done() const final { return s_ >= num_states_; }
925 
926  StateId Value() const final { return s_; }
927 
928  void Next() final { ++s_; }
929 
930  void Reset() final { s_ = 0; }
931 
932  private:
933  StateId s_;
934  StateId num_states_;
935 };
936 
937 /*****************************************************************************/
938 template <class A>
939 class ArcIterator<NGramFst<A>> : public ArcIteratorBase<A> {
940  public:
941  typedef A Arc;
942  typedef typename A::Label Label;
943  typedef typename A::StateId StateId;
944  typedef typename A::Weight Weight;
945 
946  ArcIterator(const NGramFst<A> &fst, StateId state)
947  : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
948  inst_ = fst.inst_;
949  impl_->SetInstFuture(state, &inst_);
950  impl_->SetInstNode(&inst_);
951  }
952 
953  bool Done() const final {
954  return i_ >=
955  ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1);
956  }
957 
958  const Arc &Value() const final {
959  bool eps = (inst_.node_ != 0 && i_ == 0);
960  StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
961  if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
962  arc_.ilabel = arc_.olabel =
963  eps ? 0 : impl_->future_words_[inst_.offset_ + state];
964  lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
965  }
966  if (flags_ & lazy_ & kArcNextStateValue) {
967  if (eps) {
968  arc_.nextstate =
969  impl_->context_index_.Rank1(impl_->context_index_.Select1(
970  impl_->context_index_.Rank0(inst_.node_) - 1));
971  } else {
972  if (lazy_ & kArcNextStateValue) {
973  impl_->SetInstContext(&inst_); // first time only.
974  }
975  arc_.nextstate = impl_->Transition(
976  inst_.context_, impl_->future_words_[inst_.offset_ + state]);
977  }
978  lazy_ &= ~kArcNextStateValue;
979  }
980  if (flags_ & lazy_ & kArcWeightValue) {
981  arc_.weight = eps ? impl_->backoff_[inst_.state_]
982  : impl_->future_probs_[inst_.offset_ + state];
983  lazy_ &= ~kArcWeightValue;
984  }
985  return arc_;
986  }
987 
988  void Next() final {
989  ++i_;
990  lazy_ = ~0;
991  }
992 
993  size_t Position() const final { return i_; }
994 
995  void Reset() final {
996  i_ = 0;
997  lazy_ = ~0;
998  }
999 
1000  void Seek(size_t a) final {
1001  if (i_ != a) {
1002  i_ = a;
1003  lazy_ = ~0;
1004  }
1005  }
1006 
1007  uint32 Flags() const final { return flags_; }
1008 
1009  void SetFlags(uint32 flags, uint32 mask) final {
1010  flags_ &= ~mask;
1011  flags_ |= (flags & kArcValueFlags);
1012  }
1013 
1014  private:
1015  mutable Arc arc_;
1016  mutable uint32 lazy_;
1017  const internal::NGramFstImpl<A> *impl_; // Borrowed reference.
1018  mutable NGramFstInst<A> inst_;
1019 
1020  size_t i_;
1021  uint32 flags_;
1022 };
1023 
1024 } // namespace fst
1025 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
internal::NGramFstImpl< A > Impl
Definition: ngram-fst.h:333
void GetStates(const std::vector< Label > &context, std::vector< StateId > *state) const
Definition: ngram-fst.h:365
void AddArc(StateId s, const Arc &arc) override
Definition: mutable-fst.h:312
void * mutable_data() const
Definition: mapped-file.h:34
NGramFst< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:374
void SetStart(StateId s) override
Definition: mutable-fst.h:284
size_t num_futures_
Definition: ngram-fst.h:41
constexpr uint64 kInitialAcyclic
Definition: properties.h:98
void Next() final
Definition: ngram-fst.h:891
constexpr int kNoLabel
Definition: fst.h:178
size_t NumInputEpsilons(StateId state) const
Definition: ngram-fst.h:155
uint64_t uint64
Definition: types.h:32
const Arc & Value() const final
Definition: ngram-fst.h:889
void SetFinal(StateId s, Weight weight=Weight::One()) override
Definition: mutable-fst.h:289
virtual size_t NumArcs(StateId) const =0
void SetInstContext(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:208
uint64 Properties(uint64 props) const override
Definition: ngram-fst.h:845
MatchType
Definition: fst.h:170
StateId state_
Definition: ngram-fst.h:40
constexpr uint64 kILabelSorted
Definition: properties.h:76
size_t Position() const final
Definition: ngram-fst.h:993
size_t StorageSize() const
Definition: ngram-fst.h:416
static NGramFstImpl< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:98
#define LOG(type)
Definition: log.h:46
StateId Start() const override
Definition: fst.h:874
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:36
StateId node_state_
Definition: ngram-fst.h:44
bool WriteFile(const std::string &filename) const
Definition: fst.h:302
const Arc & Value() const final
Definition: ngram-fst.h:958
constexpr int kNoStateId
Definition: fst.h:179
NGramFstMatcher< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:837
constexpr uint64 kExpanded
Definition: properties.h:28
const Arc & Value() const
Definition: fst.h:504
NGramFstMatcher(const NGramFstMatcher< A > &matcher, bool safe=false)
Definition: ngram-fst.h:825
static MappedFile * Allocate(size_t size, size_t align=kArchAlignment)
Definition: mapped-file.cc:103
virtual uint64 Properties(uint64 mask, bool test) const =0
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
Definition: ngram-fst.h:462
NGramFst(const Fst< A > &fst, std::vector< StateId > *order_out)
Definition: ngram-fst.h:338
const Fst< A > & GetFst() const override
Definition: ngram-fst.h:843
virtual size_t NumInputEpsilons(StateId) const =0
static NGramFst< A > * Read(const std::string &filename)
Definition: ngram-fst.h:383
int64_t int64
Definition: types.h:27
constexpr uint64 kEpsilons
Definition: properties.h:61
#define FSTERROR()
Definition: util.h:35
StateId AddState() override
Definition: mutable-fst.h:302
NGramFst(const Fst< A > &dst)
Definition: ngram-fst.h:335
size_t NumArcs(StateId state, NGramFstInst< A > *inst=nullptr) const
Definition: ngram-fst.h:145
const std::vector< Label > & GetContext(StateId s, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:229
void SetOutputSymbols(const SymbolTable *osyms) override
Definition: mutable-fst.h:382
NGramFst(const NGramFst< A > &fst, bool safe=false)
Definition: ngram-fst.h:343
StateIteratorBase< Arc > * base
Definition: fst.h:352
A::Weight Weight
Definition: ngram-fst.h:39
static void Set(uint64 *bits, size_t index)
Definition: bitmap-index.h:41
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:94
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:355
A::StateId StateId
Definition: ngram-fst.h:38
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:87
void SetInstNode(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:201
constexpr uint64 kNotString
Definition: properties.h:121
static NGramFst< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:378
void SetInputSymbols(const SymbolTable *isyms) override
Definition: mutable-fst.h:377
std::vector< Label > context_
Definition: ngram-fst.h:45
MatcherBase< A > * InitMatcher(MatchType match_type) const override
Definition: ngram-fst.h:412
constexpr uint64 kOLabelSorted
Definition: properties.h:81
constexpr uint64 kIEpsilons
Definition: properties.h:66
StateId Value() const final
Definition: ngram-fst.h:926
void SetNumStates(int64 numstates)
Definition: fst.h:147
StateId NumStates() const override
Definition: expanded-fst.h:120
NGramFstMatcher(const NGramFst< A > &fst, MatchType match_type)
Definition: ngram-fst.h:800
StateId nstates
Definition: fst.h:354
constexpr uint64 kIDeterministic
Definition: properties.h:51
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:222
MatchType Type(bool test) const override
Definition: ngram-fst.h:841
size_t NumArcs(StateId s) const override
Definition: ngram-fst.h:370
void SetInstFuture(StateId state, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:192
virtual StateId Start() const =0
bool Done() const
Definition: fst.h:500
static size_t StorageSize(size_t size)
Definition: bitmap-index.h:30
A::StateId StateId
Definition: ngram-fst.h:796
constexpr uint64 kAcceptor
Definition: properties.h:46
NGramFst(const char *data, bool owned)
Definition: ngram-fst.h:349
void SetState(StateId s) final
Definition: ngram-fst.h:847
uint32 Flags() const final
Definition: ngram-fst.h:1007
NGramFstImpl(const Fst< A > &fst)
Definition: ngram-fst.h:85
const std::vector< Label > GetContext(StateId s) const
Definition: ngram-fst.h:359
void Seek(size_t a) final
Definition: ngram-fst.h:1000
StateId NumStates() const
Definition: ngram-fst.h:165
A::Label Label
Definition: ngram-fst.h:331
size_t StorageSize() const
Definition: ngram-fst.h:235
uint32_t uint32
Definition: types.h:31
constexpr uint64 kCoAccessible
Definition: properties.h:111
void SetStart(int64 start)
Definition: fst.h:145
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
Definition: ngram-fst.h:397
StateId context_state_
Definition: ngram-fst.h:46
virtual const SymbolTable * InputSymbols() const =0
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:155
ssize_t Priority(StateId s) final
Definition: ngram-fst.h:899
ArcIterator(const NGramFst< A > &fst, StateId state)
Definition: ngram-fst.h:946
static size_t Storage(uint64 num_states, uint64 num_futures, uint64 num_final)
Definition: ngram-fst.h:172
Weight Final(StateId s) const override
Definition: fst.h:876
size_t NumOutputEpsilons(StateId state) const
Definition: ngram-fst.h:161
NGramFstImpl(const NGramFstImpl &other)
Definition: ngram-fst.h:87
bool Done() const final
Definition: ngram-fst.h:887
constexpr uint64 kWeighted
Definition: properties.h:86
constexpr uint64 kError
Definition: properties.h:34
StateId Start() const
Definition: ngram-fst.h:135
NGramFstMatcher(const NGramFst< A > *fst, MatchType match_type)
Definition: ngram-fst.h:813
A::Label Label
Definition: ngram-fst.h:37
constexpr uint64 kNotTopSorted
Definition: properties.h:103
ArcIteratorBase< Arc > * base
Definition: fst.h:462
void SetFlags(uint32 flags, uint32 mask) final
Definition: ngram-fst.h:1009
A::Weight Weight
Definition: ngram-fst.h:332
void InitStateIterator(StateIteratorData< A > *data) const override
Definition: ngram-fst.h:405
Weight Final(StateId state) const
Definition: ngram-fst.h:137
void InitStateIterator(StateIteratorData< A > *data) const
Definition: ngram-fst.h:167
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
Definition: ngram-fst.h:125
bool Write(const std::string &filename) const override
Definition: ngram-fst.h:401
A::StateId StateId
Definition: ngram-fst.h:330
uint64 Properties(uint64 mask, bool test) const override
Definition: fst.h:888
bool Find(Label label) final
Definition: ngram-fst.h:852
constexpr uint64 kODeterministic
Definition: properties.h:56
static bool HasRequiredStructure(const Fst< A > &fst)
Definition: ngram-fst.h:424
StateIterator(const NGramFst< A > &fst)
Definition: ngram-fst.h:921
constexpr uint64 kOEpsilons
Definition: properties.h:71
constexpr uint64 kAccessible
Definition: properties.h:106
static bool HasRequiredProps(const Fst< A > &fst)
Definition: ngram-fst.h:418
constexpr uint64 kCyclic
Definition: properties.h:91
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:508