FST  openfst-1.8.3
OpenFst Library
linear-fst-data.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Data structures for storing and looking up the actual feature weights.
19 
20 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
21 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
22 
23 #include <cstddef>
24 #include <istream>
25 #include <memory>
26 #include <numeric>
27 #include <ostream>
28 #include <sstream>
29 #include <string>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/compat.h>
34 #include <fst/log.h>
36 #include <fst/fst.h>
37 #include <fst/util.h>
38 
39 namespace fst {
40 
41 template <class A>
43 // Forward declarations
44 template <class A>
46 
47 // Immutable data storage of the feature weights in a linear
48 // model. Produces state tuples that represent internal states of a
49 // LinearTaggerFst. Object of this class can only be constructed via
50 // either `LinearFstDataBuilder::Dump()` or `LinearFstData::Read()`
51 // and usually used as refcount'd object shared across mutiple
52 // `LinearTaggerFst` copies.
53 //
54 // TODO(wuke): more efficient trie implementation
55 template <class A>
57  public:
58  friend class LinearFstDataBuilder<A>; // For builder access
59 
60  typedef typename A::Label Label;
61  typedef typename A::Weight Weight;
62 
63  // Sentence boundary labels. Both of them are negative labels other
64  // than `kNoLabel`.
65  static constexpr Label kStartOfSentence = -3;
66  static constexpr Label kEndOfSentence = -2;
67 
68  // Constructs empty data; for non-trivial ways of construction see
69  // `Read()` and `LinearFstDataBuilder`.
71  : max_future_size_(0), max_input_label_(1), input_attribs_(1) {}
72 
73  // Appends the state tuple of the start state to `output`, where
74  // each tuple holds the node ids of a trie for each feature group.
75  void EncodeStartState(std::vector<Label> *output) const {
76  for (int i = 0; i < NumGroups(); ++i) output->push_back(GroupStartState(i));
77  }
78 
79  // Takes a transition from the trie states stored in
80  // `(trie_state_begin, trie_state_end)` with input label `ilabel`
81  // and output label `olabel`; appends the destination state tuple to
82  // `next` and multiplies the weight of the transition onto
83  // `weight`. `next` should be the shifted input buffer of the caller
84  // in `LinearTaggerFstImpl` (i.e. of size `LinearTaggerFstImpl::delay_`;
85  // the last element is `ilabel`).
86  template <class Iterator>
87  void TakeTransition(Iterator buffer_end, Iterator trie_state_begin,
88  Iterator trie_state_end, Label ilabel, Label olabel,
89  std::vector<Label> *next, Weight *weight) const;
90 
91  // Returns the final weight of the given trie state sequence.
92  template <class Iterator>
93  Weight FinalWeight(Iterator trie_state_begin, Iterator trie_state_end) const;
94 
95  // Returns the start trie state of the given group.
96  Label GroupStartState(int group_id) const {
97  return groups_[group_id]->Start();
98  }
99 
100  // Takes a transition only within the given group. Returns the
101  // destination trie state and multiplies the weight onto `weight`.
102  Label GroupTransition(int group_id, int trie_state, Label ilabel,
103  Label olabel, Weight *weight) const;
104 
105  // Returns the final weight of the given trie state in the given group.
106  Weight GroupFinalWeight(int group_id, int trie_state) const {
107  return groups_[group_id]->FinalWeight(trie_state);
108  }
109 
110  Label MinInputLabel() const { return 1; }
111 
112  Label MaxInputLabel() const { return max_input_label_; }
113 
114  // Returns the maximum future size of all feature groups. Future is
115  // the look-ahead window of a feature, e.g. if a feature looks at
116  // the next 2 words after the current input, then the future size is
117  // 2. There is no look-ahead for output. Features inside a single
118  // `FeatureGroup` must have equal future size.
119  size_t MaxFutureSize() const { return max_future_size_; }
120 
121  // Returns the number of feature groups
122  size_t NumGroups() const { return groups_.size(); }
123 
124  // Returns the range of possible output labels for an input label.
125  std::pair<typename std::vector<Label>::const_iterator,
126  typename std::vector<Label>::const_iterator>
127  PossibleOutputLabels(Label word) const;
128 
129  static LinearFstData<A> *Read(std::istream &strm);
130  std::ostream &Write(std::ostream &strm) const;
131 
132  private:
133  // Offsets in `output_pool_`
134  struct InputAttribute {
135  size_t output_begin, output_length;
136 
137  std::istream &Read(std::istream &strm);
138  std::ostream &Write(std::ostream &strm) const;
139  };
140 
141  // Mapping from input label to per-group feature label
142  class GroupFeatureMap;
143 
144  // Translates the input label into input feature label of group
145  // `group`; returns `kNoLabel` when there is no feature for that
146  // group.
147  Label FindFeature(size_t group, Label word) const;
148 
149  size_t max_future_size_;
150  Label max_input_label_;
151  std::vector<std::unique_ptr<const FeatureGroup<A>>> groups_;
152  std::vector<InputAttribute> input_attribs_;
153  std::vector<Label> output_pool_, output_set_;
154  GroupFeatureMap group_feat_map_;
155 
156  LinearFstData(const LinearFstData &) = delete;
157  LinearFstData &operator=(const LinearFstData &) = delete;
158 };
159 
160 template <class A>
161 template <class Iterator>
162 void LinearFstData<A>::TakeTransition(Iterator buffer_end,
163  Iterator trie_state_begin,
164  Iterator trie_state_end, Label ilabel,
165  Label olabel, std::vector<Label> *next,
166  Weight *weight) const {
167  DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size());
168  DCHECK(ilabel > 0 || ilabel == kEndOfSentence);
169  DCHECK(olabel > 0 || olabel == kStartOfSentence);
170  size_t group_id = 0;
171  for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id) {
172  size_t delay = groups_[group_id]->Delay();
173  // On the buffer, there may also be `kStartOfSentence` from the
174  // initial empty buffer.
175  Label real_ilabel = delay == 0 ? ilabel : *(buffer_end - delay);
176  next->push_back(
177  GroupTransition(group_id, *it, real_ilabel, olabel, weight));
178  }
179 }
180 
181 template <class A>
182 typename A::Label LinearFstData<A>::GroupTransition(int group_id,
183  int trie_state,
184  Label ilabel, Label olabel,
185  Weight *weight) const {
186  Label group_ilabel = FindFeature(group_id, ilabel);
187  return groups_[group_id]->Walk(trie_state, group_ilabel, olabel, weight);
188 }
189 
190 template <class A>
191 template <class Iterator>
192 inline typename A::Weight LinearFstData<A>::FinalWeight(
193  Iterator trie_state_begin, Iterator trie_state_end) const {
194  DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size());
195  size_t group_id = 0;
196  Weight accum = Weight::One();
197  for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id)
198  accum = Times(accum, GroupFinalWeight(group_id, *it));
199  return accum;
200 }
201 
202 template <class A>
203 inline std::pair<typename std::vector<typename A::Label>::const_iterator,
204  typename std::vector<typename A::Label>::const_iterator>
206  const InputAttribute &attrib = input_attribs_[word];
207  if (attrib.output_length == 0)
208  return std::make_pair(output_set_.begin(), output_set_.end());
209  else
210  return std::make_pair(
211  output_pool_.begin() + attrib.output_begin,
212  output_pool_.begin() + attrib.output_begin + attrib.output_length);
213 }
214 
215 template <class A>
216 inline LinearFstData<A> *LinearFstData<A>::Read(std::istream &strm) {
217  std::unique_ptr<LinearFstData<A>> data(new LinearFstData<A>());
218  ReadType(strm, &(data->max_future_size_));
219  ReadType(strm, &(data->max_input_label_));
220  // Feature groups
221  size_t num_groups = 0;
222  ReadType(strm, &num_groups);
223  data->groups_.resize(num_groups);
224  for (size_t i = 0; i < num_groups; ++i)
225  data->groups_[i].reset(FeatureGroup<A>::Read(strm));
226  // Other data
227  ReadType(strm, &(data->input_attribs_));
228  ReadType(strm, &(data->output_pool_));
229  ReadType(strm, &(data->output_set_));
230  ReadType(strm, &(data->group_feat_map_));
231  if (strm) {
232  return data.release();
233  } else {
234  return nullptr;
235  }
236 }
237 
238 template <class A>
239 inline std::ostream &LinearFstData<A>::Write(std::ostream &strm) const {
240  WriteType(strm, max_future_size_);
241  WriteType(strm, max_input_label_);
242  // Feature groups
243  WriteType(strm, groups_.size());
244  for (size_t i = 0; i < groups_.size(); ++i) {
245  groups_[i]->Write(strm);
246  }
247  // Other data
248  WriteType(strm, input_attribs_);
249  WriteType(strm, output_pool_);
250  WriteType(strm, output_set_);
251  WriteType(strm, group_feat_map_);
252  return strm;
253 }
254 
255 template <class A>
256 typename A::Label LinearFstData<A>::FindFeature(size_t group,
257  Label word) const {
258  DCHECK(word > 0 || word == kStartOfSentence || word == kEndOfSentence);
259  if (word == kStartOfSentence || word == kEndOfSentence)
260  return word;
261  else
262  return group_feat_map_.Find(group, word);
263 }
264 
265 template <class A>
266 inline std::istream &LinearFstData<A>::InputAttribute::Read(
267  std::istream &strm) {
268  ReadType(strm, &output_begin);
269  ReadType(strm, &output_length);
270  return strm;
271 }
272 
273 template <class A>
274 inline std::ostream &LinearFstData<A>::InputAttribute::Write(
275  std::ostream &strm) const {
276  WriteType(strm, output_begin);
277  WriteType(strm, output_length);
278  return strm;
279 }
280 
281 // Forward declaration
282 template <class A>
283 class FeatureGroupBuilder;
284 
285 // An immutable grouping of features with similar context shape. Like
286 // `LinearFstData`, this can only be constructed via `Read()` or
287 // via its builder.
288 //
289 // Internally it uses a trie to store all feature n-grams and their
290 // weights. The label of a trie edge is a pair (feat, olabel) of
291 // labels. They can be either positive (ordinary label), `kNoLabel`,
292 // `kStartOfSentence`, or `kEndOfSentence`. `kNoLabel` usually means
293 // matching anything, with one exception: from the root of the trie,
294 // there is a special (kNoLabel, kNoLabel) that leads to the implicit
295 // start-of-sentence state. This edge is never actually matched
296 // (`FindFirstMatch()` ensures this).
297 template <class A>
298 class FeatureGroup {
299  public:
300  friend class FeatureGroupBuilder<A>; // for builder access
301 
302  typedef typename A::Label Label;
303  typedef typename A::Weight Weight;
304 
305  int Start() const { return start_; }
306 
307  // Finds destination node from `cur` by consuming `ilabel` and
308  // `olabel`. The transition weight is multiplied onto `weight`.
309  int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const;
310 
311  // Returns the final weight of the current trie state. Only valid if
312  // the state is already known to be part of a final state (see
313  // `LinearFstData<>::CanBeFinal()`).
314  Weight FinalWeight(int trie_state) const {
315  return trie_[trie_state].final_weight;
316  }
317 
318  static FeatureGroup<A> *Read(std::istream &strm) {
319  size_t delay;
320  ReadType(strm, &delay);
321  int start;
322  ReadType(strm, &start);
323  Trie trie;
324  ReadType(strm, &trie);
325  std::unique_ptr<FeatureGroup<A>> ret(new FeatureGroup<A>(delay, start));
326  ret->trie_.swap(trie);
327  ReadType(strm, &ret->next_state_);
328  if (strm) {
329  return ret.release();
330  } else {
331  return nullptr;
332  }
333  }
334 
335  std::ostream &Write(std::ostream &strm) const {
336  WriteType(strm, delay_);
337  WriteType(strm, start_);
338  WriteType(strm, trie_);
339  WriteType(strm, next_state_);
340  return strm;
341  }
342 
343  size_t Delay() const { return delay_; }
344 
345  std::string Stats() const;
346 
347  private:
348  // Label along the arcs on the trie. `kNoLabel` means anything
349  // (non-negative label) can match; both sides holding `kNoLabel`
350  // is not allow; otherwise the label is > 0 (enforced by
351  // `LinearFstDataBuilder::AddWeight()`).
352  struct InputOutputLabel;
353  struct InputOutputLabelHash;
354 
355  // Data to be stored on the trie
356  struct WeightBackLink {
357  int back_link;
358  Weight weight, final_weight;
359 
360  WeightBackLink()
361  : back_link(kNoTrieNodeId),
362  weight(Weight::One()),
363  final_weight(Weight::One()) {}
364 
365  std::istream &Read(std::istream &strm) {
366  ReadType(strm, &back_link);
367  ReadType(strm, &weight);
368  ReadType(strm, &final_weight);
369  return strm;
370  }
371 
372  std::ostream &Write(std::ostream &strm) const {
373  WriteType(strm, back_link);
374  WriteType(strm, weight);
375  WriteType(strm, final_weight);
376  return strm;
377  }
378  };
379 
382 
383  explicit FeatureGroup(size_t delay, int start)
384  : delay_(delay), start_(start) {}
385 
386  // Finds the first node with an arc with `label` following the
387  // back-off chain of `parent`. Returns the node index or
388  // `kNoTrieNodeId` when not found.
389  int FindFirstMatch(InputOutputLabel label, int parent) const;
390 
391  size_t delay_;
392  int start_;
393  Trie trie_;
394  // Where to go after hitting this state. When we reach a state with
395  // no child and with no additional final weight (i.e. its final
396  // weight is the same as its back-off), we can immediately go to its
397  // back-off state.
398  std::vector<int> next_state_;
399 
400  FeatureGroup(const FeatureGroup &) = delete;
401  FeatureGroup &operator=(const FeatureGroup &) = delete;
402 };
403 
404 template <class A>
406  Label input, output;
407 
409  : input(i), output(o) {}
410 
411  bool operator==(InputOutputLabel that) const {
412  return input == that.input && output == that.output;
413  }
414 
415  std::istream &Read(std::istream &strm) {
416  ReadType(strm, &input);
417  ReadType(strm, &output);
418  return strm;
419  }
420 
421  std::ostream &Write(std::ostream &strm) const {
422  WriteType(strm, input);
423  WriteType(strm, output);
424  return strm;
425  }
426 };
427 
428 template <class A>
430  size_t operator()(InputOutputLabel label) const {
431  return static_cast<size_t>(label.input * 7853 + label.output);
432  }
433 };
434 
435 template <class A>
436 int FeatureGroup<A>::Walk(int cur, Label ilabel, Label olabel,
437  Weight *weight) const {
438  // Note: user of this method need to ensure `ilabel` and `olabel`
439  // are valid (e.g. see DCHECKs in
440  // `LinearFstData<>::TakeTransition()` and
441  // `LinearFstData<>::FindFeature()`).
442  int next;
443  if (ilabel == LinearFstData<A>::kStartOfSentence) {
444  // An observed start-of-sentence only occurs in the beginning of
445  // the input, when this feature group is delayed (i.e. there is
446  // another feature group with a larger future size). The actual
447  // input hasn't arrived so stay at the start state.
448  DCHECK_EQ(cur, start_);
449  next = start_;
450  } else {
451  // First, try exact match
452  next = FindFirstMatch(InputOutputLabel(ilabel, olabel), cur);
453  // Then try with don't cares
454  if (next == kNoTrieNodeId)
455  next = FindFirstMatch(InputOutputLabel(ilabel, kNoLabel), cur);
456  if (next == kNoTrieNodeId)
457  next = FindFirstMatch(InputOutputLabel(kNoLabel, olabel), cur);
458  // All failed, go to empty context
459  if (next == kNoTrieNodeId) next = trie_.Root();
460  *weight = Times(*weight, trie_[next].weight);
461  next = next_state_[next];
462  }
463  return next;
464 }
465 
466 template <class A>
468  int parent) const {
469  if (label.input == kNoLabel && label.output == kNoLabel)
470  return kNoTrieNodeId; // very important; see class doc.
471  for (; parent != kNoTrieNodeId; parent = trie_[parent].back_link) {
472  int next = trie_.Find(parent, label);
473  if (next != kNoTrieNodeId) return next;
474  }
475  return kNoTrieNodeId;
476 }
477 
478 template <class A>
479 inline std::string FeatureGroup<A>::Stats() const {
480  std::ostringstream strm;
481  int num_states = 2;
482  for (int i = 2; i < next_state_.size(); ++i)
483  num_states += i == next_state_[i];
484  strm << trie_.NumNodes() << " node(s); " << num_states << " state(s)";
485  return strm.str();
486 }
487 
488 template <class A>
490  public:
491  GroupFeatureMap() = default;
492 
493  void Init(size_t num_groups, size_t num_words) {
494  num_groups_ = num_groups;
495  pool_.clear();
496  pool_.resize(num_groups * num_words, kNoLabel);
497  }
498 
499  Label Find(size_t group_id, Label ilabel) const {
500  return pool_[IndexOf(group_id, ilabel)];
501  }
502 
503  bool Set(size_t group_id, Label ilabel, Label feat) {
504  size_t i = IndexOf(group_id, ilabel);
505  if (pool_[i] != kNoLabel && pool_[i] != feat) {
506  FSTERROR() << "Feature group " << group_id
507  << " already has feature for word " << ilabel;
508  return false;
509  }
510  pool_[i] = feat;
511  return true;
512  }
513 
514  std::istream &Read(std::istream &strm) {
515  ReadType(strm, &num_groups_);
516  ReadType(strm, &pool_);
517  return strm;
518  }
519 
520  std::ostream &Write(std::ostream &strm) const {
521  WriteType(strm, num_groups_);
522  WriteType(strm, pool_);
523  return strm;
524  }
525 
526  private:
527  size_t IndexOf(size_t group_id, Label ilabel) const {
528  return ilabel * num_groups_ + group_id;
529  }
530 
531  size_t num_groups_;
532  // `pool_[ilabel * num_groups_ + group_id]` is the feature active
533  // for group `group_id` with input `ilabel`
534  std::vector<Label> pool_;
535 };
536 
537 } // namespace fst
538 
539 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_
std::string Stats() const
std::ostream & Write(std::ostream &strm) const
constexpr int kNoLabel
Definition: fst.h:195
bool operator==(InputOutputLabel that) const
bool Set(size_t group_id, Label ilabel, Label feat)
std::ostream & Write(std::ostream &strm) const
Label GroupStartState(int group_id) const
std::pair< typename std::vector< Label >::const_iterator, typename std::vector< Label >::const_iterator > PossibleOutputLabels(Label word) const
void TakeTransition(Iterator buffer_end, Iterator trie_state_begin, Iterator trie_state_end, Label ilabel, Label olabel, std::vector< Label > *next, Weight *weight) const
int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const
size_t Delay() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
std::istream & Read(std::istream &strm)
Weight GroupFinalWeight(int group_id, int trie_state) const
void Init(size_t num_groups, size_t num_words)
#define DCHECK_EQ(x, y)
Definition: log.h:75
void EncodeStartState(std::vector< Label > *output) const
static constexpr Label kEndOfSentence
Label GroupTransition(int group_id, int trie_state, Label ilabel, Label olabel, Weight *weight) const
Label Find(size_t group_id, Label ilabel) const
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
#define FSTERROR()
Definition: util.h:56
int Start() const
Weight FinalWeight(Iterator trie_state_begin, Iterator trie_state_end) const
constexpr int kNoTrieNodeId
Definition: trie.h:35
std::ostream & Write(std::ostream &strm) const
static constexpr Label kStartOfSentence
static LinearFstData< A > * Read(std::istream &strm)
Weight FinalWeight(int trie_state) const
size_t NumGroups() const
Label MinInputLabel() const
size_t MaxFutureSize() const
size_t operator()(InputOutputLabel label) const
std::istream & Read(std::istream &strm)
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
#define DCHECK(x)
Definition: log.h:74
Label MaxInputLabel() const
std::ostream & Write(std::ostream &strm) const
static FeatureGroup< A > * Read(std::istream &strm)
InputOutputLabel(Label i=kNoLabel, Label o=kNoLabel)