FST  openfst-1.7.1
OpenFst Library
linear-fst-data.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Data structures for storing and looking up the actual feature weights.
5 
6 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
7 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
8 
9 #include <memory>
10 #include <numeric>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 
15 #include <fst/compat.h>
16 #include <fst/fst.h>
17 
19 
20 namespace fst {
21 
22 // Forward declarations
23 template <class A>
24 class LinearFstDataBuilder;
25 template <class A>
27 
28 // Immutable data storage of the feature weights in a linear
29 // model. Produces state tuples that represent internal states of a
30 // LinearTaggerFst. Object of this class can only be constructed via
31 // either `LinearFstDataBuilder::Dump()` or `LinearFstData::Read()`
32 // and usually used as refcount'd object shared across mutiple
33 // `LinearTaggerFst` copies.
34 //
35 // TODO(wuke): more efficient trie implementation
36 template <class A>
38  public:
39  friend class LinearFstDataBuilder<A>; // For builder access
40 
41  typedef typename A::Label Label;
42  typedef typename A::Weight Weight;
43 
44  // Sentence boundary labels. Both of them are negative labels other
45  // than `kNoLabel`.
46  static const Label kStartOfSentence;
47  static const Label kEndOfSentence;
48 
49  // Constructs empty data; for non-trivial ways of construction see
50  // `Read()` and `LinearFstDataBuilder`.
52  : max_future_size_(0), max_input_label_(1), input_attribs_(1) {}
53 
54  // Appends the state tuple of the start state to `output`, where
55  // each tuple holds the node ids of a trie for each feature group.
56  void EncodeStartState(std::vector<Label> *output) const {
57  for (int i = 0; i < NumGroups(); ++i) output->push_back(GroupStartState(i));
58  }
59 
60  // Takes a transition from the trie states stored in
61  // `(trie_state_begin, trie_state_end)` with input label `ilabel`
62  // and output label `olabel`; appends the destination state tuple to
63  // `next` and multiplies the weight of the transition onto
64  // `weight`. `next` should be the shifted input buffer of the caller
65  // in `LinearTaggerFstImpl` (i.e. of size `LinearTaggerFstImpl::delay_`;
66  // the last element is `ilabel`).
67  template <class Iterator>
68  void TakeTransition(Iterator buffer_end, Iterator trie_state_begin,
69  Iterator trie_state_end, Label ilabel, Label olabel,
70  std::vector<Label> *next, Weight *weight) const;
71 
72  // Returns the final weight of the given trie state sequence.
73  template <class Iterator>
74  Weight FinalWeight(Iterator trie_state_begin, Iterator trie_state_end) const;
75 
76  // Returns the start trie state of the given group.
77  Label GroupStartState(int group_id) const {
78  return groups_[group_id]->Start();
79  }
80 
81  // Takes a transition only within the given group. Returns the
82  // destination trie state and multiplies the weight onto `weight`.
83  Label GroupTransition(int group_id, int trie_state, Label ilabel,
84  Label olabel, Weight *weight) const;
85 
86  // Returns the final weight of the given trie state in the given group.
87  Weight GroupFinalWeight(int group_id, int trie_state) const {
88  return groups_[group_id]->FinalWeight(trie_state);
89  }
90 
91  Label MinInputLabel() const { return 1; }
92 
93  Label MaxInputLabel() const { return max_input_label_; }
94 
95  // Returns the maximum future size of all feature groups. Future is
96  // the look-ahead window of a feature, e.g. if a feature looks at
97  // the next 2 words after the current input, then the future size is
98  // 2. There is no look-ahead for output. Features inside a single
99  // `FeatureGroup` must have equal future size.
100  size_t MaxFutureSize() const { return max_future_size_; }
101 
102  // Returns the number of feature groups
103  size_t NumGroups() const { return groups_.size(); }
104 
105  // Returns the range of possible output labels for an input label.
106  std::pair<typename std::vector<Label>::const_iterator,
107  typename std::vector<Label>::const_iterator>
108  PossibleOutputLabels(Label word) const;
109 
110  static LinearFstData<A> *Read(std::istream &strm); // NOLINT
111  std::ostream &Write(std::ostream &strm) const; // NOLINT
112 
113  private:
114  // Offsets in `output_pool_`
115  struct InputAttribute {
116  size_t output_begin, output_length;
117 
118  std::istream &Read(std::istream &strm); // NOLINT
119  std::ostream &Write(std::ostream &strm) const; // NOLINT
120  };
121 
122  // Mapping from input label to per-group feature label
123  class GroupFeatureMap;
124 
125  // Translates the input label into input feature label of group
126  // `group`; returns `kNoLabel` when there is no feature for that
127  // group.
128  Label FindFeature(size_t group, Label word) const;
129 
130  size_t max_future_size_;
131  Label max_input_label_;
132  std::vector<std::unique_ptr<const FeatureGroup<A>>> groups_;
133  std::vector<InputAttribute> input_attribs_;
134  std::vector<Label> output_pool_, output_set_;
135  GroupFeatureMap group_feat_map_;
136 
137  LinearFstData(const LinearFstData &) = delete;
138  LinearFstData &operator=(const LinearFstData &) = delete;
139 };
140 
141 template <class A>
142 const typename A::Label LinearFstData<A>::kStartOfSentence = -3;
143 template <class A>
144 const typename A::Label LinearFstData<A>::kEndOfSentence = -2;
145 
146 template <class A>
147 template <class Iterator>
148 void LinearFstData<A>::TakeTransition(Iterator buffer_end,
149  Iterator trie_state_begin,
150  Iterator trie_state_end, Label ilabel,
151  Label olabel, std::vector<Label> *next,
152  Weight *weight) const {
153  DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size());
154  DCHECK(ilabel > 0 || ilabel == kEndOfSentence);
155  DCHECK(olabel > 0 || olabel == kStartOfSentence);
156  size_t group_id = 0;
157  for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id) {
158  size_t delay = groups_[group_id]->Delay();
159  // On the buffer, there may also be `kStartOfSentence` from the
160  // initial empty buffer.
161  Label real_ilabel = delay == 0 ? ilabel : *(buffer_end - delay);
162  next->push_back(
163  GroupTransition(group_id, *it, real_ilabel, olabel, weight));
164  }
165 }
166 
167 template <class A>
168 typename A::Label LinearFstData<A>::GroupTransition(int group_id,
169  int trie_state,
170  Label ilabel, Label olabel,
171  Weight *weight) const {
172  Label group_ilabel = FindFeature(group_id, ilabel);
173  return groups_[group_id]->Walk(trie_state, group_ilabel, olabel, weight);
174 }
175 
176 template <class A>
177 template <class Iterator>
178 inline typename A::Weight LinearFstData<A>::FinalWeight(
179  Iterator trie_state_begin, Iterator trie_state_end) const {
180  DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size());
181  size_t group_id = 0;
182  Weight accum = Weight::One();
183  for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id)
184  accum = Times(accum, GroupFinalWeight(group_id, *it));
185  return accum;
186 }
187 
188 template <class A>
189 inline std::pair<typename std::vector<typename A::Label>::const_iterator,
190  typename std::vector<typename A::Label>::const_iterator>
192  const InputAttribute &attrib = input_attribs_[word];
193  if (attrib.output_length == 0)
194  return std::make_pair(output_set_.begin(), output_set_.end());
195  else
196  return std::make_pair(
197  output_pool_.begin() + attrib.output_begin,
198  output_pool_.begin() + attrib.output_begin + attrib.output_length);
199 }
200 
201 template <class A>
202 inline LinearFstData<A> *LinearFstData<A>::Read(std::istream &strm) { // NOLINT
203  std::unique_ptr<LinearFstData<A>> data(new LinearFstData<A>());
204  ReadType(strm, &(data->max_future_size_));
205  ReadType(strm, &(data->max_input_label_));
206  // Feature groups
207  size_t num_groups = 0;
208  ReadType(strm, &num_groups);
209  data->groups_.resize(num_groups);
210  for (size_t i = 0; i < num_groups; ++i)
211  data->groups_[i].reset(FeatureGroup<A>::Read(strm));
212  // Other data
213  ReadType(strm, &(data->input_attribs_));
214  ReadType(strm, &(data->output_pool_));
215  ReadType(strm, &(data->output_set_));
216  ReadType(strm, &(data->group_feat_map_));
217  if (strm) {
218  return data.release();
219  } else {
220  return nullptr;
221  }
222 }
223 
224 template <class A>
225 inline std::ostream &LinearFstData<A>::Write(
226  std::ostream &strm) const { // NOLINT
227  WriteType(strm, max_future_size_);
228  WriteType(strm, max_input_label_);
229  // Feature groups
230  WriteType(strm, groups_.size());
231  for (size_t i = 0; i < groups_.size(); ++i) {
232  groups_[i]->Write(strm);
233  }
234  // Other data
235  WriteType(strm, input_attribs_);
236  WriteType(strm, output_pool_);
237  WriteType(strm, output_set_);
238  WriteType(strm, group_feat_map_);
239  return strm;
240 }
241 
242 template <class A>
243 typename A::Label LinearFstData<A>::FindFeature(size_t group,
244  Label word) const {
245  DCHECK(word > 0 || word == kStartOfSentence || word == kEndOfSentence);
246  if (word == kStartOfSentence || word == kEndOfSentence)
247  return word;
248  else
249  return group_feat_map_.Find(group, word);
250 }
251 
252 template <class A>
253 inline std::istream &LinearFstData<A>::InputAttribute::Read(
254  std::istream &strm) { // NOLINT
255  ReadType(strm, &output_begin);
256  ReadType(strm, &output_length);
257  return strm;
258 }
259 
260 template <class A>
261 inline std::ostream &LinearFstData<A>::InputAttribute::Write(
262  std::ostream &strm) const { // NOLINT
263  WriteType(strm, output_begin);
264  WriteType(strm, output_length);
265  return strm;
266 }
267 
268 // Forward declaration
269 template <class A>
270 class FeatureGroupBuilder;
271 
272 // An immutable grouping of features with similar context shape. Like
273 // `LinearFstData`, this can only be constructed via `Read()` or
274 // via its builder.
275 //
276 // Internally it uses a trie to store all feature n-grams and their
277 // weights. The label of a trie edge is a pair (feat, olabel) of
278 // labels. They can be either positive (ordinary label), `kNoLabel`,
279 // `kStartOfSentence`, or `kEndOfSentence`. `kNoLabel` usually means
280 // matching anything, with one exception: from the root of the trie,
281 // there is a special (kNoLabel, kNoLabel) that leads to the implicit
282 // start-of-sentence state. This edge is never actually matched
283 // (`FindFirstMatch()` ensures this).
284 template <class A>
285 class FeatureGroup {
286  public:
287  friend class FeatureGroupBuilder<A>; // for builder access
288 
289  typedef typename A::Label Label;
290  typedef typename A::Weight Weight;
291 
292  int Start() const { return start_; }
293 
294  // Finds destination node from `cur` by consuming `ilabel` and
295  // `olabel`. The transition weight is multiplied onto `weight`.
296  int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const;
297 
298  // Returns the final weight of the current trie state. Only valid if
299  // the state is already known to be part of a final state (see
300  // `LinearFstData<>::CanBeFinal()`).
301  Weight FinalWeight(int trie_state) const {
302  return trie_[trie_state].final_weight;
303  }
304 
305  static FeatureGroup<A> *Read(std::istream &strm) { // NOLINT
306  size_t delay;
307  ReadType(strm, &delay);
308  int start;
309  ReadType(strm, &start);
310  Trie trie;
311  ReadType(strm, &trie);
312  std::unique_ptr<FeatureGroup<A>> ret(new FeatureGroup<A>(delay, start));
313  ret->trie_.swap(trie);
314  ReadType(strm, &ret->next_state_);
315  if (strm) {
316  return ret.release();
317  } else {
318  return nullptr;
319  }
320  }
321 
322  std::ostream &Write(std::ostream &strm) const { // NOLINT
323  WriteType(strm, delay_);
324  WriteType(strm, start_);
325  WriteType(strm, trie_);
326  WriteType(strm, next_state_);
327  return strm;
328  }
329 
330  size_t Delay() const { return delay_; }
331 
332  string Stats() const;
333 
334  private:
335  // Label along the arcs on the trie. `kNoLabel` means anything
336  // (non-negative label) can match; both sides holding `kNoLabel`
337  // is not allow; otherwise the label is > 0 (enforced by
338  // `LinearFstDataBuilder::AddWeight()`).
339  struct InputOutputLabel;
340  struct InputOutputLabelHash;
341 
342  // Data to be stored on the trie
343  struct WeightBackLink {
344  int back_link;
345  Weight weight, final_weight;
346 
347  WeightBackLink()
348  : back_link(kNoTrieNodeId),
349  weight(Weight::One()),
350  final_weight(Weight::One()) {}
351 
352  std::istream &Read(std::istream &strm) { // NOLINT
353  ReadType(strm, &back_link);
354  ReadType(strm, &weight);
355  ReadType(strm, &final_weight);
356  return strm;
357  }
358 
359  std::ostream &Write(std::ostream &strm) const { // NOLINT
360  WriteType(strm, back_link);
361  WriteType(strm, weight);
362  WriteType(strm, final_weight);
363  return strm;
364  }
365  };
366 
369 
370  explicit FeatureGroup(size_t delay, int start)
371  : delay_(delay), start_(start) {}
372 
373  // Finds the first node with an arc with `label` following the
374  // back-off chain of `parent`. Returns the node index or
375  // `kNoTrieNodeId` when not found.
376  int FindFirstMatch(InputOutputLabel label, int parent) const;
377 
378  size_t delay_;
379  int start_;
380  Trie trie_;
381  // Where to go after hitting this state. When we reach a state with
382  // no child and with no additional final weight (i.e. its final
383  // weight is the same as its back-off), we can immediately go to its
384  // back-off state.
385  std::vector<int> next_state_;
386 
387  FeatureGroup(const FeatureGroup &) = delete;
388  FeatureGroup &operator=(const FeatureGroup &) = delete;
389 };
390 
391 template <class A>
393  Label input, output;
394 
396  : input(i), output(o) {}
397 
398  bool operator==(InputOutputLabel that) const {
399  return input == that.input && output == that.output;
400  }
401 
402  std::istream &Read(std::istream &strm) { // NOLINT
403  ReadType(strm, &input);
404  ReadType(strm, &output);
405  return strm;
406  }
407 
408  std::ostream &Write(std::ostream &strm) const { // NOLINT
409  WriteType(strm, input);
410  WriteType(strm, output);
411  return strm;
412  }
413 };
414 
415 template <class A>
417  size_t operator()(InputOutputLabel label) const {
418  return static_cast<size_t>(label.input * 7853 + label.output);
419  }
420 };
421 
422 template <class A>
423 int FeatureGroup<A>::Walk(int cur, Label ilabel, Label olabel,
424  Weight *weight) const {
425  // Note: user of this method need to ensure `ilabel` and `olabel`
426  // are valid (e.g. see DCHECKs in
427  // `LinearFstData<>::TakeTransition()` and
428  // `LinearFstData<>::FindFeature()`).
429  int next;
430  if (ilabel == LinearFstData<A>::kStartOfSentence) {
431  // An observed start-of-sentence only occurs in the beginning of
432  // the input, when this feature group is delayed (i.e. there is
433  // another feature group with a larger future size). The actual
434  // input hasn't arrived so stay at the start state.
435  DCHECK_EQ(cur, start_);
436  next = start_;
437  } else {
438  // First, try exact match
439  next = FindFirstMatch(InputOutputLabel(ilabel, olabel), cur);
440  // Then try with don't cares
441  if (next == kNoTrieNodeId)
442  next = FindFirstMatch(InputOutputLabel(ilabel, kNoLabel), cur);
443  if (next == kNoTrieNodeId)
444  next = FindFirstMatch(InputOutputLabel(kNoLabel, olabel), cur);
445  // All failed, go to empty context
446  if (next == kNoTrieNodeId) next = trie_.Root();
447  *weight = Times(*weight, trie_[next].weight);
448  next = next_state_[next];
449  }
450  return next;
451 }
452 
453 template <class A>
455  int parent) const {
456  if (label.input == kNoLabel && label.output == kNoLabel)
457  return kNoTrieNodeId; // very important; see class doc.
458  for (; parent != kNoTrieNodeId; parent = trie_[parent].back_link) {
459  int next = trie_.Find(parent, label);
460  if (next != kNoTrieNodeId) return next;
461  }
462  return kNoTrieNodeId;
463 }
464 
465 template <class A>
466 inline string FeatureGroup<A>::Stats() const {
467  std::ostringstream strm;
468  int num_states = 2;
469  for (int i = 2; i < next_state_.size(); ++i)
470  num_states += i == next_state_[i];
471  strm << trie_.NumNodes() << " node(s); " << num_states << " state(s)";
472  return strm.str();
473 }
474 
475 template <class A>
477  public:
479 
480  void Init(size_t num_groups, size_t num_words) {
481  num_groups_ = num_groups;
482  pool_.clear();
483  pool_.resize(num_groups * num_words, kNoLabel);
484  }
485 
486  Label Find(size_t group_id, Label ilabel) const {
487  return pool_[IndexOf(group_id, ilabel)];
488  }
489 
490  bool Set(size_t group_id, Label ilabel, Label feat) {
491  size_t i = IndexOf(group_id, ilabel);
492  if (pool_[i] != kNoLabel && pool_[i] != feat) {
493  FSTERROR() << "Feature group " << group_id
494  << " already has feature for word " << ilabel;
495  return false;
496  }
497  pool_[i] = feat;
498  return true;
499  }
500 
501  std::istream &Read(std::istream &strm) { // NOLINT
502  ReadType(strm, &num_groups_);
503  ReadType(strm, &pool_);
504  return strm;
505  }
506 
507  std::ostream &Write(std::ostream &strm) const { // NOLINT
508  WriteType(strm, num_groups_);
509  WriteType(strm, pool_);
510  return strm;
511  }
512 
513  private:
514  size_t IndexOf(size_t group_id, Label ilabel) const {
515  return ilabel * num_groups_ + group_id;
516  }
517 
518  size_t num_groups_;
519  // `pool_[ilabel * num_groups_ + group_id]` is the feature active
520  // for group `group_id` with input `ilabel`
521  std::vector<Label> pool_;
522 };
523 
524 } // namespace fst
525 
526 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
std::ostream & Write(std::ostream &strm) const
constexpr int kNoLabel
Definition: fst.h:179
static const Label kEndOfSentence
bool operator==(InputOutputLabel that) const
bool Set(size_t group_id, Label ilabel, Label feat)
std::ostream & Write(std::ostream &strm) const
Label GroupStartState(int group_id) const
string Stats() const
std::pair< typename std::vector< Label >::const_iterator, typename std::vector< Label >::const_iterator > PossibleOutputLabels(Label word) const
void TakeTransition(Iterator buffer_end, Iterator trie_state_begin, Iterator trie_state_end, Label ilabel, Label olabel, std::vector< Label > *next, Weight *weight) const
int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const
size_t Delay() const
std::istream & Read(std::istream &strm)
Weight GroupFinalWeight(int group_id, int trie_state) const
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
static const Label kStartOfSentence
void Init(size_t num_groups, size_t num_words)
#define DCHECK_EQ(x, y)
Definition: log.h:71
void EncodeStartState(std::vector< Label > *output) const
Label GroupTransition(int group_id, int trie_state, Label ilabel, Label olabel, Weight *weight) const
Label Find(size_t group_id, Label ilabel) const
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
#define FSTERROR()
Definition: util.h:35
int Start() const
Weight FinalWeight(Iterator trie_state_begin, Iterator trie_state_end) const
const int kNoTrieNodeId
Definition: trie.h:16
std::ostream & Write(std::ostream &strm) const
static LinearFstData< A > * Read(std::istream &strm)
Weight FinalWeight(int trie_state) const
size_t NumGroups() const
Label MinInputLabel() const
size_t MaxFutureSize() const
size_t operator()(InputOutputLabel label) const
std::istream & Read(std::istream &strm)
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
#define DCHECK(x)
Definition: log.h:70
Label MaxInputLabel() const
std::ostream & Write(std::ostream &strm) const
static FeatureGroup< A > * Read(std::istream &strm)
InputOutputLabel(Label i=kNoLabel, Label o=kNoLabel)