FST  openfst-1.7.1
OpenFst Library
linearscript.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 
4 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
5 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
6 
7 #include <istream>
8 #include <sstream>
9 #include <string>
10 #include <vector>
11 
12 #include <fst/compat.h>
15 #include <fstream>
16 #include <fst/symbol-table.h>
17 #include <fst/script/arg-packs.h>
18 #include <fst/script/script-impl.h>
19 
20 DECLARE_string(delimiter);
21 DECLARE_string(empty_symbol);
22 DECLARE_string(start_symbol);
23 DECLARE_string(end_symbol);
24 DECLARE_bool(classifier);
25 
26 namespace fst {
27 namespace script {
28 typedef std::tuple<const string &, const string &, const string &, char **, int,
29  const string &, const string &, const string &,
30  const string &>
32 
33 bool ValidateDelimiter();
34 bool ValidateEmptySymbol();
35 
36 // Returns the proper label given the symbol. For symbols other than
37 // `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol
38 // table to decide the label. Depending on whether
39 // `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it
40 // either returns `kNoLabel` for later processing or decides the label
41 // right away.
42 template <class Arc>
43 inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) {
44  if (str == FLAGS_start_symbol)
45  return str == FLAGS_end_symbol ? kNoLabel
47  else if (str == FLAGS_end_symbol)
49  else
50  return syms->AddSymbol(str);
51 }
52 
53 // Splits `str` with `delim` as the delimiter and stores the labels in
54 // `output`.
55 template <class Arc>
56 void SplitAndPush(const string &str, const char delim, SymbolTable *syms,
57  std::vector<typename Arc::Label> *output) {
58  if (str == FLAGS_empty_symbol) return;
59  std::istringstream strm(str);
60  string buf;
61  while (std::getline(strm, buf, delim))
62  output->push_back(LookUp<Arc>(buf, syms));
63 }
64 
65 // Like `std::replace_copy` but returns the number of modifications
66 template <class InputIterator, class OutputIterator, class T>
67 size_t ReplaceCopy(InputIterator first, InputIterator last,
68  OutputIterator result, const T &old_value,
69  const T &new_value) {
70  size_t changes = 0;
71  while (first != last) {
72  if (*first == old_value) {
73  *result = new_value;
74  ++changes;
75  } else {
76  *result = *first;
77  }
78  ++first;
79  ++result;
80  }
81  return changes;
82 }
83 
84 template <class Arc>
85 bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT
86  SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
87  typename Arc::Label *word,
88  std::vector<typename Arc::Label> *feature_labels,
89  std::vector<typename Arc::Label> *possible_labels,
90  size_t *num_line);
91 
92 template <class Arc>
93 bool GetModelRecord(const string &model, std::istream &strm, // NOLINT
94  SymbolTable *fsyms, SymbolTable *osyms,
95  std::vector<typename Arc::Label> *input_labels,
96  std::vector<typename Arc::Label> *output_labels,
97  typename Arc::Weight *weight, size_t *num_line);
98 
99 // Reads in vocabulary file. Each line is in the following format
100 //
101 // word <whitespace> features [ <whitespace> possible output ]
102 //
103 // where features and possible output are `FLAGS_delimiter`-delimited lists of
104 // tokens
105 template <class Arc>
106 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
107  SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
108  std::ifstream in(vocab);
109  if (!in) LOG(FATAL) << "Can't open file: " << vocab;
110  size_t num_line = 0, num_added = 0;
111  std::vector<string> fields;
112  std::vector<typename Arc::Label> feature_labels, possible_labels;
113  typename Arc::Label word;
114  while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
115  &feature_labels, &possible_labels, &num_line)) {
116  if (word == kNoLabel) {
117  LOG(WARNING) << "Ignored: boundary word: " << fields[0];
118  continue;
119  }
120  if (possible_labels.empty())
121  num_added += builder->AddWord(word, feature_labels);
122  else
123  num_added += builder->AddWord(word, feature_labels, possible_labels);
124  }
125  VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
126  << vocab;
127 }
128 
129 template <class Arc>
130 void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
131  SymbolTable *osyms,
133  std::ifstream in(vocab);
134  if (!in) LOG(FATAL) << "Can't open file: " << vocab;
135  size_t num_line = 0, num_added = 0;
136  std::vector<string> fields;
137  std::vector<typename Arc::Label> feature_labels, possible_labels;
138  typename Arc::Label word;
139  while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
140  &feature_labels, &possible_labels, &num_line)) {
141  if (!possible_labels.empty())
142  LOG(FATAL)
143  << "Classifier vocabulary should not have possible output constraint";
144  if (word == kNoLabel) {
145  LOG(WARNING) << "Ignored: boundary word: " << fields[0];
146  continue;
147  }
148  num_added += builder->AddWord(word, feature_labels);
149  }
150  VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
151  << vocab;
152 }
153 
154 // Reads in model file. The first line is an integer designating the
155 // size of future window in the input sequences. After this, each line
156 // is in the following format
157 //
158 // input sequence <whitespace> output sequence <whitespace> weight
159 //
160 // input sequence is a `FLAGS_delimiter`-delimited sequence of feature
161 // labels (see `AddVocab()`) . output sequence is a
162 // `FLAGS_delimiter`-delimited sequence of output labels where the
163 // last label is the output of the feature position before the history
164 // boundary.
165 template <class Arc>
166 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
167  LinearFstDataBuilder<Arc> *builder) {
168  std::ifstream in(model);
169  if (!in) LOG(FATAL) << "Can't open file: " << model;
170  string line;
171  std::getline(in, line);
172  if (!in) LOG(FATAL) << "Empty file: " << model;
173  size_t future_size;
174  {
175  std::istringstream strm(line);
176  strm >> future_size;
177  if (!strm) LOG(FATAL) << "Can't read future size: " << model;
178  }
179  size_t num_line = 1, num_added = 0;
180  const int group = builder->AddGroup(future_size);
181  VLOG(1) << "Group " << group << ": from " << model << "; future size is "
182  << future_size << ".";
183  // Add the rest of lines as a single feature group
184  std::vector<string> fields;
185  std::vector<typename Arc::Label> input_labels, output_labels;
186  typename Arc::Weight weight;
187  while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
188  &output_labels, &weight, &num_line)) {
189  if (output_labels.empty())
190  LOG(FATAL) << "Empty output sequence in source " << model << ", line "
191  << num_line;
192 
193  const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
195 
196  std::vector<typename Arc::Label> copy_input(input_labels.size()),
197  copy_output(output_labels.size());
198  for (int i = 0; i < 2; ++i) {
199  for (int j = 0; j < 2; ++j) {
200  size_t num_input_changes =
201  ReplaceCopy(input_labels.begin(), input_labels.end(),
202  copy_input.begin(), kNoLabel, marks[i]);
203  size_t num_output_changes =
204  ReplaceCopy(output_labels.begin(), output_labels.end(),
205  copy_output.begin(), kNoLabel, marks[j]);
206  if ((num_input_changes > 0 || i == 0) &&
207  (num_output_changes > 0 || j == 0))
208  num_added +=
209  builder->AddWeight(group, copy_input, copy_output, weight);
210  }
211  }
212  }
213  VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
214  << num_line << " lines.";
215 }
216 
217 template <class Arc>
218 void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms,
220  std::ifstream in(model);
221  if (!in) LOG(FATAL) << "Can't open file: " << model;
222  string line;
223  std::getline(in, line);
224  if (!in) LOG(FATAL) << "Empty file: " << model;
225  size_t future_size;
226  {
227  std::istringstream strm(line);
228  strm >> future_size;
229  if (!strm) LOG(FATAL) << "Can't read future size: " << model;
230  }
231  if (future_size != 0)
232  LOG(FATAL) << "Classifier model must have future size = 0; got "
233  << future_size << " from " << model;
234  size_t num_line = 1, num_added = 0;
235  const int group = builder->AddGroup();
236  VLOG(1) << "Group " << group << ": from " << model << "; future size is "
237  << future_size << ".";
238  // Add the rest of lines as a single feature group
239  std::vector<string> fields;
240  std::vector<typename Arc::Label> input_labels, output_labels;
241  typename Arc::Weight weight;
242  while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
243  &output_labels, &weight, &num_line)) {
244  if (output_labels.size() != 1)
245  LOG(FATAL) << "Output not a single label in source " << model << ", line "
246  << num_line;
247 
248  const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
250 
251  typename Arc::Label pred = output_labels[0];
252 
253  std::vector<typename Arc::Label> copy_input(input_labels.size());
254  for (int i = 0; i < 2; ++i) {
255  size_t num_input_changes =
256  ReplaceCopy(input_labels.begin(), input_labels.end(),
257  copy_input.begin(), kNoLabel, marks[i]);
258  if (num_input_changes > 0 || i == 0)
259  num_added += builder->AddWeight(group, copy_input, pred, weight);
260  }
261  }
262  VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
263  << num_line << " lines.";
264 }
265 
266 void SplitByWhitespace(const string &str, std::vector<string> *out);
267 int ScanNumClasses(char **models, int models_length);
268 
269 template <class Arc>
271  const string &epsilon_symbol = std::get<0>(*args);
272  const string &unknown_symbol = std::get<1>(*args);
273  const string &vocab = std::get<2>(*args);
274  char **models = std::get<3>(*args);
275  const int models_length = std::get<4>(*args);
276  const string &out = std::get<5>(*args);
277  const string &save_isymbols = std::get<6>(*args);
278  const string &save_fsymbols = std::get<7>(*args);
279  const string &save_osymbols = std::get<8>(*args);
280 
281  SymbolTable isyms, // input (e.g. word tokens)
282  osyms, // output (e.g. tags)
283  fsyms; // feature (e.g. word identity, suffix, etc.)
284  isyms.AddSymbol(epsilon_symbol);
285  osyms.AddSymbol(epsilon_symbol);
286  fsyms.AddSymbol(epsilon_symbol);
287  isyms.AddSymbol(unknown_symbol);
288 
289  VLOG(1) << "start-of-sentence label is "
291  VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
292 
293  if (FLAGS_classifier) {
294  int num_classes = ScanNumClasses(models, models_length);
295  LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
296  &osyms);
297 
298  AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
299  for (int i = 0; i < models_length; ++i)
300  AddModel(models[i], &fsyms, &osyms, &builder);
301 
302  LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
303  fst.Write(out);
304  } else {
305  LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
306 
307  AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
308  for (int i = 0; i < models_length; ++i)
309  AddModel(models[i], &fsyms, &osyms, &builder);
310 
311  LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
312  fst.Write(out);
313  }
314 
315  if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
316  if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
317  if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
318 }
319 
320 void LinearCompile(const string &arc_type, const string &epsilon_symbol,
321  const string &unknown_symbol, const string &vocab,
322  char **models, int models_len, const string &out,
323  const string &save_isymbols, const string &save_fsymbols,
324  const string &save_osymbols);
325 
326 template <class Arc>
327 bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT
328  SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
329  typename Arc::Label *word,
330  std::vector<typename Arc::Label> *feature_labels,
331  std::vector<typename Arc::Label> *possible_labels,
332  size_t *num_line) {
333  string line;
334  if (!std::getline(strm, line)) return false;
335  ++(*num_line);
336 
337  std::vector<string> fields;
338  SplitByWhitespace(line, &fields);
339  if (fields.size() != 3)
340  LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
341  << num_line;
342 
343  feature_labels->clear();
344  possible_labels->clear();
345 
346  *word = LookUp<Arc>(fields[0], isyms);
347 
348  const char delim = FLAGS_delimiter[0];
349  SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
350  SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
351 
352  return true;
353 }
354 
355 template <class Arc>
356 bool GetModelRecord(const string &model, std::istream &strm, // NOLINT
357  SymbolTable *fsyms, SymbolTable *osyms,
358  std::vector<typename Arc::Label> *input_labels,
359  std::vector<typename Arc::Label> *output_labels,
360  typename Arc::Weight *weight, size_t *num_line) {
361  string line;
362  if (!std::getline(strm, line)) return false;
363  ++(*num_line);
364 
365  std::vector<string> fields;
366  SplitByWhitespace(line, &fields);
367  if (fields.size() != 3)
368  LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
369  << num_line;
370 
371  input_labels->clear();
372  output_labels->clear();
373 
374  const char delim = FLAGS_delimiter[0];
375  SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
376  SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
377 
378  *weight = StrToWeight<typename Arc::Weight>(fields[2], model, *num_line);
379 
380  GuessStartOrEnd<Arc>(input_labels, kNoLabel);
381  GuessStartOrEnd<Arc>(output_labels, kNoLabel);
382 
383  return true;
384 }
385 } // namespace script
386 } // namespace fst
387 
388 #define REGISTER_FST_LINEAR_OPERATIONS(Arc) \
389  REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs);
390 
391 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
DECLARE_string(delimiter)
void SplitAndPush(const string &str, const char delim, SymbolTable *syms, std::vector< typename Arc::Label > *output)
Definition: linearscript.h:56
constexpr int kNoLabel
Definition: fst.h:179
void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
Definition: linearscript.h:166
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
#define LOG(type)
Definition: log.h:48
bool AddWord(Label word, const std::vector< Label > &features)
virtual bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
int AddGroup(size_t future_size)
Arc::Label LookUp(const string &str, SymbolTable *syms)
Definition: linearscript.h:43
bool AddWord(Label word, const std::vector< Label > &features)
bool ValidateEmptySymbol()
Definition: linearscript.cc:36
virtual int64 AddSymbol(const string &symbol, int64 key)
Definition: symbol-table.h:268
std::tuple< const string &, const string &, const string &, char **, int, const string &, const string &, const string &, const string & > LinearCompileArgs
Definition: linearscript.h:31
DECLARE_bool(classifier)
#define VLOG(level)
Definition: log.h:49
int ScanNumClasses(char **models, int models_len)
Definition: linearscript.cc:67
bool GetVocabRecord(const string &vocab, std::istream &strm, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, typename Arc::Label *word, std::vector< typename Arc::Label > *feature_labels, std::vector< typename Arc::Label > *possible_labels, size_t *num_line)
Definition: linearscript.h:327
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
size_t ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result, const T &old_value, const T &new_value)
Definition: linearscript.h:67
bool GetModelRecord(const string &model, std::istream &strm, SymbolTable *fsyms, SymbolTable *osyms, std::vector< typename Arc::Label > *input_labels, std::vector< typename Arc::Label > *output_labels, typename Arc::Weight *weight, size_t *num_line)
Definition: linearscript.h:356
void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
Definition: linearscript.h:106
void LinearCompile(const string &arc_type, const string &epsilon_symbol, const string &unknown_symbol, const string &vocab, char **models, int models_len, const string &out, const string &save_isymbols, const string &save_fsymbols, const string &save_osymbols)
Definition: linearscript.cc:45
bool ValidateDelimiter()
Definition: linearscript.cc:30
void LinearCompileTpl(LinearCompileArgs *args)
Definition: linearscript.h:270
void SplitByWhitespace(const string &str, std::vector< string > *out)
Definition: linearscript.cc:60