FST  openfst-1.7.1
OpenFst Library
linear-fst-data-builder.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 
4 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_
5 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_
6 
7 #include <map>
8 #include <queue>
9 #include <set>
10 #include <sstream>
11 #include <stack>
12 #include <string>
13 #include <vector>
14 
15 #include <fst/compat.h>
16 #include <fst/log.h>
17 #include <fst/fst.h>
18 #include <fst/symbol-table.h>
19 #include <fst/util.h>
20 
22 
23 namespace fst {
24 
25 // Forward declaration
26 template <class A>
28 
29 // For logging purposes
30 inline string TranslateLabel(int64 label, const SymbolTable *syms);
31 template <class Iterator>
32 string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms);
33 template <class Label>
34 string JoinLabels(const std::vector<Label> &labels, const SymbolTable *syms);
35 
36 // Guesses the appropriate boundary label (start- or end-of-sentence)
37 // for all labels equal to `boundary` and modifies the `sequence`
38 // in-place. Returns the number of positions that are still uncertain.
39 template <class A>
40 typename A::Label GuessStartOrEnd(std::vector<typename A::Label> *sequence,
41  typename A::Label boundary);
42 
43 // Builds a `LinearFstData` object by adding words and feature
44 // weights. A few conventions:
45 //
46 // - Input labels forms a dense non-empty range from 1 to `MaxInputLabel()`.
47 // - Feature labels, output labels are > 0.
48 // - Being a discriminative linear model, it only makes sense to use tropical
49 // semirings.
50 template <class A>
52  public:
53  typedef typename A::Label Label;
54  typedef typename A::Weight Weight;
55 
56  // Constructs a builder with associated symbol tables for diagonstic
57  // output. Each of these symbol tables may also be nullptr.
58  explicit LinearFstDataBuilder(const SymbolTable *isyms = nullptr,
59  const SymbolTable *fsyms = nullptr,
60  const SymbolTable *osyms = nullptr)
61  : error_(false),
62  max_future_size_(0),
63  max_input_label_(1),
64  isyms_(isyms),
65  fsyms_(fsyms),
66  osyms_(osyms) {}
67 
68  // Tests whether the builder has encountered any error. No operation
69  // is valid if the builder is already at error state. All other
70  // public methods should check this before any actual operations.
71  bool Error() const { return error_; }
72 
73  // Adds a word and its feature labels to the vocabulary; this
74  // version allows the word to have any output label. Returns true
75  // iff the word is added.
76  //
77  // This may fail if the word is added twice or if the feature labels
78  // are non-positive.
79  bool AddWord(Label word, const std::vector<Label> &features);
80 
81  // Adds a word and its feature labels to the vocabulary; this
82  // version puts constraint on possible output labels the word can
83  // have. `possible_output` must not be empty. Returns true iff the
84  // word is added.
85  //
86  // In addition to the reasons above in the two-parameter version,
87  // this may also fail if `possible_output` is empty or any output
88  // label in it is non-positive.
89  bool AddWord(Label word, const std::vector<Label> &word_features,
90  const std::vector<Label> &possible_output);
91 
92  // Creates a new feature group with specified future size (size of
93  // the look-ahead window), returns the group id to be used for
94  // adding actual feature weights or a negative number when called at
95  // error state.
96  //
97  // This does not fail unless called at error state.
98  int AddGroup(size_t future_size);
99 
100  // Adds an instance of feature weight to the specified feature
101  // group. If some weight has already been added with the same
102  // feature, the product of the old and new weights are
103  // stored. Returns true iff the weight is added. A weight is not
104  // added when the context has ill-formed context involving start-,
105  // end-of-sentence marks.
106  //
107  // For two features to be within the same group, it must satisfy
108  // that (1) they have the same future size; (2) the two either have
109  // disjoint context or one is the back-off context of the
110  // other. Furthermore, for all features in a single group, there
111  // must be one and only one other context (not necessarily an active
112  // feature) that the feature immediately backs off to (i.e. there is
113  // no other context that is the back-off of the first and backs off
114  // to the second).
115  //
116  // Consider for example features with zero look-ahead of the form
117  // (input, OUTPUT).
118  //
119  // - The following two features can be put in the same group because
120  // their context is disjoint: (a a a, A A), (b, B B);
121  //
122  // - The following two features can be put in the same group because
123  // one is the back-off context of the other: (a a a, A A), (a a, A
124  // A);
125  //
126  // - The following two features can NOT be put in the same group
127  // because there is overlap but neither is the other's back-off: (a
128  // a a, A), (a a, A A);
129  //
130  // - Finally, the following three features cannot be in a same group
131  // because the first one can immediately back off to either of the
132  // rest: (a a a, A A), (a a, A A), (a a a, A).
133  //
134  // The easiest way to satisfy the constraints is to create a feature
135  // group for each feature template. However, better feature grouping
136  // may help improve speed.
137  //
138  // This may fail if any of input or output labels are non-positive,
139  // or if any call to `FeatureGroupBuilder<>::AddWeight()` fails.
140  bool AddWeight(size_t group, const std::vector<Label> &input,
141  const std::vector<Label> &output, Weight weight);
142 
143  // Returns a newly created `LinearFstData` object or nullptr in case
144  // of failure. The caller takes the ownership of the memory. No
145  // other methods shall be called after this --- this is enforced by
146  // putting the builder at error state, even when a
147  // `LinearFstData<>` object is successfully built.
148  //
149  // This may fail if the call to any `FeatureGroupBuilder<>::Dump()`
150  // fails.
152 
153  private:
154  bool error_;
155  CompactSet<Label, kNoLabel> all_output_labels_;
156  std::map<Label, std::set<Label>> word_output_map_, word_feat_map_;
157  std::map<Label, std::set<size_t>> feat_groups_;
158  std::vector<std::unique_ptr<FeatureGroupBuilder<A>>> groups_;
159  size_t max_future_size_;
160  Label max_input_label_;
161  const SymbolTable *isyms_, *fsyms_, *osyms_;
162 
163  LinearFstDataBuilder(const LinearFstDataBuilder &) = delete;
164  LinearFstDataBuilder &operator=(const LinearFstDataBuilder &) = delete;
165 };
166 
167 // Builds a LinearFstData tailored for a LinearClassifierFst. The
168 // major difference between an ordinary LinearFstData that works on
169 // taggers and a LinearFstData that works on classifiers is that
170 // feature groups are divided into sections by the prediction class
171 // label. For a prediction label `pred` and a logical group id
172 // `group`, the actual group id is `group * num_classes + pred -
173 // 1`.
174 //
175 // This layout saves us from recording output labels in each single
176 // FeatureGroup. Because there is no need for any delaying, stripping
177 // the output allows features with different shapes but using the same
178 // set of feature label mapping to reside in a single FeatureGroup.
179 template <class A>
181  public:
182  typedef typename A::Label Label;
183  typedef typename A::Weight Weight;
184 
185  // Constructs a builder for a `num_classes`-class classifier,
186  // optinally with associated symbol tables for diagnostic
187  // output. The output labels (i.e. prediction) must be in the range
188  // of [1, num_classes].
189  explicit LinearClassifierFstDataBuilder(size_t num_classes,
190  const SymbolTable *isyms = nullptr,
191  const SymbolTable *fsyms = nullptr,
192  const SymbolTable *osyms = nullptr)
193  : error_(false),
194  num_classes_(num_classes),
195  num_groups_(0),
196  builder_(isyms, fsyms, osyms) {}
197 
198  // Tests whether the builder has encountered any error. Similar to
199  // LinearFstDataBuilder<>::Error().
200  bool Error() const { return error_; }
201 
202  // Same as LinearFstDataBuilder<>::AddWord().
203  bool AddWord(Label word, const std::vector<Label> &features);
204 
205  // Adds a logical feature group. Similar to
206  // LinearFstDataBuilder<>::AddGroup(), with the exception that the
207  // returned group id is the logical group id. Also there is no need
208  // for "future" in a classifier.
209  int AddGroup();
210 
211  // Adds an instance of feature weight to the specified logical
212  // feature group. Instead of a vector of output, only a single
213  // prediction is needed as the output.
214  //
215  // This may fail if `pred` is not in the range of [1, num_classes_].
216  bool AddWeight(size_t group, const std::vector<Label> &input, Label pred,
217  Weight weight);
218 
219  // Returns a newly created `LinearFstData` object or nullptr in case of
220  // failure.
222 
223  private:
224  std::vector<Label> empty_;
225  bool error_;
226  size_t num_classes_, num_groups_;
227  LinearFstDataBuilder<A> builder_;
228 };
229 
230 // Builds a single feature group. Usually used in
231 // `LinearFstDataBuilder::AddWeight()`. See that method for the
232 // constraints on grouping features.
233 template <class A>
234 class FeatureGroupBuilder {
235  public:
236  typedef typename A::Label Label;
237  typedef typename A::Weight Weight;
238 
239  // Constructs a builder with the given future size. All features
240  // added to the group will have look-ahead windows of this size.
241  FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms,
242  const SymbolTable *osyms)
243  : error_(false), future_size_(future_size), fsyms_(fsyms), osyms_(osyms) {
244  // This edge is special; see doc of class `FeatureGroup` on the
245  // details.
246  start_ = trie_.Insert(trie_.Root(), InputOutputLabel(kNoLabel, kNoLabel));
247  }
248 
249  // Tests whether the builder has encountered any error. No operation
250  // is valid if the builder is already at error state. All other
251  // public methods should check this before any actual operations.
252  bool Error() const { return error_; }
253 
254  // Adds a feature weight with the given context. Returns true iff
255  // the weight is added. A weight is not added if it has ill-formed
256  // context involving start-, end-of-sentence marks.
257  //
258  // Note: `input` is the sequence of input
259  // features, instead of input labels themselves. `input` must be at
260  // least as long as `future_size`; `output` may be empty, but
261  // usually should be non-empty because an empty output context is
262  // useless in discriminative modelling. All labels in both `input`
263  // and `output` must be > 0 (this is checked in
264  // `LinearFstDataBuilder::AddWeight()`). See
265  // LinearFstDataBuilder<>::AddWeight for more details.
266  //
267  // This may fail if the input is smaller than the look-ahead window.
268  bool AddWeight(const std::vector<Label> &input,
269  const std::vector<Label> &output, Weight weight);
270 
271  // Creates an actual FeatureGroup<> object. Connects back-off links;
272  // pre-accumulates weights from back-off features. Returns nullptr if
273  // there is any violation in unique immediate back-off
274  // constraints.
275  //
276  // Regardless of whether the call succeeds or not, the error flag is
277  // always set before this returns, to prevent repeated dumping.
278  //
279  // TODO(wuke): check overlapping top-level contexts (see
280  // `DumpOverlappingContext()` in tests).
281  FeatureGroup<A> *Dump(size_t max_future_size);
282 
283  private:
286  typedef typename FeatureGroup<A>::WeightBackLink WeightBackLink;
287  // Nested trie topology uses more memory but we can traverse a
288  // node's children easily, which is required in `BuildBackLinks()`.
291 
292  // Finds the first node with an arc with `label` following the
293  // back-off chain of `parent`. Returns the node index or
294  // `kNoTrieNodeId` when not found. The number of hops is stored in
295  // `hop` when it is not `nullptr`.
296  //
297  // This does not fail.
298  int FindFirstMatch(InputOutputLabel label, int parent, int *hop) const;
299 
300  // Links each node to its immediate back-off. root is linked to -1.
301  //
302  // This may fail when the unique immediate back-off constraint is
303  // violated.
304  void BuildBackLinks();
305 
306  // Traces back on the back-chain for each node to multiply the
307  // weights from back-offs to the node itself.
308  //
309  // This does not fail.
310  void PreAccumulateWeights();
311 
312  // Reconstruct the path from trie root to given node for logging.
313  bool TrieDfs(const Topology &topology, int cur, int target,
314  std::vector<InputOutputLabel> *path) const;
315  string TriePath(int node, const Topology &topology) const;
316 
317  bool error_;
318  size_t future_size_;
319  Trie trie_;
320  int start_;
321  const SymbolTable *fsyms_, *osyms_;
322 
323  FeatureGroupBuilder(const FeatureGroupBuilder &) = delete;
324  FeatureGroupBuilder &operator=(const FeatureGroupBuilder &) = delete;
325 };
326 
327 //
328 // Implementation of methods in `LinearFstDataBuilder`
329 //
330 template <class A>
332  const std::vector<Label> &features) {
333  if (error_) {
334  FSTERROR() << "Calling LinearFstDataBuilder<>::AddWord() at error state";
335  return false;
336  }
339  LOG(WARNING) << "Ignored: adding boundary label: "
340  << TranslateLabel(word, isyms_)
341  << "(start-of-sentence=" << LinearFstData<A>::kStartOfSentence
342  << ", end-of-sentence=" << LinearFstData<A>::kEndOfSentence
343  << ")";
344  return false;
345  }
346  if (word <= 0) {
347  error_ = true;
348  FSTERROR() << "Word label must be > 0; got " << word;
349  return false;
350  }
351  if (word > max_input_label_) max_input_label_ = word;
352  // Make sure the word hasn't been added before
353  if (word_feat_map_.find(word) != word_feat_map_.end()) {
354  error_ = true;
355  FSTERROR() << "Input word " << TranslateLabel(word, isyms_)
356  << " is added twice";
357  return false;
358  }
359  // Store features
360  std::set<Label> *feats = &word_feat_map_[word];
361  for (size_t i = 0; i < features.size(); ++i) {
362  Label feat = features[i];
363  if (feat <= 0) {
364  error_ = true;
365  FSTERROR() << "Feature label must be > 0; got " << feat;
366  return false;
367  }
368  feats->insert(feat);
369  }
370  return true;
371 }
372 
373 template <class A>
375  Label word, const std::vector<Label> &word_features,
376  const std::vector<Label> &possible_output) {
377  if (error_) {
378  FSTERROR() << "Calling LinearFstDataBuilder<>::AddWord() at error state";
379  return false;
380  }
381  if (!AddWord(word, word_features)) return false;
382  // Store possible output constraint
383  if (possible_output.empty()) {
384  error_ = true;
385  FSTERROR() << "Empty possible output constraint; "
386  << "use the two-parameter version if no constraint is need.";
387  return false;
388  }
389  std::set<Label> *outputs = &word_output_map_[word];
390  for (size_t i = 0; i < possible_output.size(); ++i) {
391  Label output = possible_output[i];
392  if (output == LinearFstData<A>::kStartOfSentence ||
394  LOG(WARNING) << "Ignored: word = " << TranslateLabel(word, isyms_)
395  << ": adding boundary label as possible output: " << output
396  << "(start-of-sentence="
398  << ", end-of-sentence=" << LinearFstData<A>::kEndOfSentence
399  << ")";
400  continue;
401  }
402  if (output <= 0) {
403  error_ = true;
404  FSTERROR() << "Output label must be > 0; got " << output;
405  return false;
406  }
407  outputs->insert(output);
408  all_output_labels_.Insert(output);
409  }
410  return true;
411 }
412 
413 template <class A>
414 inline int LinearFstDataBuilder<A>::AddGroup(size_t future_size) {
415  if (error_) {
416  FSTERROR() << "Calling LinearFstDataBuilder<>::AddGroup() at error state";
417  return -1;
418  }
419  size_t ret = groups_.size();
420  groups_.emplace_back(new FeatureGroupBuilder<A>(future_size, fsyms_, osyms_));
421  if (future_size > max_future_size_) max_future_size_ = future_size;
422  return ret;
423 }
424 
425 template <class A>
427  const std::vector<Label> &input,
428  const std::vector<Label> &output,
429  Weight weight) {
430  if (error_) {
431  FSTERROR() << "Calling LinearFstDataBuilder<>::AddWeight() at error state";
432  return false;
433  }
434  // Check well-formedness of boundary marks on the input.
435  {
436  bool start_in_middle = false, end_in_middle = false;
437  for (int i = 1; i < input.size(); ++i) {
438  if (input[i] == LinearFstData<A>::kStartOfSentence &&
439  input[i - 1] != LinearFstData<A>::kStartOfSentence)
440  start_in_middle = true;
441  if (input[i - 1] == LinearFstData<A>::kEndOfSentence &&
443  end_in_middle = true;
444  }
445  if (start_in_middle) {
446  LOG(WARNING) << "Ignored: start-of-sentence in the middle of the input!";
447  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
448  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
449  return false;
450  }
451  if (end_in_middle) {
452  LOG(WARNING) << "Ignored: end-of-sentence in the middle of the input!";
453  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
454  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
455  return false;
456  }
457  }
458  // Check well-formedness of boundary marks on the output.
459  {
460  bool non_first_start = false, non_last_end = false;
461  for (int i = 1; i < output.size(); ++i) {
462  if (output[i] == LinearFstData<A>::kStartOfSentence)
463  non_first_start = true;
464  if (output[i - 1] == LinearFstData<A>::kEndOfSentence)
465  non_last_end = true;
466  }
467  if (non_first_start) {
468  LOG(WARNING) << "Ignored: start-of-sentence not appearing "
469  << "as the first label in the output!";
470  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
471  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
472  return false;
473  }
474  if (non_last_end) {
475  LOG(WARNING) << "Ignored: end-of-sentence not appearing "
476  << "as the last label in the output!";
477  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
478  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
479  return false;
480  }
481  }
482 
483  for (size_t i = 0; i < input.size(); ++i) {
484  Label feat = input[i];
486  feat != LinearFstData<A>::kEndOfSentence && feat <= 0) {
487  error_ = true;
488  FSTERROR() << "Feature label must be > 0; got " << feat;
489  return false;
490  }
491  feat_groups_[feat].insert(group);
492  }
493  for (size_t i = 0; i < output.size(); ++i) {
494  Label label = output[i];
495  if (label != LinearFstData<A>::kStartOfSentence &&
496  label != LinearFstData<A>::kEndOfSentence && label <= 0) {
497  error_ = true;
498  FSTERROR() << "Output label must be > 0; got " << label;
499  return false;
500  }
501  if (label != LinearFstData<A>::kStartOfSentence &&
503  all_output_labels_.Insert(label);
504  }
505 
506  // Everything looks good at this point (more checks on the way in
507  // the feature group). Add this feature weight.
508  bool added = groups_[group]->AddWeight(input, output, weight);
509  if (groups_[group]->Error()) {
510  error_ = true;
511  FSTERROR() << "FeatureGroupBuilder<>::AddWeight() failed";
512  return false;
513  }
514  return added;
515 }
516 
517 template <class A>
519  if (error_) {
520  FSTERROR() << "Calling LinearFstDataBuilder<>::Dump() at error state";
521  return nullptr;
522  }
523 
524  std::unique_ptr<LinearFstData<A>> data(new LinearFstData<A>());
525  data->max_future_size_ = max_future_size_;
526  data->max_input_label_ = max_input_label_;
527 
528  // Feature groups; free builders after it's dumped.
529  data->groups_.resize(groups_.size());
530  for (int group = 0; group != groups_.size(); ++group) {
531  FeatureGroup<A> *new_group = groups_[group]->Dump(max_future_size_);
532  if (new_group == nullptr) {
533  error_ = true;
534  FSTERROR() << "Error in dumping group " << group;
535  return nullptr;
536  }
537  data->groups_[group].reset(new_group);
538  groups_[group].reset();
539  VLOG(1) << "Group " << group << ": " << new_group->Stats();
540  }
541 
542  // Per-group feature mapping
543  data->group_feat_map_.Init(data->NumGroups(), max_input_label_ + 1);
544  for (Label word = 1; word <= max_input_label_; ++word) {
545  typename std::map<Label, std::set<Label>>::const_iterator it =
546  word_feat_map_.find(word);
547  if (it == word_feat_map_.end()) continue;
548  for (typename std::set<Label>::const_iterator oit = it->second.begin();
549  oit != it->second.end(); ++oit) {
550  Label feat = *oit;
551  typename std::map<Label, std::set<size_t>>::const_iterator jt =
552  feat_groups_.find(feat);
553  if (jt == feat_groups_.end()) continue;
554  for (std::set<size_t>::const_iterator git = jt->second.begin();
555  git != jt->second.end(); ++git) {
556  size_t group_id = *git;
557  if (!data->group_feat_map_.Set(group_id, word, feat)) {
558  error_ = true;
559  return nullptr;
560  }
561  }
562  }
563  }
564 
565  // Possible output labels
566  {
567  std::vector<typename LinearFstData<A>::InputAttribute> *input_attribs =
568  &data->input_attribs_;
569  std::vector<Label> *output_pool = &data->output_pool_;
570  input_attribs->resize(max_input_label_ + 1);
571  for (Label word = 0; word <= max_input_label_; ++word) {
572  typename std::map<Label, std::set<Label>>::const_iterator it =
573  word_output_map_.find(word);
574  if (it == word_output_map_.end()) {
575  (*input_attribs)[word].output_begin = 0;
576  (*input_attribs)[word].output_length = 0;
577  } else {
578  (*input_attribs)[word].output_begin = output_pool->size();
579  (*input_attribs)[word].output_length = it->second.size();
580  for (typename std::set<Label>::const_iterator oit = it->second.begin();
581  oit != it->second.end(); ++oit) {
582  Label olabel = *oit;
583  output_pool->push_back(olabel);
584  }
585  }
586  }
587  }
588 
590  all_output_labels_.Begin();
591  it != all_output_labels_.End(); ++it)
592  data->output_set_.push_back(*it);
593 
594  error_ = true; // prevent future calls on this object
595  return data.release();
596 }
597 
598 //
599 // Implementation of methods in `LinearClassifierFstDataBuilder`
600 //
601 template <class A>
603  Label word, const std::vector<Label> &features) {
604  if (error_) {
605  FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddWord() at "
606  "error state";
607  return false;
608  }
609  bool added = builder_.AddWord(word, features);
610  if (builder_.Error()) error_ = true;
611  return added;
612 }
613 
614 template <class A>
616  if (error_) {
617  FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddGroup() at "
618  "error state";
619  return -1;
620  }
621  for (int i = 0; i < num_classes_; ++i) builder_.AddGroup(0);
622  if (builder_.Error()) {
623  error_ = true;
624  return -1;
625  }
626  return num_groups_++;
627 }
628 
629 template <class A>
631  size_t group, const std::vector<Label> &input, Label pred, Weight weight) {
632  if (error_) {
633  FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddWeight() at "
634  "error state";
635  return false;
636  }
637  if (pred <= 0 || pred > num_classes_) {
638  FSTERROR() << "Out-of-range prediction label: " << pred
639  << " (num classes = " << num_classes_ << ")";
640  error_ = true;
641  return false;
642  }
643  size_t real_group = group * num_classes_ + pred - 1;
644  bool added = builder_.AddWeight(real_group, input, empty_, weight);
645  if (builder_.Error()) error_ = true;
646  return added;
647 }
648 
649 template <class A>
651  if (error_) {
652  FSTERROR()
653  << "Calling LinearClassifierFstDataBuilder<>::Dump() at error state";
654  return nullptr;
655  }
656  LinearFstData<A> *data = builder_.Dump();
657  error_ = true;
658  return data;
659 }
660 
661 //
662 // Implementation of methods in `FeatureGroupBuilder`
663 //
664 template <class A>
665 bool FeatureGroupBuilder<A>::AddWeight(const std::vector<Label> &input,
666  const std::vector<Label> &output,
667  Weight weight) {
668  if (error_) {
669  FSTERROR() << "Calling FeatureGroupBuilder<>::AddWeight() at error state";
670  return false;
671  }
672 
673  // `LinearFstDataBuilder<>::AddWeight()` ensures prefix/suffix
674  // properties for us. We can directly count.
675  int num_input_start = 0;
676  while (num_input_start < input.size() &&
677  input[num_input_start] == LinearFstData<A>::kStartOfSentence)
678  ++num_input_start;
679  int num_output_start = 0;
680  while (num_output_start < output.size() &&
681  output[num_output_start] == LinearFstData<A>::kStartOfSentence)
682  ++num_output_start;
683  int num_input_end = 0;
684  for (int i = input.size() - 1;
685  i >= 0 && input[i] == LinearFstData<A>::kEndOfSentence; --i)
686  ++num_input_end;
687  int num_output_end = 0;
688  for (int i = output.size() - 1;
689  i >= 0 && output[i] == LinearFstData<A>::kEndOfSentence; --i)
690  ++num_output_end;
691 
692  DCHECK_LE(num_output_end, 1);
693 
694  if (input.size() - num_input_start < future_size_) {
695  LOG(WARNING) << "Ignored: start-of-sentence in the future!";
696  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
697  LOG(WARNING) << "\tOutput: " << JoinLabels(output, fsyms_);
698  return false;
699  }
700  if (num_input_start > 0 &&
701  input.size() - future_size_ - num_input_start <
702  output.size() - num_output_start) {
703  LOG(WARNING) << "Ignored: matching start-of-sentence with actual output!";
704  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
705  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
706  return false;
707  }
708  if (num_output_start > 0 &&
709  input.size() - future_size_ - num_input_start >
710  output.size() - num_output_start) {
711  LOG(WARNING) << "Ignored: matching start-of-sentence with actual input!";
712  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
713  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
714  return false;
715  }
716  // The following two require `num_output_end` <= 1.
717  if (num_input_end > future_size_ && num_input_end - future_size_ != 1) {
718  LOG(WARNING) << "Ignored: matching end-of-sentence with actual output!";
719  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
720  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
721  return false;
722  }
723  if (num_output_end > 0 &&
724  ((input.size() == future_size_ && future_size_ != num_input_end) ||
725  (input.size() > future_size_ &&
726  num_input_end != future_size_ + num_output_end))) {
727  LOG(WARNING) << "Ignored: matching end-of-sentence with actual input!";
728  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
729  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
730  return false;
731  }
732  // Check if the context has no other labels than boundary marks
733  // (such features are useless).
734  if (num_input_start + num_input_end == input.size() &&
735  num_output_start + num_output_end == output.size()) {
736  LOG(WARNING)
737  << "Ignored: feature context consisting of only boundary marks!";
738  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
739  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
740  return false;
741  }
742 
743  // Start point for insertion in the trie. Insert at `start_` iff the
744  // beginning of the context is non-consumed start-of-sentence.
745  int cur = (num_input_start == 0 && num_output_start <= future_size_)
746  ? trie_.Root()
747  : start_;
748  // Skip all input start-of-sentence marks
749  size_t ipos = num_input_start;
750  // Skip to keep at most `future_size_` start-of-sentence marks
751  size_t opos =
752  num_output_start <= future_size_ ? 0 : num_output_start - future_size_;
753  // Skip `num_output_end` end-of-sentence marks on both input and output
754  size_t iend = !input.empty() ? input.size() - num_output_end : 0,
755  oend = output.size() - num_output_end;
756  // Further, when output is empty, keep at most `future_size_`
757  // end-of-sentence marks on input.
758  if (output.empty() && num_input_end > future_size_)
759  iend = input.size() - num_input_end + future_size_;
760 
761  // Actual feature context is (input[ipos:iend], output[opos:oend]).
762 
763  // Pad `kNoLabel` as don't cares on the shorter of actual `input`
764  // and `output`.
765  const size_t effective_input_size = iend - ipos,
766  effective_output_size = oend - opos;
767  if (effective_input_size > effective_output_size) {
768  for (size_t pad = effective_input_size - effective_output_size; pad != 0;
769  --pad, ++ipos)
770  cur = trie_.Insert(cur, InputOutputLabel(input[ipos], kNoLabel));
771  } else if (effective_input_size < effective_output_size) {
772  for (size_t pad = effective_output_size - effective_input_size; pad != 0;
773  --pad, ++opos)
774  cur = trie_.Insert(cur, InputOutputLabel(kNoLabel, output[opos]));
775  }
776  CHECK_EQ(iend - ipos, oend - opos);
777  for (; ipos != iend; ++ipos, ++opos)
778  cur = trie_.Insert(cur, InputOutputLabel(input[ipos], output[opos]));
779  // We only need to attach final weight when there is an output
780  // end-of-sentence. When there is only end-of-sentence on the input,
781  // they are all consumed as the end-of-sentence paddings from
782  // `LinearFstImpl<>::ShiftBuffer()`. `LinearFstImpl<>::Expand()`
783  // and `LinearFstImpl<>::MatchInput()` ensures no other
784  // transition takes place after consuming the padding.
785  if (num_output_end > 0 || (output.empty() && num_input_end > future_size_))
786  trie_[cur].final_weight = Times(trie_[cur].final_weight, weight);
787  else
788  trie_[cur].weight = Times(trie_[cur].weight, weight);
789 
790  return true;
791 }
792 
793 template <class A>
795  if (error_) {
796  FSTERROR() << "Calling FeatureGroupBuilder<>::PreAccumulateWeights() "
797  << "at error state";
798  return nullptr;
799  }
800 
801  if (max_future_size < future_size_) {
802  error_ = true;
803  FSTERROR() << "max_future_size (= " << max_future_size
804  << ") is smaller the builder's future_size (= " << future_size_
805  << ")";
806  return nullptr;
807  }
808 
809  BuildBackLinks();
810  if (error_) return nullptr;
811  PreAccumulateWeights(); // does not fail
812 
813  FeatureGroup<A> *ret =
814  new FeatureGroup<A>(max_future_size - future_size_, start_);
815 
816  // Walk around the trie to compute next states
817  ret->next_state_.resize(trie_.NumNodes());
818  const Topology &topology = trie_.TrieTopology();
819  for (int i = 0; i < topology.NumNodes(); ++i) {
820  int next = i;
821  while (next != topology.Root() && topology.ChildrenOf(next).empty() &&
822  trie_[next].final_weight ==
823  trie_[trie_[next].back_link].final_weight)
824  next = trie_[next].back_link;
825  ret->next_state_[i] = next;
826  }
827 
828  // Copy the trie
829  typename FeatureGroup<A>::Trie store_trie(trie_);
830  ret->trie_.swap(store_trie);
831 
832  // Put the builder at error state to prevent repeated call of `Dump()`.
833  error_ = true;
834  return ret;
835 }
836 
837 template <class A>
839  int *hop) const {
840  int hop_count = 0;
841  int ret = kNoTrieNodeId;
842  for (; parent >= 0; parent = trie_[parent].back_link, ++hop_count) {
843  int next = trie_.Find(parent, label);
844  if (next != kNoTrieNodeId) {
845  ret = next;
846  break;
847  }
848  }
849  if (hop != nullptr) *hop = hop_count;
850  return ret;
851 }
852 
853 template <class A>
855  // Breadth first search from the root. In the case where we only
856  // have the input label, the immedate back-off is simply the longest
857  // suffix of the current node that is also in the trie. For a node
858  // reached from its parent with label L, we can simply walk through
859  // the parent's back-off chain to find the first state with an arc
860  // of the same label L. The uniqueness is always
861  // guanranteed. However, in the case with both input and output
862  // labels, it is possible to back off by removing first labels from
863  // either side, which in general causes non-uniqueness.
864 
865  const Topology &topology = trie_.TrieTopology();
866  std::queue<int> q; // all enqueued or visited nodes have known links
867 
868  // Note: nodes have back link initialized to -1 in their
869  // constructor.
870  q.push(trie_.Root());
871  while (!error_ && !q.empty()) {
872  int parent = q.front();
873  q.pop();
874  // Find links for every child
875  const typename Topology::NextMap &children = topology.ChildrenOf(parent);
876  for (typename Topology::NextMap::const_iterator eit = children.begin();
877  eit != children.end(); ++eit) {
878  const std::pair<InputOutputLabel, int> &edge = *eit;
879  InputOutputLabel label = edge.first;
880  int child = edge.second;
881  if (label.input == kNoLabel || label.output == kNoLabel) {
882  // Label pairs from root to here all have one and only one
883  // `kNoLabel` on the same side; equivalent to the
884  // "longest-suffix" case.
885  trie_[child].back_link =
886  FindFirstMatch(label, trie_[parent].back_link, nullptr);
887  } else {
888  // Neither side is `kNoLabel` at this point, there are
889  // three possible ways to back-off: if the parent backs
890  // off to some context with only one side non-empty, the
891  // empty side may remain empty; or else an exact match of
892  // both sides is needed. Try to find all three possible
893  // backs and look for the closest one (in terms of hops
894  // along the parent's back-off chain).
895  int only_input_hop, only_output_hop, full_hop;
896  int only_input_link =
897  FindFirstMatch(InputOutputLabel(label.input, kNoLabel), parent,
898  &only_input_hop),
899  only_output_link =
900  FindFirstMatch(InputOutputLabel(kNoLabel, label.output), parent,
901  &only_output_hop),
902  full_link =
903  FindFirstMatch(label, trie_[parent].back_link, &full_hop);
904  if (only_input_link != -1 && only_output_link != -1) {
905  error_ = true;
906  FSTERROR() << "Branching back-off chain:\n"
907  << "\tnode " << child << ": " << TriePath(child, topology)
908  << "\n"
909  << "\tcan back-off to node " << only_input_link << ": "
910  << TriePath(only_input_link, topology) << "\n"
911  << "\tcan back-off to node " << only_output_link << ": "
912  << TriePath(only_output_link, topology);
913  return;
914  } else if (full_link != -1) {
915  ++full_hop;
916  if (full_hop <= only_input_hop && full_hop <= only_output_hop) {
917  trie_[child].back_link = full_link;
918  } else {
919  error_ = true;
920  int problem_link = only_input_link != kNoTrieNodeId
921  ? only_input_link
922  : only_output_link;
923  CHECK_NE(problem_link, kNoTrieNodeId);
924  FSTERROR() << "Branching back-off chain:\n"
925  << "\tnode " << child << ": "
926  << TriePath(child, topology) << "\n"
927  << "\tcan back-off to node " << full_link << ": "
928  << TriePath(full_link, topology) << "\n"
929  << "tcan back-off to node " << problem_link << ": "
930  << TriePath(problem_link, topology);
931  return;
932  }
933  } else {
934  trie_[child].back_link =
935  only_input_link != -1 ? only_input_link : only_output_link;
936  }
937  }
938  if (error_) break;
939  // Point to empty context (root) when no back-off can be found
940  if (trie_[child].back_link == -1) trie_[child].back_link = 0;
941  q.push(child);
942  }
943  }
944 }
945 
946 template <class A>
948  std::vector<bool> visited(trie_.NumNodes(), false);
949  visited[trie_.Root()] = true;
950 
951  for (size_t i = 0; i != trie_.NumNodes(); ++i) {
952  std::stack<int> back_offs;
953  for (int j = i; !visited[j]; j = trie_[j].back_link) back_offs.push(j);
954  while (!back_offs.empty()) {
955  int j = back_offs.top();
956  back_offs.pop();
957  WeightBackLink &node = trie_[j];
958  node.weight = Times(node.weight, trie_[node.back_link].weight);
959  node.final_weight =
960  Times(node.final_weight, trie_[node.back_link].final_weight);
961  visited[j] = true;
962  }
963  }
964 }
965 
966 template <class A>
968  const Topology &topology, int cur, int target,
969  std::vector<InputOutputLabel> *path) const {
970  if (cur == target) return true;
971  const typename Topology::NextMap &children = topology.ChildrenOf(cur);
972  for (typename Topology::NextMap::const_iterator eit = children.begin();
973  eit != children.end(); ++eit) {
974  const std::pair<InputOutputLabel, int> &edge = *eit;
975  path->push_back(edge.first);
976  if (TrieDfs(topology, edge.second, target, path)) return true;
977  path->pop_back();
978  }
979  return false;
980 }
981 
982 template <class A>
983 string FeatureGroupBuilder<A>::TriePath(int node,
984  const Topology &topology) const {
985  std::vector<InputOutputLabel> labels;
986  TrieDfs(topology, topology.Root(), node, &labels);
987  bool first = true;
988  std::ostringstream strm;
989  for (typename std::vector<InputOutputLabel>::const_iterator it =
990  labels.begin();
991  it != labels.end(); ++it) {
992  InputOutputLabel i = *it;
993  if (first)
994  first = false;
995  else
996  strm << ", ";
997  strm << "(" << TranslateLabel(i.input, fsyms_) << ", "
998  << TranslateLabel(i.output, osyms_) << ")";
999  }
1000  return strm.str();
1001 }
1002 
1003 inline string TranslateLabel(int64 label, const SymbolTable *syms) {
1004  string ret;
1005  if (syms != nullptr) ret += syms->Find(label);
1006  if (ret.empty()) {
1007  std::ostringstream strm;
1008  strm << '<' << label << '>';
1009  ret = strm.str();
1010  }
1011  return ret;
1012 }
1013 
1014 template <class Iterator>
1015 string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms) {
1016  if (begin == end) return "<empty>";
1017  std::ostringstream strm;
1018  bool first = true;
1019  for (Iterator it = begin; it != end; ++it) {
1020  if (first)
1021  first = false;
1022  else
1023  strm << '|';
1024  strm << TranslateLabel(*it, syms);
1025  }
1026  return strm.str();
1027 }
1028 
1029 template <class Label>
1030 string JoinLabels(const std::vector<Label> &labels, const SymbolTable *syms) {
1031  return JoinLabels(labels.begin(), labels.end(), syms);
1032 }
1033 
1034 template <class A>
1035 typename A::Label GuessStartOrEnd(std::vector<typename A::Label> *sequence,
1036  typename A::Label boundary) {
1037  const size_t length = sequence->size();
1038  std::vector<bool> non_boundary_on_left(length, false),
1039  non_boundary_on_right(length, false);
1040  for (size_t i = 1; i < length; ++i) {
1041  non_boundary_on_left[i] =
1042  non_boundary_on_left[i - 1] || (*sequence)[i - 1] != boundary;
1043  non_boundary_on_right[length - 1 - i] = non_boundary_on_right[length - i] ||
1044  (*sequence)[length - i] != boundary;
1045  }
1046  int unresolved = 0;
1047  for (size_t i = 0; i < length; ++i) {
1048  if ((*sequence)[i] != boundary) continue;
1049  const bool left = non_boundary_on_left[i], right = non_boundary_on_right[i];
1050  if (left && right) {
1051  // Boundary in the middle
1052  LOG(WARNING) << "Boundary label in the middle of the sequence! position: "
1053  << i << "; boundary: " << boundary
1054  << "; sequence: " << JoinLabels(*sequence, nullptr);
1055  LOG(WARNING)
1056  << "This is an invalid sequence anyway so I will set it to start.";
1057  (*sequence)[i] = LinearFstData<A>::kStartOfSentence;
1058  } else if (left && !right) {
1059  // Can only be end
1060  (*sequence)[i] = LinearFstData<A>::kEndOfSentence;
1061  } else if (!left && right) {
1062  // Can only be start
1063  (*sequence)[i] = LinearFstData<A>::kStartOfSentence;
1064  } else {
1065  // !left && !right; can't really tell
1066  ++unresolved;
1067  }
1068  }
1069  return unresolved;
1070 }
1071 
1072 } // namespace fst
1073 
1074 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_
FeatureGroup< A > * Dump(size_t max_future_size)
constexpr int kNoLabel
Definition: fst.h:179
FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms, const SymbolTable *osyms)
string Stats() const
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
#define LOG(type)
Definition: log.h:48
bool AddWord(Label word, const std::vector< Label > &features)
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
string TranslateLabel(int64 label, const SymbolTable *syms)
bool AddWeight(const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
int AddGroup(size_t future_size)
int64_t int64
Definition: types.h:27
#define FSTERROR()
Definition: util.h:35
LinearClassifierFstDataBuilder(size_t num_classes, const SymbolTable *isyms=nullptr, const SymbolTable *fsyms=nullptr, const SymbolTable *osyms=nullptr)
bool AddWord(Label word, const std::vector< Label > &features)
LinearFstDataBuilder(const SymbolTable *isyms=nullptr, const SymbolTable *fsyms=nullptr, const SymbolTable *osyms=nullptr)
#define VLOG(level)
Definition: log.h:49
const int kNoTrieNodeId
Definition: trie.h:16
#define CHECK_NE(x, y)
Definition: log.h:67
int Root() const
Definition: trie.h:130
virtual string Find(int64 key) const
Definition: symbol-table.h:299
size_t NumNodes() const
Definition: trie.h:131
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
#define DCHECK_LE(x, y)
Definition: log.h:74
#define CHECK_EQ(x, y)
Definition: log.h:62
const NextMap & ChildrenOf(int parent) const
Definition: trie.h:134
std::unordered_map< L, int, H > NextMap
Definition: trie.h:64
const_iterator Begin() const
Definition: util.h:409
string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms)
A::Label GuessStartOrEnd(std::vector< typename A::Label > *sequence, typename A::Label boundary)