FST  openfst-1.8.2
OpenFst Library
ngram-fst.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // NgramFst implements a n-gram language model based upon the LOUDS data
19 // structure. Please refer to "Unary Data Structures for Language Models"
20 // http://research.google.com/pubs/archive/37218.pdf
21 
22 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
23 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
24 
25 #include <algorithm>
26 #include <cstddef>
27 #include <cstdint>
28 #include <cstring>
29 #include <iostream>
30 #include <memory>
31 #include <string>
32 #include <utility>
33 #include <vector>
34 
35 #include <fst/compat.h>
36 #include <fst/log.h>
38 #include <fstream>
39 #include <fst/fstlib.h>
40 #include <fst/mapped-file.h>
41 
42 namespace fst {
43 template <class A>
44 class NGramFst;
45 
46 template <class A>
48 
49 // Instance data containing mutable state for bookkeeping repeated access to
50 // the same state.
51 template <class A>
52 struct NGramFstInst {
53  typedef typename A::Label Label;
54  typedef typename A::StateId StateId;
55  typedef typename A::Weight Weight;
56  StateId state_;
57  size_t num_futures_;
58  size_t offset_;
59  size_t node_;
60  StateId node_state_;
61  std::vector<Label> context_;
62  StateId context_state_;
64  : state_(kNoStateId),
65  node_state_(kNoStateId),
66  context_state_(kNoStateId) {}
67 };
68 
69 namespace internal {
70 
71 // Implementation class for LOUDS based NgramFst interface.
72 template <class A>
73 class NGramFstImpl : public FstImpl<A> {
76  using FstImpl<A>::SetType;
78 
79  friend class ArcIterator<NGramFst<A>>;
80  friend class NGramFstMatcher<A>;
81 
82  public:
86 
87  typedef A Arc;
88  typedef typename A::Label Label;
89  typedef typename A::StateId StateId;
90  typedef typename A::Weight Weight;
91 
93  SetType("ngram");
94  SetInputSymbols(nullptr);
95  SetOutputSymbols(nullptr);
96  SetProperties(kStaticProperties);
97  }
98 
99  NGramFstImpl(const Fst<A> &fst, std::vector<StateId> *order_out);
100 
101  explicit NGramFstImpl(const Fst<A> &fst) : NGramFstImpl(fst, nullptr) {}
102 
103  NGramFstImpl(const NGramFstImpl &other) {
104  FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false.";
105  SetProperties(kError, kError);
106  }
107 
108  static NGramFstImpl<A> *Read(std::istream &strm, const FstReadOptions &opts) {
109  auto impl = std::make_unique<NGramFstImpl<A>>();
110  FstHeader hdr;
111  if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr;
112  uint64_t num_states, num_futures, num_final;
113  const size_t offset =
114  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
115  // Peek at num_states and num_futures to see how much more needs to be read.
116  strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
117  strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
118  strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
119  size_t size = Storage(num_states, num_futures, num_final);
120  std::unique_ptr<MappedFile> data_region(MappedFile::Allocate(size));
121  char *data = static_cast<char *>(data_region->mutable_data());
122  // Copy num_states, num_futures and num_final back into data.
123  memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
124  memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
125  sizeof(num_futures));
126  memcpy(data + sizeof(num_states) + sizeof(num_futures),
127  reinterpret_cast<char *>(&num_final), sizeof(num_final));
128  strm.read(data + offset, size - offset);
129  if (strm.fail()) return nullptr;
130  impl->Init(data, std::move(data_region));
131  return impl.release();
132  }
133 
134  bool Write(std::ostream &strm, const FstWriteOptions &opts) const {
135  FstHeader hdr;
136  hdr.SetStart(Start());
137  hdr.SetNumStates(num_states_);
138  WriteHeader(strm, opts, kFileVersion, &hdr);
139  strm.write(data_, StorageSize());
140  return !strm.fail();
141  }
142 
143  StateId Start() const { return start_; }
144 
145  Weight Final(StateId state) const {
146  if (final_index_.Get(state)) {
147  return final_probs_[final_index_.Rank1(state)];
148  } else {
149  return Weight::Zero();
150  }
151  }
152 
153  size_t NumArcs(StateId state, NGramFstInst<A> *inst = nullptr) const {
154  if (inst == nullptr) {
155  const std::pair<size_t, size_t> zeros =
156  (state == 0) ? select_root_ : future_index_.Select0s(state);
157  return zeros.second - zeros.first - 1;
158  }
159  SetInstFuture(state, inst);
160  return inst->num_futures_ + ((state == 0) ? 0 : 1);
161  }
162 
163  size_t NumInputEpsilons(StateId state) const {
164  // State 0 has no parent, thus no backoff.
165  if (state == 0) return 0;
166  return 1;
167  }
168 
169  size_t NumOutputEpsilons(StateId state) const {
170  return NumInputEpsilons(state);
171  }
172 
173  StateId NumStates() const { return num_states_; }
174 
176  data->base = nullptr;
177  data->nstates = num_states_;
178  }
179 
180  static size_t Storage(uint64_t num_states, uint64_t num_futures,
181  uint64_t num_final) {
182  uint64_t b64;
183  Weight weight;
184  Label label;
185  size_t offset =
186  sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
187  offset +=
188  sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) +
189  BitmapIndex::StorageSize(num_futures + num_states + 1) +
190  BitmapIndex::StorageSize(num_states));
191  offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
192  // Pad for alignemnt, see
193  // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
194  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
195  offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
196  (num_futures + 1) * sizeof(weight);
197  return offset;
198  }
199 
200  void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
201  if (inst->state_ != state) {
202  inst->state_ = state;
203  const std::pair<size_t, size_t> zeros = future_index_.Select0s(state);
204  inst->num_futures_ = zeros.second - zeros.first - 1;
205  inst->offset_ = future_index_.Rank1(zeros.first + 1);
206  }
207  }
208 
209  void SetInstNode(NGramFstInst<A> *inst) const {
210  if (inst->node_state_ != inst->state_) {
211  inst->node_state_ = inst->state_;
212  inst->node_ = context_index_.Select1(inst->state_);
213  }
214  }
215 
216  void SetInstContext(NGramFstInst<A> *inst) const {
217  SetInstNode(inst);
218  if (inst->context_state_ != inst->state_) {
219  inst->context_state_ = inst->state_;
220  inst->context_.clear();
221  size_t node = inst->node_;
222  while (node != 0) {
223  inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
224  node = context_index_.Select1(context_index_.Rank0(node) - 1);
225  }
226  }
227  }
228 
229  // Access to the underlying representation
230  const char *GetData(size_t *data_size) const {
231  *data_size = StorageSize();
232  return data_;
233  }
234 
235  void Init(const char *data,
236  std::unique_ptr<MappedFile> data_region);
237 
238  const std::vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
239  SetInstFuture(s, inst);
240  SetInstContext(inst);
241  return inst->context_;
242  }
243 
244  size_t StorageSize() const {
245  return Storage(num_states_, num_futures_, num_final_);
246  }
247 
248  void GetStates(const std::vector<Label> &context,
249  std::vector<StateId> *states) const;
250 
251  private:
252  StateId Transition(const std::vector<Label> &context, Label future) const;
253 
254  // Properties always true for this Fst class.
255  static constexpr uint64_t kStaticProperties =
260  // Current file format version.
261  static constexpr int kFileVersion = 4;
262  // Minimum file format version supported.
263  static constexpr int kMinFileVersion = 4;
264 
265  std::unique_ptr<MappedFile> data_region_;
266  const char *data_ = nullptr; // Not owned.
267  StateId start_ = fst::kNoStateId;
268  uint64_t num_states_ = 0;
269  uint64_t num_futures_ = 0;
270  uint64_t num_final_ = 0;
271  std::pair<size_t, size_t> select_root_;
272  const Label *root_children_ = nullptr;
273  // borrowed references
274  const uint64_t *context_ = nullptr;
275  const uint64_t *future_ = nullptr;
276  const uint64_t *final_ = nullptr;
277  const Label *context_words_ = nullptr;
278  const Label *future_words_ = nullptr;
279  const Weight *backoff_ = nullptr;
280  const Weight *final_probs_ = nullptr;
281  const Weight *future_probs_ = nullptr;
282  // Uses all operations.
283  BitmapIndex context_index_;
284  // Uses Select0 and Rank1.
285  BitmapIndex future_index_;
286  // Uses Get and Rank1. This wastes space if there are no or few final
287  // states, but it's also small. TODO(jrosenstock): Look at EliasFanoArray.
288  BitmapIndex final_index_;
289 };
290 
291 template <typename A>
293  const std::vector<Label> &context,
294  std::vector<typename A::StateId> *states) const {
295  states->clear();
296  states->push_back(0);
297  typename std::vector<Label>::const_reverse_iterator cit = context.rbegin();
298  const Label *children = root_children_;
299  size_t num_children = select_root_.second - 2;
300  const Label *loc = std::lower_bound(children, children + num_children, *cit);
301  if (loc == children + num_children || *loc != *cit) return;
302  size_t node = 2 + loc - children;
303  states->push_back(context_index_.Rank1(node));
304  if (context.size() == 1) return;
305  size_t node_rank = context_index_.Rank1(node);
306  std::pair<size_t, size_t> zeros =
307  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
308  size_t first_child = zeros.first + 1;
309  ++cit;
310  if (context_index_.Get(first_child) != false) {
311  size_t last_child = zeros.second - 1;
312  while (cit != context.rend()) {
313  children = context_words_ + context_index_.Rank1(first_child);
314  loc = std::lower_bound(children, children + last_child - first_child + 1,
315  *cit);
316  if (loc == children + last_child - first_child + 1 || *loc != *cit) {
317  break;
318  }
319  ++cit;
320  node = first_child + loc - children;
321  states->push_back(context_index_.Rank1(node));
322  node_rank = context_index_.Rank1(node);
323  zeros =
324  node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
325  first_child = zeros.first + 1;
326  if (context_index_.Get(first_child) == false) break;
327  last_child = zeros.second - 1;
328  }
329  }
330 }
331 
332 } // namespace internal
333 
334 /*****************************************************************************/
335 template <class A>
336 class NGramFst : public ImplToExpandedFst<internal::NGramFstImpl<A>> {
337  friend class ArcIterator<NGramFst<A>>;
338  friend class NGramFstMatcher<A>;
339 
340  public:
341  typedef A Arc;
342  typedef typename A::StateId StateId;
343  typedef typename A::Label Label;
344  typedef typename A::Weight Weight;
346 
347  explicit NGramFst(const Fst<A> &dst)
348  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(dst, nullptr)) {}
349 
350  NGramFst(const Fst<A> &fst, std::vector<StateId> *order_out)
351  : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst, order_out)) {}
352 
353  // Because the NGramFstImpl is a const stateless data structure, there
354  // is never a need to do anything beside copy the reference.
355  NGramFst(const NGramFst<A> &fst, bool safe = false)
356  : ImplToExpandedFst<Impl>(fst, false) {}
357 
358  NGramFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
359 
360  // Non-standard constructor to initialize NGramFst directly from data. Caller
361  // maintains ownership of data, which must outlive the NGramFst.
362  explicit NGramFst(const char *data)
363  : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {
364  GetMutableImpl()->Init(data, /*data_region=*/nullptr);
365  }
366 
367  // Get method that gets the data associated with Init().
368  const char *GetData(size_t *data_size) const {
369  return GetImpl()->GetData(data_size);
370  }
371 
372  const std::vector<Label> GetContext(StateId s) const {
373  return GetImpl()->GetContext(s, &inst_);
374  }
375 
376  // Consumes as much as possible of context from right to left, returns the
377  // the states corresponding to the increasingly conditioned input sequence.
378  void GetStates(const std::vector<Label> &context,
379  std::vector<StateId> *state) const {
380  return GetImpl()->GetStates(context, state);
381  }
382 
383  size_t NumArcs(StateId s) const override {
384  return GetImpl()->NumArcs(s, &inst_);
385  }
386 
387  NGramFst<A> *Copy(bool safe = false) const override {
388  return new NGramFst(*this, safe);
389  }
390 
391  static NGramFst<A> *Read(std::istream &strm, const FstReadOptions &opts) {
392  Impl *impl = Impl::Read(strm, opts);
393  return impl ? new NGramFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
394  }
395 
396  static NGramFst<A> *Read(const std::string &source) {
397  if (!source.empty()) {
398  std::ifstream strm(source,
399  std::ios_base::in | std::ios_base::binary);
400  if (!strm.good()) {
401  LOG(ERROR) << "NGramFst::Read: Can't open file: " << source;
402  return nullptr;
403  }
404  return Read(strm, FstReadOptions(source));
405  } else {
406  return Read(std::cin, FstReadOptions("standard input"));
407  }
408  }
409 
410  bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
411  return GetImpl()->Write(strm, opts);
412  }
413 
414  bool Write(const std::string &source) const override {
415  return Fst<A>::WriteFile(source);
416  }
417 
418  inline void InitStateIterator(StateIteratorData<A> *data) const override {
419  GetImpl()->InitStateIterator(data);
420  }
421 
422  inline void InitArcIterator(StateId s,
423  ArcIteratorData<A> *data) const override;
424 
425  MatcherBase<A> *InitMatcher(MatchType match_type) const override {
426  return new NGramFstMatcher<A>(this, match_type);
427  }
428 
429  size_t StorageSize() const { return GetImpl()->StorageSize(); }
430 
431  static bool HasRequiredProps(const Fst<A> &fst) {
432  static const auto props =
434  return fst.Properties(props, true) == props;
435  }
436 
437  static bool HasRequiredStructure(const Fst<A> &fst) {
438  if (!HasRequiredProps(fst)) {
439  return false;
440  }
441  typename A::StateId unigram = fst.Start();
442  while (true) { // Follows epsilon arc chain to find unigram state.
443  if (unigram == fst::kNoStateId) return false; // No unigram state.
444  typename fst::ArcIterator<Fst<A>> aiter(fst, unigram);
445  if (aiter.Done() || aiter.Value().ilabel != 0) break;
446  unigram = aiter.Value().nextstate;
447  aiter.Next();
448  }
449  // Other requirement: all states other than unigram an epsilon arc.
450  for (fst::StateIterator<Fst<A>> siter(fst); !siter.Done();
451  siter.Next()) {
452  const typename A::StateId &state = siter.Value();
453  fst::ArcIterator<Fst<A>> aiter(fst, state);
454  if (state != unigram) {
455  if (aiter.Done()) return false;
456  if (aiter.Value().ilabel != 0) return false;
457  aiter.Next();
458  if (!aiter.Done() && aiter.Value().ilabel == 0) return false;
459  }
460  }
461  return true;
462  }
463 
464  private:
466  using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetMutableImpl;
467 
468  explicit NGramFst(std::shared_ptr<Impl> impl)
469  : ImplToExpandedFst<Impl>(impl) {}
470 
471  mutable NGramFstInst<A> inst_;
472 };
473 
474 template <class A>
476  ArcIteratorData<A> *data) const {
477  GetImpl()->SetInstFuture(s, &inst_);
478  GetImpl()->SetInstNode(&inst_);
479  data->base = std::make_unique<ArcIterator<NGramFst<A>>>(*this, s);
480 }
481 
482 namespace internal {
483 
484 template <typename A>
486  std::vector<StateId> *order_out) {
487  typedef A Arc;
488  typedef typename Arc::Label Label;
489  typedef typename Arc::Weight Weight;
490  typedef typename Arc::StateId StateId;
491  SetType("ngram");
492  SetInputSymbols(fst.InputSymbols());
493  SetOutputSymbols(fst.OutputSymbols());
494  SetProperties(kStaticProperties);
495 
496  // Check basic requirements for an OpenGrm language model Fst.
497  if (!NGramFst<A>::HasRequiredProps(fst)) {
498  FSTERROR() << "NGramFst only accepts OpenGrm language models as input";
499  SetProperties(kError, kError);
500  return;
501  }
502 
503  int64_t num_states = CountStates(fst);
504  std::vector<Label> context(num_states, 0);
505 
506  // Find the unigram state by starting from the start state, following
507  // epsilons.
508  StateId unigram = fst.Start();
509  while (true) {
510  if (unigram == kNoStateId) {
511  FSTERROR() << "Could not identify unigram state";
512  SetProperties(kError, kError);
513  return;
514  }
515  ArcIterator<Fst<A>> aiter(fst, unigram);
516  if (aiter.Done()) {
517  LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
518  break;
519  }
520  if (aiter.Value().ilabel != 0) break;
521  unigram = aiter.Value().nextstate;
522  }
523 
524  // Each state's context is determined by the subtree it is under from the
525  // unigram state.
526  std::queue<std::pair<StateId, Label>> label_queue;
527  std::vector<bool> visited(num_states);
528  // Force an epsilon link to the start state.
529  label_queue.push(std::make_pair(fst.Start(), 0));
530  for (ArcIterator<Fst<A>> aiter(fst, unigram); !aiter.Done(); aiter.Next()) {
531  label_queue.push(
532  std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
533  }
534  // investigate states in breadth first fashion to assign context words.
535  while (!label_queue.empty()) {
536  std::pair<StateId, Label> &now = label_queue.front();
537  if (!visited[now.first]) {
538  context[now.first] = now.second;
539  visited[now.first] = true;
540  for (ArcIterator<Fst<A>> aiter(fst, now.first); !aiter.Done();
541  aiter.Next()) {
542  const Arc &arc = aiter.Value();
543  if (arc.ilabel != 0) {
544  label_queue.push(std::make_pair(arc.nextstate, now.second));
545  }
546  }
547  }
548  label_queue.pop();
549  }
550  visited.clear();
551 
552  // The arc from the start state should be assigned an epsilon to put it
553  // in front of the all other labels (which makes Start state 1 after
554  // unigram which is state 0).
555  context[fst.Start()] = 0;
556 
557  // Build the tree of contexts fst by reversing the epsilon arcs from fst.
558  VectorFst<Arc> context_fst;
559  uint64_t num_final = 0;
560  for (int i = 0; i < num_states; ++i) {
561  if (fst.Final(i) != Weight::Zero()) {
562  ++num_final;
563  }
564  context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
565  }
566  context_fst.SetStart(unigram);
567  context_fst.SetInputSymbols(fst.InputSymbols());
568  context_fst.SetOutputSymbols(fst.OutputSymbols());
569  int64_t num_context_arcs = 0;
570  int64_t num_futures = 0;
571  for (StateIterator<Fst<A>> siter(fst); !siter.Done(); siter.Next()) {
572  const StateId &state = siter.Value();
573  num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
574  ArcIterator<Fst<A>> aiter(fst, state);
575  if (!aiter.Done()) {
576  const Arc &arc = aiter.Value();
577  // this arc goes from state to arc.nextstate, so create an arc from
578  // arc.nextstate to state to reverse it.
579  if (arc.ilabel == 0) {
580  context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
581  arc.weight, state));
582  num_context_arcs++;
583  }
584  }
585  }
586  if (num_context_arcs != context_fst.NumStates() - 1) {
587  FSTERROR() << "Number of contexts arcs != number of states - 1";
588  SetProperties(kError, kError);
589  return;
590  }
591  if (context_fst.NumStates() != num_states) {
592  FSTERROR() << "Number of contexts != number of states";
593  SetProperties(kError, kError);
594  return;
595  }
596  int64_t context_props =
597  context_fst.Properties(kIDeterministic | kILabelSorted, true);
598  if (!(context_props & kIDeterministic)) {
599  FSTERROR() << "Input Fst is not structured properly";
600  SetProperties(kError, kError);
601  return;
602  }
603  if (!(context_props & kILabelSorted)) {
604  ArcSort(&context_fst, ILabelCompare<Arc>());
605  }
606 
607  uint64_t b64;
608  Weight weight;
609  Label label = kNoLabel;
610  const size_t storage = Storage(num_states, num_futures, num_final);
611  std::unique_ptr<MappedFile> data_region(MappedFile::Allocate(storage));
612  char *data = static_cast<char *>(data_region->mutable_data());
613  memset(data, 0, storage);
614  size_t offset = 0;
615  memcpy(data + offset, reinterpret_cast<char *>(&num_states),
616  sizeof(num_states));
617  offset += sizeof(num_states);
618  memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
619  sizeof(num_futures));
620  offset += sizeof(num_futures);
621  memcpy(data + offset, reinterpret_cast<char *>(&num_final),
622  sizeof(num_final));
623  offset += sizeof(num_final);
624  uint64_t *context_bits = reinterpret_cast<uint64_t *>(data + offset);
625  offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
626  uint64_t *future_bits = reinterpret_cast<uint64_t *>(data + offset);
627  offset +=
628  BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
629  uint64_t *final_bits = reinterpret_cast<uint64_t *>(data + offset);
630  offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
631  Label *context_words = reinterpret_cast<Label *>(data + offset);
632  offset += (num_states + 1) * sizeof(label);
633  Label *future_words = reinterpret_cast<Label *>(data + offset);
634  offset += num_futures * sizeof(label);
635  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
636  Weight *backoff = reinterpret_cast<Weight *>(data + offset);
637  offset += (num_states + 1) * sizeof(weight);
638  Weight *final_probs = reinterpret_cast<Weight *>(data + offset);
639  offset += num_final * sizeof(weight);
640  Weight *future_probs = reinterpret_cast<Weight *>(data + offset);
641  int64_t context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
642  final_bit = 0;
643 
644  // pseudo-root bits
645  BitmapIndex::Set(context_bits, context_bit++);
646  ++context_bit;
647  context_words[context_arc] = label;
648  backoff[context_arc] = Weight::Zero();
649  context_arc++;
650 
651  ++future_bit;
652  if (order_out) {
653  order_out->clear();
654  order_out->resize(num_states);
655  }
656 
657  std::queue<StateId> context_q;
658  context_q.push(context_fst.Start());
659  StateId state_number = 0;
660  while (!context_q.empty()) {
661  const StateId &state = context_q.front();
662  if (order_out) {
663  (*order_out)[state] = state_number;
664  }
665 
666  const Weight final_weight = context_fst.Final(state);
667  if (final_weight != Weight::Zero()) {
668  BitmapIndex::Set(final_bits, state_number);
669  final_probs[final_bit] = final_weight;
670  ++final_bit;
671  }
672 
673  for (ArcIterator<VectorFst<A>> aiter(context_fst, state); !aiter.Done();
674  aiter.Next()) {
675  const Arc &arc = aiter.Value();
676  context_words[context_arc] = arc.ilabel;
677  backoff[context_arc] = arc.weight;
678  ++context_arc;
679  BitmapIndex::Set(context_bits, context_bit++);
680  context_q.push(arc.nextstate);
681  }
682  ++context_bit;
683 
684  for (ArcIterator<Fst<A>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
685  const Arc &arc = aiter.Value();
686  if (arc.ilabel != 0) {
687  future_words[future_arc] = arc.ilabel;
688  future_probs[future_arc] = arc.weight;
689  ++future_arc;
690  BitmapIndex::Set(future_bits, future_bit++);
691  }
692  }
693  ++future_bit;
694  ++state_number;
695  context_q.pop();
696  }
697 
698  if ((state_number != num_states) || (context_bit != num_states * 2 + 1) ||
699  (context_arc != num_states) || (future_arc != num_futures) ||
700  (future_bit != num_futures + num_states + 1) ||
701  (final_bit != num_final)) {
702  FSTERROR() << "Structure problems detected during construction";
703  SetProperties(kError, kError);
704  return;
705  }
706 
707  Init(data, std::move(data_region));
708 }
709 
710 template <typename A>
711 inline void NGramFstImpl<A>::Init(const char *data,
712  std::unique_ptr<MappedFile> data_region) {
713  data_region_ = std::move(data_region);
714  data_ = data;
715  size_t offset = 0;
716  num_states_ = *(reinterpret_cast<const uint64_t *>(data_ + offset));
717  offset += sizeof(num_states_);
718  num_futures_ = *(reinterpret_cast<const uint64_t *>(data_ + offset));
719  offset += sizeof(num_futures_);
720  num_final_ = *(reinterpret_cast<const uint64_t *>(data_ + offset));
721  offset += sizeof(num_final_);
722  uint64_t bits;
723  size_t context_bits = num_states_ * 2 + 1;
724  size_t future_bits = num_futures_ + num_states_ + 1;
725  context_ = reinterpret_cast<const uint64_t *>(data_ + offset);
726  offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
727  future_ = reinterpret_cast<const uint64_t *>(data_ + offset);
728  offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
729  final_ = reinterpret_cast<const uint64_t *>(data_ + offset);
730  offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
731  context_words_ = reinterpret_cast<const Label *>(data_ + offset);
732  offset += (num_states_ + 1) * sizeof(*context_words_);
733  future_words_ = reinterpret_cast<const Label *>(data_ + offset);
734  offset += num_futures_ * sizeof(*future_words_);
735  offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
736  backoff_ = reinterpret_cast<const Weight *>(data_ + offset);
737  offset += (num_states_ + 1) * sizeof(*backoff_);
738  final_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
739  offset += num_final_ * sizeof(*final_probs_);
740  future_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
741 
742  context_index_.BuildIndex(context_, context_bits,
743  /*enable_select_0_index=*/true,
744  /*enable_select_1_index=*/true);
745  future_index_.BuildIndex(future_, future_bits,
746  /*enable_select_0_index=*/true,
747  /*enable_select_1_index=*/false);
748  final_index_.BuildIndex(final_, num_states_);
749 
750  select_root_ = context_index_.Select0s(0);
751  if (context_index_.Rank1(0) != 0 || select_root_.first != 1 ||
752  context_index_.Get(2) == false) {
753  FSTERROR() << "Malformed file";
754  SetProperties(kError, kError);
755  return;
756  }
757  root_children_ = context_words_ + context_index_.Rank1(2);
758  start_ = 1;
759 }
760 
761 template <typename A>
762 inline typename A::StateId NGramFstImpl<A>::Transition(
763  const std::vector<Label> &context, Label future) const {
764  const Label *children = root_children_;
765  size_t num_children = select_root_.second - 2;
766  const Label *loc =
767  std::lower_bound(children, children + num_children, future);
768  if (loc == children + num_children || *loc != future) {
769  return context_index_.Rank1(0);
770  }
771  size_t node = 2 + loc - children;
772  size_t node_rank = context_index_.Rank1(node);
773  std::pair<size_t, size_t> zeros =
774  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
775  size_t first_child = zeros.first + 1;
776  if (context_index_.Get(first_child) == false) {
777  return context_index_.Rank1(node);
778  }
779  size_t last_child = zeros.second - 1;
780  for (int word = context.size() - 1; word >= 0; --word) {
781  children = context_words_ + context_index_.Rank1(first_child);
782  loc = std::lower_bound(children, children + last_child - first_child + 1,
783  context[word]);
784  if (loc == children + last_child - first_child + 1 ||
785  *loc != context[word]) {
786  break;
787  }
788  node = first_child + loc - children;
789  node_rank = context_index_.Rank1(node);
790  zeros =
791  (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
792  first_child = zeros.first + 1;
793  if (context_index_.Get(first_child) == false) break;
794  last_child = zeros.second - 1;
795  }
796  return context_index_.Rank1(node);
797 }
798 
799 } // namespace internal
800 
801 /*****************************************************************************/
802 template <class A>
803 class NGramFstMatcher : public MatcherBase<A> {
804  public:
805  typedef A Arc;
806  typedef typename A::Label Label;
807  typedef typename A::StateId StateId;
808  typedef typename A::Weight Weight;
809 
810  // This makes a copy of the FST.
812  : owned_fst_(fst.Copy()),
813  fst_(*owned_fst_),
814  inst_(fst_.inst_),
815  match_type_(match_type),
816  current_loop_(false),
817  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
818  if (match_type_ == MATCH_OUTPUT) {
819  std::swap(loop_.ilabel, loop_.olabel);
820  }
821  }
822 
823  // This doesn't copy the FST.
825  : fst_(*fst),
826  inst_(fst_.inst_),
827  match_type_(match_type),
828  current_loop_(false),
829  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
830  if (match_type_ == MATCH_OUTPUT) {
831  std::swap(loop_.ilabel, loop_.olabel);
832  }
833  }
834 
835  // This makes a copy of the FST.
836  NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
837  : owned_fst_(matcher.fst_.Copy(safe)),
838  fst_(*owned_fst_),
839  inst_(matcher.inst_),
840  match_type_(matcher.match_type_),
841  current_loop_(false),
842  loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
843  if (match_type_ == MATCH_OUTPUT) {
844  std::swap(loop_.ilabel, loop_.olabel);
845  }
846  }
847 
848  NGramFstMatcher<A> *Copy(bool safe = false) const override {
849  return new NGramFstMatcher<A>(*this, safe);
850  }
851 
852  MatchType Type(bool test) const override { return match_type_; }
853 
854  const Fst<A> &GetFst() const override { return fst_; }
855 
856  uint64_t Properties(uint64_t props) const override { return props; }
857 
858  void SetState(StateId s) final {
859  fst_.GetImpl()->SetInstFuture(s, &inst_);
860  current_loop_ = false;
861  }
862 
863  bool Find(Label label) final {
864  const Label nolabel = kNoLabel;
865  done_ = true;
866  if (label == 0 || label == nolabel) {
867  if (label == 0) {
868  current_loop_ = true;
869  loop_.nextstate = inst_.state_;
870  }
871  // The unigram state has no epsilon arc.
872  if (inst_.state_ != 0) {
873  arc_.ilabel = arc_.olabel = 0;
874  fst_.GetImpl()->SetInstNode(&inst_);
875  arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
876  fst_.GetImpl()->context_index_.Select1(
877  fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
878  arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
879  done_ = false;
880  }
881  } else {
882  current_loop_ = false;
883  const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
884  const Label *end = start + inst_.num_futures_;
885  const Label *search = std::lower_bound(start, end, label);
886  if (search != end && *search == label) {
887  size_t state = search - start;
888  arc_.ilabel = arc_.olabel = label;
889  arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
890  fst_.GetImpl()->SetInstContext(&inst_);
891  arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
892  done_ = false;
893  }
894  }
895  return !Done();
896  }
897 
898  bool Done() const final { return !current_loop_ && done_; }
899 
900  const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; }
901 
902  void Next() final {
903  if (current_loop_) {
904  current_loop_ = false;
905  } else {
906  done_ = true;
907  }
908  }
909 
910  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
911 
912  private:
913  std::unique_ptr<NGramFst<A>> owned_fst_;
914  const NGramFst<A> &fst_;
915  NGramFstInst<A> inst_;
916  MatchType match_type_; // Supplied by caller
917  bool done_;
918  Arc arc_;
919  bool current_loop_; // Current arc is the implicit loop
920  Arc loop_;
921 };
922 
923 /*****************************************************************************/
924 // Specialization for NGramFst; see generic version in fst.h
925 // for sample usage (but use the ProdLmFst type!). This version
926 // should inline.
927 template <class A>
928 class StateIterator<NGramFst<A>> : public StateIteratorBase<A> {
929  public:
930  typedef typename A::StateId StateId;
931 
932  explicit StateIterator(const NGramFst<A> &fst)
933  : s_(0), num_states_(fst.NumStates()) {}
934 
935  bool Done() const final { return s_ >= num_states_; }
936 
937  StateId Value() const final { return s_; }
938 
939  void Next() final { ++s_; }
940 
941  void Reset() final { s_ = 0; }
942 
943  private:
944  StateId s_;
945  StateId num_states_;
946 };
947 
948 /*****************************************************************************/
949 template <class A>
950 class ArcIterator<NGramFst<A>> : public ArcIteratorBase<A> {
951  public:
952  typedef A Arc;
953  typedef typename A::Label Label;
954  typedef typename A::StateId StateId;
955  typedef typename A::Weight Weight;
956 
957  ArcIterator(const NGramFst<A> &fst, StateId state)
958  : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
959  inst_ = fst.inst_;
960  impl_->SetInstFuture(state, &inst_);
961  impl_->SetInstNode(&inst_);
962  }
963 
964  bool Done() const final {
965  return i_ >=
966  ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1);
967  }
968 
969  const Arc &Value() const final {
970  bool eps = (inst_.node_ != 0 && i_ == 0);
971  StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
972  if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
973  arc_.ilabel = arc_.olabel =
974  eps ? 0 : impl_->future_words_[inst_.offset_ + state];
975  lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
976  }
977  if (flags_ & lazy_ & kArcNextStateValue) {
978  if (eps) {
979  arc_.nextstate =
980  impl_->context_index_.Rank1(impl_->context_index_.Select1(
981  impl_->context_index_.Rank0(inst_.node_) - 1));
982  } else {
983  if (lazy_ & kArcNextStateValue) {
984  impl_->SetInstContext(&inst_); // first time only.
985  }
986  arc_.nextstate = impl_->Transition(
987  inst_.context_, impl_->future_words_[inst_.offset_ + state]);
988  }
989  lazy_ &= ~kArcNextStateValue;
990  }
991  if (flags_ & lazy_ & kArcWeightValue) {
992  arc_.weight = eps ? impl_->backoff_[inst_.state_]
993  : impl_->future_probs_[inst_.offset_ + state];
994  lazy_ &= ~kArcWeightValue;
995  }
996  return arc_;
997  }
998 
999  void Next() final {
1000  ++i_;
1001  lazy_ = ~0;
1002  }
1003 
1004  size_t Position() const final { return i_; }
1005 
1006  void Reset() final {
1007  i_ = 0;
1008  lazy_ = ~0;
1009  }
1010 
1011  void Seek(size_t a) final {
1012  if (i_ != a) {
1013  i_ = a;
1014  lazy_ = ~0;
1015  }
1016  }
1017 
1018  uint8_t Flags() const final { return flags_; }
1019 
1020  void SetFlags(uint8_t flags, uint8_t mask) final {
1021  flags_ &= ~mask;
1022  flags_ |= (flags & kArcValueFlags);
1023  }
1024 
1025  private:
1026  mutable Arc arc_;
1027  mutable uint8_t lazy_;
1028  const internal::NGramFstImpl<A> *impl_; // Borrowed reference.
1029  mutable NGramFstInst<A> inst_;
1030 
1031  size_t i_;
1032  uint8_t flags_;
1033 };
1034 
1035 } // namespace fst
1036 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
constexpr uint64_t kCyclic
Definition: properties.h:108
internal::NGramFstImpl< A > Impl
Definition: ngram-fst.h:345
void GetStates(const std::vector< Label > &context, std::vector< StateId > *state) const
Definition: ngram-fst.h:378
constexpr uint64_t kNotString
Definition: properties.h:138
void AddArc(StateId s, const Arc &arc) override
Definition: mutable-fst.h:326
NGramFst< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:387
void SetStart(StateId s) override
Definition: mutable-fst.h:298
constexpr uint8_t kArcValueFlags
Definition: fst.h:453
size_t num_futures_
Definition: ngram-fst.h:57
void Next() final
Definition: ngram-fst.h:902
constexpr int kNoLabel
Definition: fst.h:201
size_t NumInputEpsilons(StateId state) const
Definition: ngram-fst.h:163
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kOEpsilons
Definition: properties.h:88
const Arc & Value() const final
Definition: ngram-fst.h:900
void SetFinal(StateId s, Weight weight=Weight::One()) override
Definition: mutable-fst.h:303
virtual size_t NumArcs(StateId) const =0
static size_t Storage(uint64_t num_states, uint64_t num_futures, uint64_t num_final)
Definition: ngram-fst.h:180
constexpr uint64_t kCoAccessible
Definition: properties.h:128
void SetInstContext(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:216
uint8_t Flags() const final
Definition: ngram-fst.h:1018
static void Set(uint64_t *bits, size_t index)
Definition: bitmap-index.h:115
constexpr uint64_t kNotTopSorted
Definition: properties.h:120
MatchType
Definition: fst.h:193
StateId state_
Definition: ngram-fst.h:56
size_t Position() const final
Definition: ngram-fst.h:1004
constexpr uint64_t kError
Definition: properties.h:51
size_t StorageSize() const
Definition: ngram-fst.h:429
constexpr uint64_t kInitialAcyclic
Definition: properties.h:115
static NGramFstImpl< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:108
#define LOG(type)
Definition: log.h:49
StateId Start() const override
Definition: fst.h:950
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:52
StateId node_state_
Definition: ngram-fst.h:60
uint64_t Properties(uint64_t props) const override
Definition: ngram-fst.h:856
constexpr uint8_t kArcILabelValue
Definition: fst.h:444
constexpr uint64_t kEpsilons
Definition: properties.h:78
constexpr uint64_t kODeterministic
Definition: properties.h:73
const Arc & Value() const final
Definition: ngram-fst.h:969
constexpr int kNoStateId
Definition: fst.h:202
NGramFstMatcher< A > * Copy(bool safe=false) const override
Definition: ngram-fst.h:848
const Arc & Value() const
Definition: fst.h:537
NGramFstMatcher(const NGramFstMatcher< A > &matcher, bool safe=false)
Definition: ngram-fst.h:836
static MappedFile * Allocate(size_t size, size_t align=kArchAlignment)
Definition: mapped-file.cc:190
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
Definition: ngram-fst.h:475
NGramFst(const Fst< A > &fst, std::vector< StateId > *order_out)
Definition: ngram-fst.h:350
const Fst< A > & GetFst() const override
Definition: ngram-fst.h:854
virtual size_t NumInputEpsilons(StateId) const =0
#define FSTERROR()
Definition: util.h:53
StateId AddState() override
Definition: mutable-fst.h:316
NGramFst(const Fst< A > &dst)
Definition: ngram-fst.h:347
size_t NumArcs(StateId state, NGramFstInst< A > *inst=nullptr) const
Definition: ngram-fst.h:153
const std::vector< Label > & GetContext(StateId s, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:238
void SetOutputSymbols(const SymbolTable *osyms) override
Definition: mutable-fst.h:396
static size_t StorageSize(size_t num_bits)
Definition: bitmap-index.h:93
NGramFst(const NGramFst< A > &fst, bool safe=false)
Definition: ngram-fst.h:355
A::Weight Weight
Definition: ngram-fst.h:55
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:108
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:368
A::StateId StateId
Definition: ngram-fst.h:54
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:102
constexpr uint64_t kOLabelSorted
Definition: properties.h:98
void SetInstNode(NGramFstInst< A > *inst) const
Definition: ngram-fst.h:209
void SetFlags(uint8_t flags, uint8_t mask) final
Definition: ngram-fst.h:1020
static NGramFst< A > * Read(std::istream &strm, const FstReadOptions &opts)
Definition: ngram-fst.h:391
void SetInputSymbols(const SymbolTable *isyms) override
Definition: mutable-fst.h:391
std::vector< Label > context_
Definition: ngram-fst.h:61
MatcherBase< A > * InitMatcher(MatchType match_type) const override
Definition: ngram-fst.h:425
bool Write(const std::string &source) const override
Definition: ngram-fst.h:414
StateId Value() const final
Definition: ngram-fst.h:937
StateId NumStates() const override
Definition: expanded-fst.h:134
NGramFstMatcher(const NGramFst< A > &fst, MatchType match_type)
Definition: ngram-fst.h:811
StateId nstates
Definition: fst.h:384
constexpr uint64_t kAccessible
Definition: properties.h:123
constexpr uint8_t kArcOLabelValue
Definition: fst.h:446
const char * GetData(size_t *data_size) const
Definition: ngram-fst.h:230
MatchType Type(bool test) const override
Definition: ngram-fst.h:852
size_t NumArcs(StateId s) const override
Definition: ngram-fst.h:383
void SetInstFuture(StateId state, NGramFstInst< A > *inst) const
Definition: ngram-fst.h:200
virtual StateId Start() const =0
bool Done() const
Definition: fst.h:533
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
A::StateId StateId
Definition: ngram-fst.h:807
constexpr uint64_t kIDeterministic
Definition: properties.h:68
void SetState(StateId s) final
Definition: ngram-fst.h:858
NGramFstImpl(const Fst< A > &fst)
Definition: ngram-fst.h:101
std::unique_ptr< ArcIteratorBase< Arc > > base
Definition: fst.h:495
const std::vector< Label > GetContext(StateId s) const
Definition: ngram-fst.h:372
constexpr uint64_t kIEpsilons
Definition: properties.h:83
void Seek(size_t a) final
Definition: ngram-fst.h:1011
bool WriteFile(const std::string &source) const
Definition: fst.h:332
StateId NumStates() const
Definition: ngram-fst.h:173
A::Label Label
Definition: ngram-fst.h:343
constexpr uint8_t kArcWeightValue
Definition: fst.h:448
size_t StorageSize() const
Definition: ngram-fst.h:244
constexpr uint64_t kILabelSorted
Definition: properties.h:93
void SetNumStates(int64_t numstates)
Definition: fst.h:170
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
Definition: ngram-fst.h:410
StateId context_state_
Definition: ngram-fst.h:62
virtual const SymbolTable * InputSymbols() const =0
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:169
constexpr uint8_t kArcNextStateValue
Definition: fst.h:450
ssize_t Priority(StateId s) final
Definition: ngram-fst.h:910
ArcIterator(const NGramFst< A > &fst, StateId state)
Definition: ngram-fst.h:957
Weight Final(StateId s) const override
Definition: fst.h:952
size_t NumOutputEpsilons(StateId state) const
Definition: ngram-fst.h:169
NGramFstImpl(const NGramFstImpl &other)
Definition: ngram-fst.h:103
void SetStart(int64_t start)
Definition: fst.h:168
bool Done() const final
Definition: ngram-fst.h:898
StateId Start() const
Definition: ngram-fst.h:143
NGramFstMatcher(const NGramFst< A > *fst, MatchType match_type)
Definition: ngram-fst.h:824
A::Label Label
Definition: ngram-fst.h:53
NGramFst(const char *data)
Definition: ngram-fst.h:362
A::Weight Weight
Definition: ngram-fst.h:344
void InitStateIterator(StateIteratorData< A > *data) const override
Definition: ngram-fst.h:418
Weight Final(StateId state) const
Definition: ngram-fst.h:145
void InitStateIterator(StateIteratorData< A > *data) const
Definition: ngram-fst.h:175
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
Definition: ngram-fst.h:134
A::StateId StateId
Definition: ngram-fst.h:342
constexpr uint64_t kWeighted
Definition: properties.h:103
bool Find(Label label) final
Definition: ngram-fst.h:863
constexpr uint64_t kExpanded
Definition: properties.h:45
static bool HasRequiredStructure(const Fst< A > &fst)
Definition: ngram-fst.h:437
StateIterator(const NGramFst< A > &fst)
Definition: ngram-fst.h:932
static NGramFst< A > * Read(const std::string &source)
Definition: ngram-fst.h:396
static bool HasRequiredProps(const Fst< A > &fst)
Definition: ngram-fst.h:431
uint64_t Properties(uint64_t mask, bool test) const override
Definition: fst.h:966
constexpr uint64_t kAcceptor
Definition: properties.h:63
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:541