FST  openfst-1.8.3
OpenFst Library
linear-fst-data-builder.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_
19 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_
20 
21 #include <cstddef>
22 #include <cstdint>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <sstream>
28 #include <stack>
29 #include <string>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/compat.h>
34 #include <fst/log.h>
37 #include <fst/fst.h>
38 #include <fst/symbol-table.h>
39 #include <fst/util.h>
40 
41 namespace fst {
42 
43 // Forward declaration
44 template <class A>
46 
47 // For logging purposes
48 inline std::string TranslateLabel(int64_t label, const SymbolTable *syms);
49 template <class Iterator>
50 std::string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms);
51 template <class Label>
52 std::string JoinLabels(const std::vector<Label> &labels,
53  const SymbolTable *syms);
54 
55 // Guesses the appropriate boundary label (start- or end-of-sentence)
56 // for all labels equal to `boundary` and modifies the `sequence`
57 // in-place. Returns the number of positions that are still uncertain.
58 template <class A>
59 typename A::Label GuessStartOrEnd(std::vector<typename A::Label> *sequence,
60  typename A::Label boundary);
61 
62 // Builds a `LinearFstData` object by adding words and feature
63 // weights. A few conventions:
64 //
65 // - Input labels forms a dense non-empty range from 1 to `MaxInputLabel()`.
66 // - Feature labels, output labels are > 0.
67 // - Being a discriminative linear model, it only makes sense to use tropical
68 // semirings.
69 template <class A>
71  public:
72  typedef typename A::Label Label;
73  typedef typename A::Weight Weight;
74 
75  // Constructs a builder with associated symbol tables for diagonstic
76  // output. Each of these symbol tables may also be nullptr.
77  explicit LinearFstDataBuilder(const SymbolTable *isyms = nullptr,
78  const SymbolTable *fsyms = nullptr,
79  const SymbolTable *osyms = nullptr)
80  : error_(false),
81  max_future_size_(0),
82  max_input_label_(1),
83  isyms_(isyms),
84  fsyms_(fsyms),
85  osyms_(osyms) {}
86 
87  // Tests whether the builder has encountered any error. No operation
88  // is valid if the builder is already at error state. All other
89  // public methods should check this before any actual operations.
90  bool Error() const { return error_; }
91 
92  // Adds a word and its feature labels to the vocabulary; this
93  // version allows the word to have any output label. Returns true
94  // iff the word is added.
95  //
96  // This may fail if the word is added twice or if the feature labels
97  // are non-positive.
98  bool AddWord(Label word, const std::vector<Label> &features);
99 
100  // Adds a word and its feature labels to the vocabulary; this
101  // version puts constraint on possible output labels the word can
102  // have. `possible_output` must not be empty. Returns true iff the
103  // word is added.
104  //
105  // In addition to the reasons above in the two-parameter version,
106  // this may also fail if `possible_output` is empty or any output
107  // label in it is non-positive.
108  bool AddWord(Label word, const std::vector<Label> &word_features,
109  const std::vector<Label> &possible_output);
110 
111  // Creates a new feature group with specified future size (size of
112  // the look-ahead window), returns the group id to be used for
113  // adding actual feature weights or a negative number when called at
114  // error state.
115  //
116  // This does not fail unless called at error state.
117  int AddGroup(size_t future_size);
118 
119  // Adds an instance of feature weight to the specified feature
120  // group. If some weight has already been added with the same
121  // feature, the product of the old and new weights are
122  // stored. Returns true iff the weight is added. A weight is not
123  // added when the context has ill-formed context involving start-,
124  // end-of-sentence marks.
125  //
126  // For two features to be within the same group, it must satisfy
127  // that (1) they have the same future size; (2) the two either have
128  // disjoint context or one is the back-off context of the
129  // other. Furthermore, for all features in a single group, there
130  // must be one and only one other context (not necessarily an active
131  // feature) that the feature immediately backs off to (i.e. there is
132  // no other context that is the back-off of the first and backs off
133  // to the second).
134  //
135  // Consider for example features with zero look-ahead of the form
136  // (input, OUTPUT).
137  //
138  // - The following two features can be put in the same group because
139  // their context is disjoint: (a a a, A A), (b, B B);
140  //
141  // - The following two features can be put in the same group because
142  // one is the back-off context of the other: (a a a, A A), (a a, A
143  // A);
144  //
145  // - The following two features can NOT be put in the same group
146  // because there is overlap but neither is the other's back-off: (a
147  // a a, A), (a a, A A);
148  //
149  // - Finally, the following three features cannot be in a same group
150  // because the first one can immediately back off to either of the
151  // rest: (a a a, A A), (a a, A A), (a a a, A).
152  //
153  // The easiest way to satisfy the constraints is to create a feature
154  // group for each feature template. However, better feature grouping
155  // may help improve speed.
156  //
157  // This may fail if any of input or output labels are non-positive,
158  // or if any call to `FeatureGroupBuilder<>::AddWeight()` fails.
159  bool AddWeight(size_t group, const std::vector<Label> &input,
160  const std::vector<Label> &output, Weight weight);
161 
162  // Returns a newly created `LinearFstData` object or nullptr in case
163  // of failure. The caller takes the ownership of the memory. No
164  // other methods shall be called after this --- this is enforced by
165  // putting the builder at error state, even when a
166  // `LinearFstData<>` object is successfully built.
167  //
168  // This may fail if the call to any `FeatureGroupBuilder<>::Dump()`
169  // fails.
171 
172  private:
173  bool error_;
174  CompactSet<Label, kNoLabel> all_output_labels_;
175  std::map<Label, std::set<Label>> word_output_map_, word_feat_map_;
176  std::map<Label, std::set<size_t>> feat_groups_;
177  std::vector<std::unique_ptr<FeatureGroupBuilder<A>>> groups_;
178  size_t max_future_size_;
179  Label max_input_label_;
180  const SymbolTable *isyms_, *fsyms_, *osyms_;
181 
182  LinearFstDataBuilder(const LinearFstDataBuilder &) = delete;
183  LinearFstDataBuilder &operator=(const LinearFstDataBuilder &) = delete;
184 };
185 
186 // Builds a LinearFstData tailored for a LinearClassifierFst. The
187 // major difference between an ordinary LinearFstData that works on
188 // taggers and a LinearFstData that works on classifiers is that
189 // feature groups are divided into sections by the prediction class
190 // label. For a prediction label `pred` and a logical group id
191 // `group`, the actual group id is `group * num_classes + pred -
192 // 1`.
193 //
194 // This layout saves us from recording output labels in each single
195 // FeatureGroup. Because there is no need for any delaying, stripping
196 // the output allows features with different shapes but using the same
197 // set of feature label mapping to reside in a single FeatureGroup.
198 template <class A>
200  public:
201  typedef typename A::Label Label;
202  typedef typename A::Weight Weight;
203 
204  // Constructs a builder for a `num_classes`-class classifier,
205  // optinally with associated symbol tables for diagnostic
206  // output. The output labels (i.e. prediction) must be in the range
207  // of [1, num_classes].
208  explicit LinearClassifierFstDataBuilder(size_t num_classes,
209  const SymbolTable *isyms = nullptr,
210  const SymbolTable *fsyms = nullptr,
211  const SymbolTable *osyms = nullptr)
212  : error_(false),
213  num_classes_(num_classes),
214  num_groups_(0),
215  builder_(isyms, fsyms, osyms) {}
216 
217  // Tests whether the builder has encountered any error. Similar to
218  // LinearFstDataBuilder<>::Error().
219  bool Error() const { return error_; }
220 
221  // Same as LinearFstDataBuilder<>::AddWord().
222  bool AddWord(Label word, const std::vector<Label> &features);
223 
224  // Adds a logical feature group. Similar to
225  // LinearFstDataBuilder<>::AddGroup(), with the exception that the
226  // returned group id is the logical group id. Also there is no need
227  // for "future" in a classifier.
228  int AddGroup();
229 
230  // Adds an instance of feature weight to the specified logical
231  // feature group. Instead of a vector of output, only a single
232  // prediction is needed as the output.
233  //
234  // This may fail if `pred` is not in the range of [1, num_classes_].
235  bool AddWeight(size_t group, const std::vector<Label> &input, Label pred,
236  Weight weight);
237 
238  // Returns a newly created `LinearFstData` object or nullptr in case of
239  // failure.
241 
242  private:
243  std::vector<Label> empty_;
244  bool error_;
245  size_t num_classes_, num_groups_;
246  LinearFstDataBuilder<A> builder_;
247 };
248 
249 // Builds a single feature group. Usually used in
250 // `LinearFstDataBuilder::AddWeight()`. See that method for the
251 // constraints on grouping features.
252 template <class A>
253 class FeatureGroupBuilder {
254  public:
255  typedef typename A::Label Label;
256  typedef typename A::Weight Weight;
257 
258  // Constructs a builder with the given future size. All features
259  // added to the group will have look-ahead windows of this size.
260  FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms,
261  const SymbolTable *osyms)
262  : error_(false), future_size_(future_size), fsyms_(fsyms), osyms_(osyms) {
263  // This edge is special; see doc of class `FeatureGroup` on the
264  // details.
265  start_ = trie_.Insert(trie_.Root(), InputOutputLabel(kNoLabel, kNoLabel));
266  }
267 
268  // Tests whether the builder has encountered any error. No operation
269  // is valid if the builder is already at error state. All other
270  // public methods should check this before any actual operations.
271  bool Error() const { return error_; }
272 
273  // Adds a feature weight with the given context. Returns true iff
274  // the weight is added. A weight is not added if it has ill-formed
275  // context involving start-, end-of-sentence marks.
276  //
277  // Note: `input` is the sequence of input
278  // features, instead of input labels themselves. `input` must be at
279  // least as long as `future_size`; `output` may be empty, but
280  // usually should be non-empty because an empty output context is
281  // useless in discriminative modelling. All labels in both `input`
282  // and `output` must be > 0 (this is checked in
283  // `LinearFstDataBuilder::AddWeight()`). See
284  // LinearFstDataBuilder<>::AddWeight for more details.
285  //
286  // This may fail if the input is smaller than the look-ahead window.
287  bool AddWeight(const std::vector<Label> &input,
288  const std::vector<Label> &output, Weight weight);
289 
290  // Creates an actual FeatureGroup<> object. Connects back-off links;
291  // pre-accumulates weights from back-off features. Returns nullptr if
292  // there is any violation in unique immediate back-off
293  // constraints.
294  //
295  // Regardless of whether the call succeeds or not, the error flag is
296  // always set before this returns, to prevent repeated dumping.
297  //
298  // TODO(wuke): check overlapping top-level contexts (see
299  // `DumpOverlappingContext()` in tests).
300  FeatureGroup<A> *Dump(size_t max_future_size);
301 
302  private:
305  typedef typename FeatureGroup<A>::WeightBackLink WeightBackLink;
306  // Nested trie topology uses more memory but we can traverse a
307  // node's children easily, which is required in `BuildBackLinks()`.
310 
311  // Finds the first node with an arc with `label` following the
312  // back-off chain of `parent`. Returns the node index or
313  // `kNoTrieNodeId` when not found. The number of hops is stored in
314  // `hop` when it is not `nullptr`.
315  //
316  // This does not fail.
317  int FindFirstMatch(InputOutputLabel label, int parent, int *hop) const;
318 
319  // Links each node to its immediate back-off. root is linked to -1.
320  //
321  // This may fail when the unique immediate back-off constraint is
322  // violated.
323  void BuildBackLinks();
324 
325  // Traces back on the back-chain for each node to multiply the
326  // weights from back-offs to the node itself.
327  //
328  // This does not fail.
329  void PreAccumulateWeights();
330 
331  // Reconstruct the path from trie root to given node for logging.
332  bool TrieDfs(const Topology &topology, int cur, int target,
333  std::vector<InputOutputLabel> *path) const;
334  std::string TriePath(int node, const Topology &topology) const;
335 
336  bool error_;
337  size_t future_size_;
338  Trie trie_;
339  int start_;
340  const SymbolTable *fsyms_, *osyms_;
341 
342  FeatureGroupBuilder(const FeatureGroupBuilder &) = delete;
343  FeatureGroupBuilder &operator=(const FeatureGroupBuilder &) = delete;
344 };
345 
346 //
347 // Implementation of methods in `LinearFstDataBuilder`
348 //
349 template <class A>
351  const std::vector<Label> &features) {
352  if (error_) {
353  FSTERROR() << "Calling LinearFstDataBuilder<>::AddWord() at error state";
354  return false;
355  }
358  LOG(WARNING) << "Ignored: adding boundary label: "
359  << TranslateLabel(word, isyms_)
360  << "(start-of-sentence=" << LinearFstData<A>::kStartOfSentence
361  << ", end-of-sentence=" << LinearFstData<A>::kEndOfSentence
362  << ")";
363  return false;
364  }
365  if (word <= 0) {
366  error_ = true;
367  FSTERROR() << "Word label must be > 0; got " << word;
368  return false;
369  }
370  if (word > max_input_label_) max_input_label_ = word;
371  // Make sure the word hasn't been added before
372  if (word_feat_map_.find(word) != word_feat_map_.end()) {
373  error_ = true;
374  FSTERROR() << "Input word " << TranslateLabel(word, isyms_)
375  << " is added twice";
376  return false;
377  }
378  // Store features
379  std::set<Label> *feats = &word_feat_map_[word];
380  for (size_t i = 0; i < features.size(); ++i) {
381  Label feat = features[i];
382  if (feat <= 0) {
383  error_ = true;
384  FSTERROR() << "Feature label must be > 0; got " << feat;
385  return false;
386  }
387  feats->insert(feat);
388  }
389  return true;
390 }
391 
392 template <class A>
394  Label word, const std::vector<Label> &word_features,
395  const std::vector<Label> &possible_output) {
396  if (error_) {
397  FSTERROR() << "Calling LinearFstDataBuilder<>::AddWord() at error state";
398  return false;
399  }
400  if (!AddWord(word, word_features)) return false;
401  // Store possible output constraint
402  if (possible_output.empty()) {
403  error_ = true;
404  FSTERROR() << "Empty possible output constraint; "
405  << "use the two-parameter version if no constraint is need.";
406  return false;
407  }
408  std::set<Label> *outputs = &word_output_map_[word];
409  for (size_t i = 0; i < possible_output.size(); ++i) {
410  Label output = possible_output[i];
411  if (output == LinearFstData<A>::kStartOfSentence ||
413  LOG(WARNING) << "Ignored: word = " << TranslateLabel(word, isyms_)
414  << ": adding boundary label as possible output: " << output
415  << "(start-of-sentence="
417  << ", end-of-sentence=" << LinearFstData<A>::kEndOfSentence
418  << ")";
419  continue;
420  }
421  if (output <= 0) {
422  error_ = true;
423  FSTERROR() << "Output label must be > 0; got " << output;
424  return false;
425  }
426  outputs->insert(output);
427  all_output_labels_.Insert(output);
428  }
429  return true;
430 }
431 
432 template <class A>
433 inline int LinearFstDataBuilder<A>::AddGroup(size_t future_size) {
434  if (error_) {
435  FSTERROR() << "Calling LinearFstDataBuilder<>::AddGroup() at error state";
436  return -1;
437  }
438  size_t ret = groups_.size();
439  groups_.emplace_back(new FeatureGroupBuilder<A>(future_size, fsyms_, osyms_));
440  if (future_size > max_future_size_) max_future_size_ = future_size;
441  return ret;
442 }
443 
444 template <class A>
446  const std::vector<Label> &input,
447  const std::vector<Label> &output,
448  Weight weight) {
449  if (error_) {
450  FSTERROR() << "Calling LinearFstDataBuilder<>::AddWeight() at error state";
451  return false;
452  }
453  // Check well-formedness of boundary marks on the input.
454  {
455  bool start_in_middle = false, end_in_middle = false;
456  for (int i = 1; i < input.size(); ++i) {
457  if (input[i] == LinearFstData<A>::kStartOfSentence &&
458  input[i - 1] != LinearFstData<A>::kStartOfSentence)
459  start_in_middle = true;
460  if (input[i - 1] == LinearFstData<A>::kEndOfSentence &&
462  end_in_middle = true;
463  }
464  if (start_in_middle) {
465  LOG(WARNING) << "Ignored: start-of-sentence in the middle of the input!";
466  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
467  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
468  return false;
469  }
470  if (end_in_middle) {
471  LOG(WARNING) << "Ignored: end-of-sentence in the middle of the input!";
472  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
473  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
474  return false;
475  }
476  }
477  // Check well-formedness of boundary marks on the output.
478  {
479  bool non_first_start = false, non_last_end = false;
480  for (int i = 1; i < output.size(); ++i) {
481  if (output[i] == LinearFstData<A>::kStartOfSentence)
482  non_first_start = true;
483  if (output[i - 1] == LinearFstData<A>::kEndOfSentence)
484  non_last_end = true;
485  }
486  if (non_first_start) {
487  LOG(WARNING) << "Ignored: start-of-sentence not appearing "
488  << "as the first label in the output!";
489  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
490  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
491  return false;
492  }
493  if (non_last_end) {
494  LOG(WARNING) << "Ignored: end-of-sentence not appearing "
495  << "as the last label in the output!";
496  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
497  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
498  return false;
499  }
500  }
501 
502  for (size_t i = 0; i < input.size(); ++i) {
503  Label feat = input[i];
505  feat != LinearFstData<A>::kEndOfSentence && feat <= 0) {
506  error_ = true;
507  FSTERROR() << "Feature label must be > 0; got " << feat;
508  return false;
509  }
510  feat_groups_[feat].insert(group);
511  }
512  for (size_t i = 0; i < output.size(); ++i) {
513  Label label = output[i];
514  if (label != LinearFstData<A>::kStartOfSentence &&
515  label != LinearFstData<A>::kEndOfSentence && label <= 0) {
516  error_ = true;
517  FSTERROR() << "Output label must be > 0; got " << label;
518  return false;
519  }
520  if (label != LinearFstData<A>::kStartOfSentence &&
522  all_output_labels_.Insert(label);
523  }
524 
525  // Everything looks good at this point (more checks on the way in
526  // the feature group). Add this feature weight.
527  bool added = groups_[group]->AddWeight(input, output, weight);
528  if (groups_[group]->Error()) {
529  error_ = true;
530  FSTERROR() << "FeatureGroupBuilder<>::AddWeight() failed";
531  return false;
532  }
533  return added;
534 }
535 
536 template <class A>
538  if (error_) {
539  FSTERROR() << "Calling LinearFstDataBuilder<>::Dump() at error state";
540  return nullptr;
541  }
542 
543  std::unique_ptr<LinearFstData<A>> data(new LinearFstData<A>());
544  data->max_future_size_ = max_future_size_;
545  data->max_input_label_ = max_input_label_;
546 
547  // Feature groups; free builders after it's dumped.
548  data->groups_.resize(groups_.size());
549  for (int group = 0; group != groups_.size(); ++group) {
550  FeatureGroup<A> *new_group = groups_[group]->Dump(max_future_size_);
551  if (new_group == nullptr) {
552  error_ = true;
553  FSTERROR() << "Error in dumping group " << group;
554  return nullptr;
555  }
556  data->groups_[group].reset(new_group);
557  groups_[group].reset();
558  VLOG(1) << "Group " << group << ": " << new_group->Stats();
559  }
560 
561  // Per-group feature mapping
562  data->group_feat_map_.Init(data->NumGroups(), max_input_label_ + 1);
563  for (Label word = 1; word <= max_input_label_; ++word) {
564  typename std::map<Label, std::set<Label>>::const_iterator it =
565  word_feat_map_.find(word);
566  if (it == word_feat_map_.end()) continue;
567  for (typename std::set<Label>::const_iterator oit = it->second.begin();
568  oit != it->second.end(); ++oit) {
569  Label feat = *oit;
570  typename std::map<Label, std::set<size_t>>::const_iterator jt =
571  feat_groups_.find(feat);
572  if (jt == feat_groups_.end()) continue;
573  for (std::set<size_t>::const_iterator git = jt->second.begin();
574  git != jt->second.end(); ++git) {
575  size_t group_id = *git;
576  if (!data->group_feat_map_.Set(group_id, word, feat)) {
577  error_ = true;
578  return nullptr;
579  }
580  }
581  }
582  }
583 
584  // Possible output labels
585  {
586  std::vector<typename LinearFstData<A>::InputAttribute> *input_attribs =
587  &data->input_attribs_;
588  std::vector<Label> *output_pool = &data->output_pool_;
589  input_attribs->resize(max_input_label_ + 1);
590  for (Label word = 0; word <= max_input_label_; ++word) {
591  typename std::map<Label, std::set<Label>>::const_iterator it =
592  word_output_map_.find(word);
593  if (it == word_output_map_.end()) {
594  (*input_attribs)[word].output_begin = 0;
595  (*input_attribs)[word].output_length = 0;
596  } else {
597  (*input_attribs)[word].output_begin = output_pool->size();
598  (*input_attribs)[word].output_length = it->second.size();
599  for (typename std::set<Label>::const_iterator oit = it->second.begin();
600  oit != it->second.end(); ++oit) {
601  Label olabel = *oit;
602  output_pool->push_back(olabel);
603  }
604  }
605  }
606  }
607 
609  all_output_labels_.Begin();
610  it != all_output_labels_.End(); ++it)
611  data->output_set_.push_back(*it);
612 
613  error_ = true; // prevent future calls on this object
614  return data.release();
615 }
616 
617 //
618 // Implementation of methods in `LinearClassifierFstDataBuilder`
619 //
620 template <class A>
622  Label word, const std::vector<Label> &features) {
623  if (error_) {
624  FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddWord() at "
625  "error state";
626  return false;
627  }
628  bool added = builder_.AddWord(word, features);
629  if (builder_.Error()) error_ = true;
630  return added;
631 }
632 
633 template <class A>
635  if (error_) {
636  FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddGroup() at "
637  "error state";
638  return -1;
639  }
640  for (int i = 0; i < num_classes_; ++i) builder_.AddGroup(0);
641  if (builder_.Error()) {
642  error_ = true;
643  return -1;
644  }
645  return num_groups_++;
646 }
647 
648 template <class A>
650  size_t group, const std::vector<Label> &input, Label pred, Weight weight) {
651  if (error_) {
652  FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddWeight() at "
653  "error state";
654  return false;
655  }
656  if (pred <= 0 || pred > num_classes_) {
657  FSTERROR() << "Out-of-range prediction label: " << pred
658  << " (num classes = " << num_classes_ << ")";
659  error_ = true;
660  return false;
661  }
662  size_t real_group = group * num_classes_ + pred - 1;
663  bool added = builder_.AddWeight(real_group, input, empty_, weight);
664  if (builder_.Error()) error_ = true;
665  return added;
666 }
667 
668 template <class A>
670  if (error_) {
671  FSTERROR()
672  << "Calling LinearClassifierFstDataBuilder<>::Dump() at error state";
673  return nullptr;
674  }
675  LinearFstData<A> *data = builder_.Dump();
676  error_ = true;
677  return data;
678 }
679 
680 //
681 // Implementation of methods in `FeatureGroupBuilder`
682 //
683 template <class A>
684 bool FeatureGroupBuilder<A>::AddWeight(const std::vector<Label> &input,
685  const std::vector<Label> &output,
686  Weight weight) {
687  if (error_) {
688  FSTERROR() << "Calling FeatureGroupBuilder<>::AddWeight() at error state";
689  return false;
690  }
691 
692  // `LinearFstDataBuilder<>::AddWeight()` ensures prefix/suffix
693  // properties for us. We can directly count.
694  int num_input_start = 0;
695  while (num_input_start < input.size() &&
696  input[num_input_start] == LinearFstData<A>::kStartOfSentence)
697  ++num_input_start;
698  int num_output_start = 0;
699  while (num_output_start < output.size() &&
700  output[num_output_start] == LinearFstData<A>::kStartOfSentence)
701  ++num_output_start;
702  int num_input_end = 0;
703  for (int i = input.size() - 1;
704  i >= 0 && input[i] == LinearFstData<A>::kEndOfSentence; --i)
705  ++num_input_end;
706  int num_output_end = 0;
707  for (int i = output.size() - 1;
708  i >= 0 && output[i] == LinearFstData<A>::kEndOfSentence; --i)
709  ++num_output_end;
710 
711  DCHECK_LE(num_output_end, 1);
712 
713  if (input.size() - num_input_start < future_size_) {
714  LOG(WARNING) << "Ignored: start-of-sentence in the future!";
715  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
716  LOG(WARNING) << "\tOutput: " << JoinLabels(output, fsyms_);
717  return false;
718  }
719  if (num_input_start > 0 && input.size() - future_size_ - num_input_start <
720  output.size() - num_output_start) {
721  LOG(WARNING) << "Ignored: matching start-of-sentence with actual output!";
722  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
723  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
724  return false;
725  }
726  if (num_output_start > 0 && input.size() - future_size_ - num_input_start >
727  output.size() - num_output_start) {
728  LOG(WARNING) << "Ignored: matching start-of-sentence with actual input!";
729  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
730  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
731  return false;
732  }
733  // The following two require `num_output_end` <= 1.
734  if (num_input_end > future_size_ && num_input_end - future_size_ != 1) {
735  LOG(WARNING) << "Ignored: matching end-of-sentence with actual output!";
736  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
737  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
738  return false;
739  }
740  if (num_output_end > 0 &&
741  ((input.size() == future_size_ && future_size_ != num_input_end) ||
742  (input.size() > future_size_ &&
743  num_input_end != future_size_ + num_output_end))) {
744  LOG(WARNING) << "Ignored: matching end-of-sentence with actual input!";
745  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
746  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
747  return false;
748  }
749  // Check if the context has no other labels than boundary marks
750  // (such features are useless).
751  if (num_input_start + num_input_end == input.size() &&
752  num_output_start + num_output_end == output.size()) {
753  LOG(WARNING)
754  << "Ignored: feature context consisting of only boundary marks!";
755  LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_);
756  LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_);
757  return false;
758  }
759 
760  // Start point for insertion in the trie. Insert at `start_` iff the
761  // beginning of the context is non-consumed start-of-sentence.
762  int cur = (num_input_start == 0 && num_output_start <= future_size_)
763  ? trie_.Root()
764  : start_;
765  // Skip all input start-of-sentence marks
766  size_t ipos = num_input_start;
767  // Skip to keep at most `future_size_` start-of-sentence marks
768  size_t opos =
769  num_output_start <= future_size_ ? 0 : num_output_start - future_size_;
770  // Skip `num_output_end` end-of-sentence marks on both input and output
771  size_t iend = !input.empty() ? input.size() - num_output_end : 0,
772  oend = output.size() - num_output_end;
773  // Further, when output is empty, keep at most `future_size_`
774  // end-of-sentence marks on input.
775  if (output.empty() && num_input_end > future_size_)
776  iend = input.size() - num_input_end + future_size_;
777 
778  // Actual feature context is (input[ipos:iend], output[opos:oend]).
779 
780  // Pad `kNoLabel` as don't cares on the shorter of actual `input`
781  // and `output`.
782  const size_t effective_input_size = iend - ipos,
783  effective_output_size = oend - opos;
784  if (effective_input_size > effective_output_size) {
785  for (size_t pad = effective_input_size - effective_output_size; pad != 0;
786  --pad, ++ipos)
787  cur = trie_.Insert(cur, InputOutputLabel(input[ipos], kNoLabel));
788  } else if (effective_input_size < effective_output_size) {
789  for (size_t pad = effective_output_size - effective_input_size; pad != 0;
790  --pad, ++opos)
791  cur = trie_.Insert(cur, InputOutputLabel(kNoLabel, output[opos]));
792  }
793  CHECK_EQ(iend - ipos, oend - opos);
794  for (; ipos != iend; ++ipos, ++opos)
795  cur = trie_.Insert(cur, InputOutputLabel(input[ipos], output[opos]));
796  // We only need to attach final weight when there is an output
797  // end-of-sentence. When there is only end-of-sentence on the input,
798  // they are all consumed as the end-of-sentence paddings from
799  // `LinearFstImpl<>::ShiftBuffer()`. `LinearFstImpl<>::Expand()`
800  // and `LinearFstImpl<>::MatchInput()` ensures no other
801  // transition takes place after consuming the padding.
802  if (num_output_end > 0 || (output.empty() && num_input_end > future_size_))
803  trie_[cur].final_weight = Times(trie_[cur].final_weight, weight);
804  else
805  trie_[cur].weight = Times(trie_[cur].weight, weight);
806 
807  return true;
808 }
809 
810 template <class A>
812  if (error_) {
813  FSTERROR() << "Calling FeatureGroupBuilder<>::PreAccumulateWeights() "
814  << "at error state";
815  return nullptr;
816  }
817 
818  if (max_future_size < future_size_) {
819  error_ = true;
820  FSTERROR() << "max_future_size (= " << max_future_size
821  << ") is smaller the builder's future_size (= " << future_size_
822  << ")";
823  return nullptr;
824  }
825 
826  BuildBackLinks();
827  if (error_) return nullptr;
828  PreAccumulateWeights(); // does not fail
829 
830  FeatureGroup<A> *ret =
831  new FeatureGroup<A>(max_future_size - future_size_, start_);
832 
833  // Walk around the trie to compute next states
834  ret->next_state_.resize(trie_.NumNodes());
835  const Topology &topology = trie_.TrieTopology();
836  for (int i = 0; i < topology.NumNodes(); ++i) {
837  int next = i;
838  while (next != topology.Root() && topology.ChildrenOf(next).empty() &&
839  trie_[next].final_weight ==
840  trie_[trie_[next].back_link].final_weight)
841  next = trie_[next].back_link;
842  ret->next_state_[i] = next;
843  }
844 
845  // Copy the trie
846  typename FeatureGroup<A>::Trie store_trie(trie_);
847  ret->trie_.swap(store_trie);
848 
849  // Put the builder at error state to prevent repeated call of `Dump()`.
850  error_ = true;
851  return ret;
852 }
853 
854 template <class A>
856  int *hop) const {
857  int hop_count = 0;
858  int ret = kNoTrieNodeId;
859  for (; parent >= 0; parent = trie_[parent].back_link, ++hop_count) {
860  int next = trie_.Find(parent, label);
861  if (next != kNoTrieNodeId) {
862  ret = next;
863  break;
864  }
865  }
866  if (hop != nullptr) *hop = hop_count;
867  return ret;
868 }
869 
870 template <class A>
872  // Breadth first search from the root. In the case where we only
873  // have the input label, the immedate back-off is simply the longest
874  // suffix of the current node that is also in the trie. For a node
875  // reached from its parent with label L, we can simply walk through
876  // the parent's back-off chain to find the first state with an arc
877  // of the same label L. The uniqueness is always
878  // guanranteed. However, in the case with both input and output
879  // labels, it is possible to back off by removing first labels from
880  // either side, which in general causes non-uniqueness.
881 
882  const Topology &topology = trie_.TrieTopology();
883  std::queue<int> q; // all enqueued or visited nodes have known links
884 
885  // Note: nodes have back link initialized to -1 in their
886  // constructor.
887  q.push(trie_.Root());
888  while (!error_ && !q.empty()) {
889  int parent = q.front();
890  q.pop();
891  // Find links for every child
892  const typename Topology::NextMap &children = topology.ChildrenOf(parent);
893  for (typename Topology::NextMap::const_iterator eit = children.begin();
894  eit != children.end(); ++eit) {
895  const std::pair<InputOutputLabel, int> &edge = *eit;
896  InputOutputLabel label = edge.first;
897  int child = edge.second;
898  if (label.input == kNoLabel || label.output == kNoLabel) {
899  // Label pairs from root to here all have one and only one
900  // `kNoLabel` on the same side; equivalent to the
901  // "longest-suffix" case.
902  trie_[child].back_link =
903  FindFirstMatch(label, trie_[parent].back_link, nullptr);
904  } else {
905  // Neither side is `kNoLabel` at this point, there are
906  // three possible ways to back-off: if the parent backs
907  // off to some context with only one side non-empty, the
908  // empty side may remain empty; or else an exact match of
909  // both sides is needed. Try to find all three possible
910  // backs and look for the closest one (in terms of hops
911  // along the parent's back-off chain).
912  int only_input_hop, only_output_hop, full_hop;
913  int only_input_link =
914  FindFirstMatch(InputOutputLabel(label.input, kNoLabel), parent,
915  &only_input_hop),
916  only_output_link =
917  FindFirstMatch(InputOutputLabel(kNoLabel, label.output), parent,
918  &only_output_hop),
919  full_link =
920  FindFirstMatch(label, trie_[parent].back_link, &full_hop);
921  if (only_input_link != -1 && only_output_link != -1) {
922  error_ = true;
923  FSTERROR() << "Branching back-off chain:\n"
924  << "\tnode " << child << ": " << TriePath(child, topology)
925  << "\n"
926  << "\tcan back-off to node " << only_input_link << ": "
927  << TriePath(only_input_link, topology) << "\n"
928  << "\tcan back-off to node " << only_output_link << ": "
929  << TriePath(only_output_link, topology);
930  return;
931  } else if (full_link != -1) {
932  ++full_hop;
933  if (full_hop <= only_input_hop && full_hop <= only_output_hop) {
934  trie_[child].back_link = full_link;
935  } else {
936  error_ = true;
937  int problem_link = only_input_link != kNoTrieNodeId
938  ? only_input_link
939  : only_output_link;
940  CHECK_NE(problem_link, kNoTrieNodeId);
941  FSTERROR() << "Branching back-off chain:\n"
942  << "\tnode " << child << ": "
943  << TriePath(child, topology) << "\n"
944  << "\tcan back-off to node " << full_link << ": "
945  << TriePath(full_link, topology) << "\n"
946  << "tcan back-off to node " << problem_link << ": "
947  << TriePath(problem_link, topology);
948  return;
949  }
950  } else {
951  trie_[child].back_link =
952  only_input_link != -1 ? only_input_link : only_output_link;
953  }
954  }
955  if (error_) break;
956  // Point to empty context (root) when no back-off can be found
957  if (trie_[child].back_link == -1) trie_[child].back_link = 0;
958  q.push(child);
959  }
960  }
961 }
962 
963 template <class A>
965  std::vector<bool> visited(trie_.NumNodes(), false);
966  visited[trie_.Root()] = true;
967 
968  for (size_t i = 0; i != trie_.NumNodes(); ++i) {
969  std::stack<int> back_offs;
970  for (int j = i; !visited[j]; j = trie_[j].back_link) back_offs.push(j);
971  while (!back_offs.empty()) {
972  int j = back_offs.top();
973  back_offs.pop();
974  WeightBackLink &node = trie_[j];
975  node.weight = Times(node.weight, trie_[node.back_link].weight);
976  node.final_weight =
977  Times(node.final_weight, trie_[node.back_link].final_weight);
978  visited[j] = true;
979  }
980  }
981 }
982 
983 template <class A>
985  const Topology &topology, int cur, int target,
986  std::vector<InputOutputLabel> *path) const {
987  if (cur == target) return true;
988  const typename Topology::NextMap &children = topology.ChildrenOf(cur);
989  for (typename Topology::NextMap::const_iterator eit = children.begin();
990  eit != children.end(); ++eit) {
991  const std::pair<InputOutputLabel, int> &edge = *eit;
992  path->push_back(edge.first);
993  if (TrieDfs(topology, edge.second, target, path)) return true;
994  path->pop_back();
995  }
996  return false;
997 }
998 
999 template <class A>
1000 std::string FeatureGroupBuilder<A>::TriePath(int node,
1001  const Topology &topology) const {
1002  std::vector<InputOutputLabel> labels;
1003  TrieDfs(topology, topology.Root(), node, &labels);
1004  bool first = true;
1005  std::ostringstream strm;
1006  for (typename std::vector<InputOutputLabel>::const_iterator it =
1007  labels.begin();
1008  it != labels.end(); ++it) {
1009  InputOutputLabel i = *it;
1010  if (first)
1011  first = false;
1012  else
1013  strm << ", ";
1014  strm << "(" << TranslateLabel(i.input, fsyms_) << ", "
1015  << TranslateLabel(i.output, osyms_) << ")";
1016  }
1017  return strm.str();
1018 }
1019 
1020 inline std::string TranslateLabel(int64_t label, const SymbolTable *syms) {
1021  std::string ret;
1022  if (syms != nullptr) ret += syms->Find(label);
1023  if (ret.empty()) {
1024  std::ostringstream strm;
1025  strm << '<' << label << '>';
1026  ret = strm.str();
1027  }
1028  return ret;
1029 }
1030 
1031 template <class Iterator>
1032 std::string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms) {
1033  if (begin == end) return "<empty>";
1034  std::ostringstream strm;
1035  bool first = true;
1036  for (Iterator it = begin; it != end; ++it) {
1037  if (first)
1038  first = false;
1039  else
1040  strm << '|';
1041  strm << TranslateLabel(*it, syms);
1042  }
1043  return strm.str();
1044 }
1045 
1046 template <class Label>
1047 std::string JoinLabels(const std::vector<Label> &labels,
1048  const SymbolTable *syms) {
1049  return JoinLabels(labels.begin(), labels.end(), syms);
1050 }
1051 
1052 template <class A>
1053 typename A::Label GuessStartOrEnd(std::vector<typename A::Label> *sequence,
1054  typename A::Label boundary) {
1055  const size_t length = sequence->size();
1056  std::vector<bool> non_boundary_on_left(length, false),
1057  non_boundary_on_right(length, false);
1058  for (size_t i = 1; i < length; ++i) {
1059  non_boundary_on_left[i] =
1060  non_boundary_on_left[i - 1] || (*sequence)[i - 1] != boundary;
1061  non_boundary_on_right[length - 1 - i] = non_boundary_on_right[length - i] ||
1062  (*sequence)[length - i] != boundary;
1063  }
1064  int unresolved = 0;
1065  for (size_t i = 0; i < length; ++i) {
1066  if ((*sequence)[i] != boundary) continue;
1067  const bool left = non_boundary_on_left[i], right = non_boundary_on_right[i];
1068  if (left && right) {
1069  // Boundary in the middle
1070  LOG(WARNING) << "Boundary label in the middle of the sequence! position: "
1071  << i << "; boundary: " << boundary
1072  << "; sequence: " << JoinLabels(*sequence, nullptr);
1073  LOG(WARNING)
1074  << "This is an invalid sequence anyway so I will set it to start.";
1075  (*sequence)[i] = LinearFstData<A>::kStartOfSentence;
1076  } else if (left && !right) {
1077  // Can only be end
1078  (*sequence)[i] = LinearFstData<A>::kEndOfSentence;
1079  } else if (!left && right) {
1080  // Can only be start
1081  (*sequence)[i] = LinearFstData<A>::kStartOfSentence;
1082  } else {
1083  // !left && !right; can't really tell
1084  ++unresolved;
1085  }
1086  }
1087  return unresolved;
1088 }
1089 
1090 } // namespace fst
1091 
1092 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_
std::string Stats() const
FeatureGroup< A > * Dump(size_t max_future_size)
constexpr int kNoLabel
Definition: fst.h:195
FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms, const SymbolTable *osyms)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
#define LOG(type)
Definition: log.h:53
bool AddWord(Label word, const std::vector< Label > &features)
bool AddWeight(const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
int AddGroup(size_t future_size)
#define FSTERROR()
Definition: util.h:56
LinearClassifierFstDataBuilder(size_t num_classes, const SymbolTable *isyms=nullptr, const SymbolTable *fsyms=nullptr, const SymbolTable *osyms=nullptr)
bool AddWord(Label word, const std::vector< Label > &features)
std::string TranslateLabel(int64_t label, const SymbolTable *syms)
LinearFstDataBuilder(const SymbolTable *isyms=nullptr, const SymbolTable *fsyms=nullptr, const SymbolTable *osyms=nullptr)
#define VLOG(level)
Definition: log.h:54
#define CHECK_NE(x, y)
Definition: log.h:71
constexpr int kNoTrieNodeId
Definition: trie.h:35
std::string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms)
int Root() const
Definition: trie.h:149
size_t NumNodes() const
Definition: trie.h:150
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
#define DCHECK_LE(x, y)
Definition: log.h:78
#define CHECK_EQ(x, y)
Definition: log.h:66
const NextMap & ChildrenOf(int parent) const
Definition: trie.h:153
std::unordered_map< L, int, H > NextMap
Definition: trie.h:83
const_iterator Begin() const
Definition: util.h:496
std::string Find(int64_t key) const
Definition: symbol-table.h:450
A::Label GuessStartOrEnd(std::vector< typename A::Label > *sequence, typename A::Label boundary)