FST  openfst-1.8.2
OpenFst Library
linearscript.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
19 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
20 
21 #include <istream>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25 
26 #include <fst/compat.h>
29 #include <fstream>
30 #include <fst/symbol-table.h>
31 #include <fst/script/arg-packs.h>
32 #include <fst/script/script-impl.h>
33 
34 DECLARE_string(delimiter);
35 DECLARE_string(empty_symbol);
36 DECLARE_string(start_symbol);
37 DECLARE_string(end_symbol);
38 DECLARE_bool(classifier);
39 
40 namespace fst {
41 namespace script {
42 
43 using LinearCompileArgs =
44  std::tuple<const std::string &, const std::string &, const std::string &,
45  char **, int, const std::string &, const std::string &,
46  const std::string &, const std::string &>;
47 
48 bool ValidateDelimiter();
49 
50 bool ValidateEmptySymbol();
51 
52 // Returns the proper label given the symbol. For symbols other than
53 // `FST_FLAGS_start_symbol` or `FST_FLAGS_end_symbol`, looks up the symbol
54 // table to decide the label. Depending on whether
55 // `FST_FLAGS_start_symbol` and `FST_FLAGS_end_symbol` are identical, it
56 // either returns `kNoLabel` for later processing or decides the label
57 // right away.
58 template <class Arc>
59 inline typename Arc::Label LookUp(const std::string &str, SymbolTable *syms) {
60  if (str == FST_FLAGS_start_symbol) {
61  return str == FST_FLAGS_end_symbol
62  ? kNoLabel
64  } else if (str == FST_FLAGS_end_symbol) {
66  } else {
67  return syms->AddSymbol(str);
68  }
69 }
70 
71 // Splits `str` with `delim` as the delimiter and stores the labels in
72 // `output`.
73 template <class Arc>
74 void SplitAndPush(const std::string &str, const char delim, SymbolTable *syms,
75  std::vector<typename Arc::Label> *output) {
76  if (str == FST_FLAGS_empty_symbol) return;
77  std::istringstream strm(str);
78  std::string buf;
79  while (std::getline(strm, buf, delim)) {
80  output->push_back(LookUp<Arc>(buf, syms));
81  }
82 }
83 
84 // Like `std::replace_copy` but returns the number of modifications
85 template <class InputIterator, class OutputIterator, class T>
86 size_t ReplaceCopy(InputIterator first, InputIterator last,
87  OutputIterator result, const T &old_value,
88  const T &new_value) {
89  size_t changes = 0;
90  while (first != last) {
91  if (*first == old_value) {
92  *result = new_value;
93  ++changes;
94  } else {
95  *result = *first;
96  }
97  ++first;
98  ++result;
99  }
100  return changes;
101 }
102 
103 template <class Arc>
104 bool GetVocabRecord(const std::string &vocab, std::istream &strm,
105  SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
106  typename Arc::Label *word,
107  std::vector<typename Arc::Label> *feature_labels,
108  std::vector<typename Arc::Label> *possible_labels,
109  size_t *num_line);
110 
111 template <class Arc>
112 bool GetModelRecord(const std::string &model, std::istream &strm,
113  SymbolTable *fsyms, SymbolTable *osyms,
114  std::vector<typename Arc::Label> *input_labels,
115  std::vector<typename Arc::Label> *output_labels,
116  typename Arc::Weight *weight, size_t *num_line);
117 
118 // Reads in vocabulary file. Each line is in the following format
119 //
120 // word <whitespace> features [ <whitespace> possible output ]
121 //
122 // where features and possible output are `FST_FLAGS_delimiter`-delimited lists of
123 // tokens
124 template <class Arc>
125 void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
126  SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
127  std::ifstream in(vocab);
128  if (!in) LOG(FATAL) << "Can't open file: " << vocab;
129  size_t num_line = 0, num_added = 0;
130  std::vector<std::string> fields;
131  std::vector<typename Arc::Label> feature_labels, possible_labels;
132  typename Arc::Label word;
133  while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
134  &feature_labels, &possible_labels, &num_line)) {
135  if (word == kNoLabel) {
136  LOG(WARNING) << "Ignored: boundary word: " << fields[0];
137  continue;
138  }
139  if (possible_labels.empty()) {
140  num_added += builder->AddWord(word, feature_labels);
141  } else {
142  num_added += builder->AddWord(word, feature_labels, possible_labels);
143  }
144  }
145  VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
146  << vocab;
147 }
148 
149 template <class Arc>
150 void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
151  SymbolTable *osyms,
153  std::ifstream in(vocab);
154  if (!in) LOG(FATAL) << "Can't open file: " << vocab;
155  size_t num_line = 0, num_added = 0;
156  std::vector<std::string> fields;
157  std::vector<typename Arc::Label> feature_labels, possible_labels;
158  typename Arc::Label word;
159  while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
160  &feature_labels, &possible_labels, &num_line)) {
161  if (!possible_labels.empty()) {
162  LOG(FATAL)
163  << "Classifier vocabulary should not have possible output constraint";
164  }
165  if (word == kNoLabel) {
166  LOG(WARNING) << "Ignored: boundary word: " << fields[0];
167  continue;
168  }
169  num_added += builder->AddWord(word, feature_labels);
170  }
171  VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
172  << vocab;
173 }
174 
175 // Reads in model file. The first line is an integer designating the
176 // size of future window in the input sequences. After this, each line
177 // is in the following format
178 //
179 // input sequence <whitespace> output sequence <whitespace> weight
180 //
181 // input sequence is a `FST_FLAGS_delimiter`-delimited sequence of feature
182 // labels (see `AddVocab()`) . output sequence is a
183 // `FST_FLAGS_delimiter`-delimited sequence of output labels where the
184 // last label is the output of the feature position before the history
185 // boundary.
186 template <class Arc>
187 void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms,
188  LinearFstDataBuilder<Arc> *builder) {
189  std::ifstream in(model);
190  if (!in) LOG(FATAL) << "Can't open file: " << model;
191  std::string line;
192  std::getline(in, line);
193  if (!in) LOG(FATAL) << "Empty file: " << model;
194  size_t future_size;
195  {
196  std::istringstream strm(line);
197  strm >> future_size;
198  if (!strm) LOG(FATAL) << "Can't read future size: " << model;
199  }
200  size_t num_line = 1, num_added = 0;
201  const int group = builder->AddGroup(future_size);
202  VLOG(1) << "Group " << group << ": from " << model << "; future size is "
203  << future_size << ".";
204  // Add the rest of lines as a single feature group
205  std::vector<std::string> fields;
206  std::vector<typename Arc::Label> input_labels, output_labels;
207  typename Arc::Weight weight;
208  while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
209  &output_labels, &weight, &num_line)) {
210  if (output_labels.empty()) {
211  LOG(FATAL) << "Empty output sequence in source " << model << ", line "
212  << num_line;
213  }
214 
215  const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
217 
218  std::vector<typename Arc::Label> copy_input(input_labels.size()),
219  copy_output(output_labels.size());
220  for (int i = 0; i < 2; ++i) {
221  for (int j = 0; j < 2; ++j) {
222  size_t num_input_changes =
223  ReplaceCopy(input_labels.begin(), input_labels.end(),
224  copy_input.begin(), kNoLabel, marks[i]);
225  size_t num_output_changes =
226  ReplaceCopy(output_labels.begin(), output_labels.end(),
227  copy_output.begin(), kNoLabel, marks[j]);
228  if ((num_input_changes > 0 || i == 0) &&
229  (num_output_changes > 0 || j == 0)) {
230  num_added +=
231  builder->AddWeight(group, copy_input, copy_output, weight);
232  }
233  }
234  }
235  }
236  VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
237  << num_line << " lines.";
238 }
239 
240 template <class Arc>
241 void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms,
243  std::ifstream in(model);
244  if (!in) LOG(FATAL) << "Can't open file: " << model;
245  std::string line;
246  std::getline(in, line);
247  if (!in) LOG(FATAL) << "Empty file: " << model;
248  size_t future_size;
249  {
250  std::istringstream strm(line);
251  strm >> future_size;
252  if (!strm) LOG(FATAL) << "Can't read future size: " << model;
253  }
254  if (future_size != 0) {
255  LOG(FATAL) << "Classifier model must have future size = 0; got "
256  << future_size << " from " << model;
257  }
258  size_t num_line = 1, num_added = 0;
259  const int group = builder->AddGroup();
260  VLOG(1) << "Group " << group << ": from " << model << "; future size is "
261  << future_size << ".";
262  // Add the rest of lines as a single feature group
263  std::vector<std::string> fields;
264  std::vector<typename Arc::Label> input_labels, output_labels;
265  typename Arc::Weight weight;
266  while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
267  &output_labels, &weight, &num_line)) {
268  if (output_labels.size() != 1) {
269  LOG(FATAL) << "Output not a single label in source " << model << ", line "
270  << num_line;
271  }
272 
273  const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
275 
276  typename Arc::Label pred = output_labels[0];
277 
278  std::vector<typename Arc::Label> copy_input(input_labels.size());
279  for (int i = 0; i < 2; ++i) {
280  size_t num_input_changes =
281  ReplaceCopy(input_labels.begin(), input_labels.end(),
282  copy_input.begin(), kNoLabel, marks[i]);
283  if (num_input_changes > 0 || i == 0) {
284  num_added += builder->AddWeight(group, copy_input, pred, weight);
285  }
286  }
287  }
288  VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
289  << num_line << " lines.";
290 }
291 
292 void SplitByWhitespace(const std::string &str, std::vector<std::string> *out);
293 int ScanNumClasses(char **models, int models_length);
294 
295 template <class Arc>
297  const std::string &epsilon_symbol = std::get<0>(*args);
298  const std::string &unknown_symbol = std::get<1>(*args);
299  const std::string &vocab = std::get<2>(*args);
300  char **models = std::get<3>(*args);
301  const int models_length = std::get<4>(*args);
302  const std::string &out = std::get<5>(*args);
303  const std::string &save_isymbols = std::get<6>(*args);
304  const std::string &save_fsymbols = std::get<7>(*args);
305  const std::string &save_osymbols = std::get<8>(*args);
306 
307  SymbolTable isyms, // input (e.g. word tokens)
308  osyms, // output (e.g. tags)
309  fsyms; // feature (e.g. word identity, suffix, etc.)
310  isyms.AddSymbol(epsilon_symbol);
311  osyms.AddSymbol(epsilon_symbol);
312  fsyms.AddSymbol(epsilon_symbol);
313  isyms.AddSymbol(unknown_symbol);
314 
315  VLOG(1) << "start-of-sentence label is "
317  VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
318 
319  if (FST_FLAGS_classifier) {
320  int num_classes = ScanNumClasses(models, models_length);
321  LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
322  &osyms);
323 
324  AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
325  for (int i = 0; i < models_length; ++i) {
326  AddModel(models[i], &fsyms, &osyms, &builder);
327  }
328  LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
329  fst.Write(out);
330  } else {
331  LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
332  AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
333  for (int i = 0; i < models_length; ++i) {
334  AddModel(models[i], &fsyms, &osyms, &builder);
335  }
336  LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
337  fst.Write(out);
338  }
339 
340  if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
341  if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
342  if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
343 }
344 
345 void LinearCompile(const std::string &arc_type,
346  const std::string &epsilon_symbol,
347  const std::string &unknown_symbol, const std::string &vocab,
348  char **models, int models_len, const std::string &out,
349  const std::string &save_isymbols,
350  const std::string &save_fsymbols,
351  const std::string &save_osymbols);
352 
353 template <class Arc>
354 bool GetVocabRecord(const std::string &vocab, std::istream &strm,
355  SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
356  typename Arc::Label *word,
357  std::vector<typename Arc::Label> *feature_labels,
358  std::vector<typename Arc::Label> *possible_labels,
359  size_t *num_line) {
360  std::string line;
361  if (!std::getline(strm, line)) return false;
362  ++(*num_line);
363 
364  std::vector<std::string> fields;
365  SplitByWhitespace(line, &fields);
366  if (fields.size() != 3) {
367  LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
368  << num_line;
369  }
370 
371  feature_labels->clear();
372  possible_labels->clear();
373 
374  *word = LookUp<Arc>(fields[0], isyms);
375 
376  const char delim = FST_FLAGS_delimiter[0];
377  SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
378  SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
379 
380  return true;
381 }
382 
383 template <class Arc>
384 bool GetModelRecord(const std::string &model, std::istream &strm,
385  SymbolTable *fsyms, SymbolTable *osyms,
386  std::vector<typename Arc::Label> *input_labels,
387  std::vector<typename Arc::Label> *output_labels,
388  typename Arc::Weight *weight, size_t *num_line) {
389  std::string line;
390  if (!std::getline(strm, line)) return false;
391  ++(*num_line);
392 
393  std::vector<std::string> fields;
394  SplitByWhitespace(line, &fields);
395  if (fields.size() != 3) {
396  LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
397  << num_line;
398  }
399 
400  input_labels->clear();
401  output_labels->clear();
402 
403  const char delim = FST_FLAGS_delimiter[0];
404  SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
405  SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
406 
407  *weight = StrToWeight<typename Arc::Weight>(fields[2]);
408 
409  GuessStartOrEnd<Arc>(input_labels, kNoLabel);
410  GuessStartOrEnd<Arc>(output_labels, kNoLabel);
411 
412  return true;
413 }
414 
415 } // namespace script
416 } // namespace fst
417 
418 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
int64_t AddSymbol(std::string_view symbol, int64_t key)
Definition: symbol-table.h:421
DECLARE_string(delimiter)
void SplitByWhitespace(const std::string &str, std::vector< std::string > *out)
Definition: linearscript.cc:75
constexpr int kNoLabel
Definition: fst.h:201
std::tuple< const std::string &, const std::string &, const std::string &, char **, int, const std::string &, const std::string &, const std::string &, const std::string & > LinearCompileArgs
Definition: linearscript.h:46
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
#define LOG(type)
Definition: log.h:49
bool AddWord(Label word, const std::vector< Label > &features)
bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
int AddGroup(size_t future_size)
bool GetModelRecord(const std::string &model, std::istream &strm, SymbolTable *fsyms, SymbolTable *osyms, std::vector< typename Arc::Label > *input_labels, std::vector< typename Arc::Label > *output_labels, typename Arc::Weight *weight, size_t *num_line)
Definition: linearscript.h:384
bool AddWord(Label word, const std::vector< Label > &features)
bool ValidateEmptySymbol()
Definition: linearscript.cc:51
DECLARE_bool(classifier)
#define VLOG(level)
Definition: log.h:50
int ScanNumClasses(char **models, int models_len)
Definition: linearscript.cc:82
void LinearCompile(const std::string &arc_type, const std::string &epsilon_symbol, const std::string &unknown_symbol, const std::string &vocab, char **models, int models_len, const std::string &out, const std::string &save_isymbols, const std::string &save_fsymbols, const std::string &save_osymbols)
Definition: linearscript.cc:60
void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
Definition: linearscript.h:125
bool GetVocabRecord(const std::string &vocab, std::istream &strm, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, typename Arc::Label *word, std::vector< typename Arc::Label > *feature_labels, std::vector< typename Arc::Label > *possible_labels, size_t *num_line)
Definition: linearscript.h:354
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
Definition: linearscript.h:187
size_t ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result, const T &old_value, const T &new_value)
Definition: linearscript.h:86
void SplitAndPush(const std::string &str, const char delim, SymbolTable *syms, std::vector< typename Arc::Label > *output)
Definition: linearscript.h:74
Arc::Label LookUp(const std::string &str, SymbolTable *syms)
Definition: linearscript.h:59
bool ValidateDelimiter()
Definition: linearscript.cc:46
void LinearCompileTpl(LinearCompileArgs *args)
Definition: linearscript.h:296