FST  openfst-1.8.2.post1
OpenFst Library
linear-fst-data.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Data structures for storing and looking up the actual feature weights.
19 
20 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
21 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
22 
23 #include <memory>
24 #include <numeric>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/compat.h>
31 #include <fst/fst.h>
32 
33 namespace fst {
34 
35 // Forward declarations
36 template <class A>
37 class LinearFstDataBuilder;
38 template <class A>
40 
41 // Immutable data storage of the feature weights in a linear
42 // model. Produces state tuples that represent internal states of a
43 // LinearTaggerFst. Object of this class can only be constructed via
44 // either `LinearFstDataBuilder::Dump()` or `LinearFstData::Read()`
45 // and usually used as refcount'd object shared across mutiple
46 // `LinearTaggerFst` copies.
47 //
48 // TODO(wuke): more efficient trie implementation
49 template <class A>
51  public:
52  friend class LinearFstDataBuilder<A>; // For builder access
53 
54  typedef typename A::Label Label;
55  typedef typename A::Weight Weight;
56 
57  // Sentence boundary labels. Both of them are negative labels other
58  // than `kNoLabel`.
59  static constexpr Label kStartOfSentence = -3;
60  static constexpr Label kEndOfSentence = -2;
61 
62  // Constructs empty data; for non-trivial ways of construction see
63  // `Read()` and `LinearFstDataBuilder`.
65  : max_future_size_(0), max_input_label_(1), input_attribs_(1) {}
66 
67  // Appends the state tuple of the start state to `output`, where
68  // each tuple holds the node ids of a trie for each feature group.
69  void EncodeStartState(std::vector<Label> *output) const {
70  for (int i = 0; i < NumGroups(); ++i) output->push_back(GroupStartState(i));
71  }
72 
73  // Takes a transition from the trie states stored in
74  // `(trie_state_begin, trie_state_end)` with input label `ilabel`
75  // and output label `olabel`; appends the destination state tuple to
76  // `next` and multiplies the weight of the transition onto
77  // `weight`. `next` should be the shifted input buffer of the caller
78  // in `LinearTaggerFstImpl` (i.e. of size `LinearTaggerFstImpl::delay_`;
79  // the last element is `ilabel`).
80  template <class Iterator>
81  void TakeTransition(Iterator buffer_end, Iterator trie_state_begin,
82  Iterator trie_state_end, Label ilabel, Label olabel,
83  std::vector<Label> *next, Weight *weight) const;
84 
85  // Returns the final weight of the given trie state sequence.
86  template <class Iterator>
87  Weight FinalWeight(Iterator trie_state_begin, Iterator trie_state_end) const;
88 
89  // Returns the start trie state of the given group.
90  Label GroupStartState(int group_id) const {
91  return groups_[group_id]->Start();
92  }
93 
94  // Takes a transition only within the given group. Returns the
95  // destination trie state and multiplies the weight onto `weight`.
96  Label GroupTransition(int group_id, int trie_state, Label ilabel,
97  Label olabel, Weight *weight) const;
98 
99  // Returns the final weight of the given trie state in the given group.
100  Weight GroupFinalWeight(int group_id, int trie_state) const {
101  return groups_[group_id]->FinalWeight(trie_state);
102  }
103 
104  Label MinInputLabel() const { return 1; }
105 
106  Label MaxInputLabel() const { return max_input_label_; }
107 
108  // Returns the maximum future size of all feature groups. Future is
109  // the look-ahead window of a feature, e.g. if a feature looks at
110  // the next 2 words after the current input, then the future size is
111  // 2. There is no look-ahead for output. Features inside a single
112  // `FeatureGroup` must have equal future size.
113  size_t MaxFutureSize() const { return max_future_size_; }
114 
115  // Returns the number of feature groups
116  size_t NumGroups() const { return groups_.size(); }
117 
118  // Returns the range of possible output labels for an input label.
119  std::pair<typename std::vector<Label>::const_iterator,
120  typename std::vector<Label>::const_iterator>
121  PossibleOutputLabels(Label word) const;
122 
123  static LinearFstData<A> *Read(std::istream &strm);
124  std::ostream &Write(std::ostream &strm) const;
125 
126  private:
127  // Offsets in `output_pool_`
128  struct InputAttribute {
129  size_t output_begin, output_length;
130 
131  std::istream &Read(std::istream &strm);
132  std::ostream &Write(std::ostream &strm) const;
133  };
134 
135  // Mapping from input label to per-group feature label
136  class GroupFeatureMap;
137 
138  // Translates the input label into input feature label of group
139  // `group`; returns `kNoLabel` when there is no feature for that
140  // group.
141  Label FindFeature(size_t group, Label word) const;
142 
143  size_t max_future_size_;
144  Label max_input_label_;
145  std::vector<std::unique_ptr<const FeatureGroup<A>>> groups_;
146  std::vector<InputAttribute> input_attribs_;
147  std::vector<Label> output_pool_, output_set_;
148  GroupFeatureMap group_feat_map_;
149 
150  LinearFstData(const LinearFstData &) = delete;
151  LinearFstData &operator=(const LinearFstData &) = delete;
152 };
153 
154 template <class A>
155 template <class Iterator>
156 void LinearFstData<A>::TakeTransition(Iterator buffer_end,
157  Iterator trie_state_begin,
158  Iterator trie_state_end, Label ilabel,
159  Label olabel, std::vector<Label> *next,
160  Weight *weight) const {
161  DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size());
162  DCHECK(ilabel > 0 || ilabel == kEndOfSentence);
163  DCHECK(olabel > 0 || olabel == kStartOfSentence);
164  size_t group_id = 0;
165  for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id) {
166  size_t delay = groups_[group_id]->Delay();
167  // On the buffer, there may also be `kStartOfSentence` from the
168  // initial empty buffer.
169  Label real_ilabel = delay == 0 ? ilabel : *(buffer_end - delay);
170  next->push_back(
171  GroupTransition(group_id, *it, real_ilabel, olabel, weight));
172  }
173 }
174 
175 template <class A>
176 typename A::Label LinearFstData<A>::GroupTransition(int group_id,
177  int trie_state,
178  Label ilabel, Label olabel,
179  Weight *weight) const {
180  Label group_ilabel = FindFeature(group_id, ilabel);
181  return groups_[group_id]->Walk(trie_state, group_ilabel, olabel, weight);
182 }
183 
184 template <class A>
185 template <class Iterator>
186 inline typename A::Weight LinearFstData<A>::FinalWeight(
187  Iterator trie_state_begin, Iterator trie_state_end) const {
188  DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size());
189  size_t group_id = 0;
190  Weight accum = Weight::One();
191  for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id)
192  accum = Times(accum, GroupFinalWeight(group_id, *it));
193  return accum;
194 }
195 
196 template <class A>
197 inline std::pair<typename std::vector<typename A::Label>::const_iterator,
198  typename std::vector<typename A::Label>::const_iterator>
200  const InputAttribute &attrib = input_attribs_[word];
201  if (attrib.output_length == 0)
202  return std::make_pair(output_set_.begin(), output_set_.end());
203  else
204  return std::make_pair(
205  output_pool_.begin() + attrib.output_begin,
206  output_pool_.begin() + attrib.output_begin + attrib.output_length);
207 }
208 
209 template <class A>
210 inline LinearFstData<A> *LinearFstData<A>::Read(std::istream &strm) {
211  std::unique_ptr<LinearFstData<A>> data(new LinearFstData<A>());
212  ReadType(strm, &(data->max_future_size_));
213  ReadType(strm, &(data->max_input_label_));
214  // Feature groups
215  size_t num_groups = 0;
216  ReadType(strm, &num_groups);
217  data->groups_.resize(num_groups);
218  for (size_t i = 0; i < num_groups; ++i)
219  data->groups_[i].reset(FeatureGroup<A>::Read(strm));
220  // Other data
221  ReadType(strm, &(data->input_attribs_));
222  ReadType(strm, &(data->output_pool_));
223  ReadType(strm, &(data->output_set_));
224  ReadType(strm, &(data->group_feat_map_));
225  if (strm) {
226  return data.release();
227  } else {
228  return nullptr;
229  }
230 }
231 
232 template <class A>
233 inline std::ostream &LinearFstData<A>::Write(std::ostream &strm) const {
234  WriteType(strm, max_future_size_);
235  WriteType(strm, max_input_label_);
236  // Feature groups
237  WriteType(strm, groups_.size());
238  for (size_t i = 0; i < groups_.size(); ++i) {
239  groups_[i]->Write(strm);
240  }
241  // Other data
242  WriteType(strm, input_attribs_);
243  WriteType(strm, output_pool_);
244  WriteType(strm, output_set_);
245  WriteType(strm, group_feat_map_);
246  return strm;
247 }
248 
249 template <class A>
250 typename A::Label LinearFstData<A>::FindFeature(size_t group,
251  Label word) const {
252  DCHECK(word > 0 || word == kStartOfSentence || word == kEndOfSentence);
253  if (word == kStartOfSentence || word == kEndOfSentence)
254  return word;
255  else
256  return group_feat_map_.Find(group, word);
257 }
258 
259 template <class A>
260 inline std::istream &LinearFstData<A>::InputAttribute::Read(
261  std::istream &strm) {
262  ReadType(strm, &output_begin);
263  ReadType(strm, &output_length);
264  return strm;
265 }
266 
267 template <class A>
268 inline std::ostream &LinearFstData<A>::InputAttribute::Write(
269  std::ostream &strm) const {
270  WriteType(strm, output_begin);
271  WriteType(strm, output_length);
272  return strm;
273 }
274 
275 // Forward declaration
276 template <class A>
277 class FeatureGroupBuilder;
278 
279 // An immutable grouping of features with similar context shape. Like
280 // `LinearFstData`, this can only be constructed via `Read()` or
281 // via its builder.
282 //
283 // Internally it uses a trie to store all feature n-grams and their
284 // weights. The label of a trie edge is a pair (feat, olabel) of
285 // labels. They can be either positive (ordinary label), `kNoLabel`,
286 // `kStartOfSentence`, or `kEndOfSentence`. `kNoLabel` usually means
287 // matching anything, with one exception: from the root of the trie,
288 // there is a special (kNoLabel, kNoLabel) that leads to the implicit
289 // start-of-sentence state. This edge is never actually matched
290 // (`FindFirstMatch()` ensures this).
291 template <class A>
292 class FeatureGroup {
293  public:
294  friend class FeatureGroupBuilder<A>; // for builder access
295 
296  typedef typename A::Label Label;
297  typedef typename A::Weight Weight;
298 
299  int Start() const { return start_; }
300 
301  // Finds destination node from `cur` by consuming `ilabel` and
302  // `olabel`. The transition weight is multiplied onto `weight`.
303  int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const;
304 
305  // Returns the final weight of the current trie state. Only valid if
306  // the state is already known to be part of a final state (see
307  // `LinearFstData<>::CanBeFinal()`).
308  Weight FinalWeight(int trie_state) const {
309  return trie_[trie_state].final_weight;
310  }
311 
312  static FeatureGroup<A> *Read(std::istream &strm) {
313  size_t delay;
314  ReadType(strm, &delay);
315  int start;
316  ReadType(strm, &start);
317  Trie trie;
318  ReadType(strm, &trie);
319  std::unique_ptr<FeatureGroup<A>> ret(new FeatureGroup<A>(delay, start));
320  ret->trie_.swap(trie);
321  ReadType(strm, &ret->next_state_);
322  if (strm) {
323  return ret.release();
324  } else {
325  return nullptr;
326  }
327  }
328 
329  std::ostream &Write(std::ostream &strm) const {
330  WriteType(strm, delay_);
331  WriteType(strm, start_);
332  WriteType(strm, trie_);
333  WriteType(strm, next_state_);
334  return strm;
335  }
336 
337  size_t Delay() const { return delay_; }
338 
339  std::string Stats() const;
340 
341  private:
342  // Label along the arcs on the trie. `kNoLabel` means anything
343  // (non-negative label) can match; both sides holding `kNoLabel`
344  // is not allow; otherwise the label is > 0 (enforced by
345  // `LinearFstDataBuilder::AddWeight()`).
346  struct InputOutputLabel;
347  struct InputOutputLabelHash;
348 
349  // Data to be stored on the trie
350  struct WeightBackLink {
351  int back_link;
352  Weight weight, final_weight;
353 
354  WeightBackLink()
355  : back_link(kNoTrieNodeId),
356  weight(Weight::One()),
357  final_weight(Weight::One()) {}
358 
359  std::istream &Read(std::istream &strm) {
360  ReadType(strm, &back_link);
361  ReadType(strm, &weight);
362  ReadType(strm, &final_weight);
363  return strm;
364  }
365 
366  std::ostream &Write(std::ostream &strm) const {
367  WriteType(strm, back_link);
368  WriteType(strm, weight);
369  WriteType(strm, final_weight);
370  return strm;
371  }
372  };
373 
376 
377  explicit FeatureGroup(size_t delay, int start)
378  : delay_(delay), start_(start) {}
379 
380  // Finds the first node with an arc with `label` following the
381  // back-off chain of `parent`. Returns the node index or
382  // `kNoTrieNodeId` when not found.
383  int FindFirstMatch(InputOutputLabel label, int parent) const;
384 
385  size_t delay_;
386  int start_;
387  Trie trie_;
388  // Where to go after hitting this state. When we reach a state with
389  // no child and with no additional final weight (i.e. its final
390  // weight is the same as its back-off), we can immediately go to its
391  // back-off state.
392  std::vector<int> next_state_;
393 
394  FeatureGroup(const FeatureGroup &) = delete;
395  FeatureGroup &operator=(const FeatureGroup &) = delete;
396 };
397 
398 template <class A>
400  Label input, output;
401 
403  : input(i), output(o) {}
404 
405  bool operator==(InputOutputLabel that) const {
406  return input == that.input && output == that.output;
407  }
408 
409  std::istream &Read(std::istream &strm) {
410  ReadType(strm, &input);
411  ReadType(strm, &output);
412  return strm;
413  }
414 
415  std::ostream &Write(std::ostream &strm) const {
416  WriteType(strm, input);
417  WriteType(strm, output);
418  return strm;
419  }
420 };
421 
422 template <class A>
424  size_t operator()(InputOutputLabel label) const {
425  return static_cast<size_t>(label.input * 7853 + label.output);
426  }
427 };
428 
429 template <class A>
430 int FeatureGroup<A>::Walk(int cur, Label ilabel, Label olabel,
431  Weight *weight) const {
432  // Note: user of this method need to ensure `ilabel` and `olabel`
433  // are valid (e.g. see DCHECKs in
434  // `LinearFstData<>::TakeTransition()` and
435  // `LinearFstData<>::FindFeature()`).
436  int next;
437  if (ilabel == LinearFstData<A>::kStartOfSentence) {
438  // An observed start-of-sentence only occurs in the beginning of
439  // the input, when this feature group is delayed (i.e. there is
440  // another feature group with a larger future size). The actual
441  // input hasn't arrived so stay at the start state.
442  DCHECK_EQ(cur, start_);
443  next = start_;
444  } else {
445  // First, try exact match
446  next = FindFirstMatch(InputOutputLabel(ilabel, olabel), cur);
447  // Then try with don't cares
448  if (next == kNoTrieNodeId)
449  next = FindFirstMatch(InputOutputLabel(ilabel, kNoLabel), cur);
450  if (next == kNoTrieNodeId)
451  next = FindFirstMatch(InputOutputLabel(kNoLabel, olabel), cur);
452  // All failed, go to empty context
453  if (next == kNoTrieNodeId) next = trie_.Root();
454  *weight = Times(*weight, trie_[next].weight);
455  next = next_state_[next];
456  }
457  return next;
458 }
459 
460 template <class A>
462  int parent) const {
463  if (label.input == kNoLabel && label.output == kNoLabel)
464  return kNoTrieNodeId; // very important; see class doc.
465  for (; parent != kNoTrieNodeId; parent = trie_[parent].back_link) {
466  int next = trie_.Find(parent, label);
467  if (next != kNoTrieNodeId) return next;
468  }
469  return kNoTrieNodeId;
470 }
471 
472 template <class A>
473 inline std::string FeatureGroup<A>::Stats() const {
474  std::ostringstream strm;
475  int num_states = 2;
476  for (int i = 2; i < next_state_.size(); ++i)
477  num_states += i == next_state_[i];
478  strm << trie_.NumNodes() << " node(s); " << num_states << " state(s)";
479  return strm.str();
480 }
481 
482 template <class A>
484  public:
486 
487  void Init(size_t num_groups, size_t num_words) {
488  num_groups_ = num_groups;
489  pool_.clear();
490  pool_.resize(num_groups * num_words, kNoLabel);
491  }
492 
493  Label Find(size_t group_id, Label ilabel) const {
494  return pool_[IndexOf(group_id, ilabel)];
495  }
496 
497  bool Set(size_t group_id, Label ilabel, Label feat) {
498  size_t i = IndexOf(group_id, ilabel);
499  if (pool_[i] != kNoLabel && pool_[i] != feat) {
500  FSTERROR() << "Feature group " << group_id
501  << " already has feature for word " << ilabel;
502  return false;
503  }
504  pool_[i] = feat;
505  return true;
506  }
507 
508  std::istream &Read(std::istream &strm) {
509  ReadType(strm, &num_groups_);
510  ReadType(strm, &pool_);
511  return strm;
512  }
513 
514  std::ostream &Write(std::ostream &strm) const {
515  WriteType(strm, num_groups_);
516  WriteType(strm, pool_);
517  return strm;
518  }
519 
520  private:
521  size_t IndexOf(size_t group_id, Label ilabel) const {
522  return ilabel * num_groups_ + group_id;
523  }
524 
525  size_t num_groups_;
526  // `pool_[ilabel * num_groups_ + group_id]` is the feature active
527  // for group `group_id` with input `ilabel`
528  std::vector<Label> pool_;
529 };
530 
531 } // namespace fst
532 
533 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
std::string Stats() const
std::ostream & Write(std::ostream &strm) const
constexpr int kNoLabel
Definition: fst.h:201
bool operator==(InputOutputLabel that) const
bool Set(size_t group_id, Label ilabel, Label feat)
std::ostream & Write(std::ostream &strm) const
Label GroupStartState(int group_id) const
std::pair< typename std::vector< Label >::const_iterator, typename std::vector< Label >::const_iterator > PossibleOutputLabels(Label word) const
void TakeTransition(Iterator buffer_end, Iterator trie_state_begin, Iterator trie_state_end, Label ilabel, Label olabel, std::vector< Label > *next, Weight *weight) const
int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const
size_t Delay() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
std::istream & Read(std::istream &strm)
Weight GroupFinalWeight(int group_id, int trie_state) const
void Init(size_t num_groups, size_t num_words)
#define DCHECK_EQ(x, y)
Definition: log.h:71
void EncodeStartState(std::vector< Label > *output) const
static constexpr Label kEndOfSentence
Label GroupTransition(int group_id, int trie_state, Label ilabel, Label olabel, Weight *weight) const
Label Find(size_t group_id, Label ilabel) const
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:214
#define FSTERROR()
Definition: util.h:53
int Start() const
Weight FinalWeight(Iterator trie_state_begin, Iterator trie_state_end) const
constexpr int kNoTrieNodeId
Definition: trie.h:31
std::ostream & Write(std::ostream &strm) const
static constexpr Label kStartOfSentence
static LinearFstData< A > * Read(std::istream &strm)
Weight FinalWeight(int trie_state) const
size_t NumGroups() const
Label MinInputLabel() const
size_t MaxFutureSize() const
size_t operator()(InputOutputLabel label) const
std::istream & Read(std::istream &strm)
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:68
#define DCHECK(x)
Definition: log.h:70
Label MaxInputLabel() const
std::ostream & Write(std::ostream &strm) const
static FeatureGroup< A > * Read(std::istream &strm)
InputOutputLabel(Label i=kNoLabel, Label o=kNoLabel)