FST  openfst-1.8.3
OpenFst Library
linearscript.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
19 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
20 
21 #include <cstddef>
22 #include <istream>
23 #include <sstream>
24 #include <string>
25 #include <tuple>
26 #include <vector>
27 
28 #include <fst/compat.h>
29 #include <fst/log.h>
33 #include <fstream>
34 #include <fst/fst.h>
35 #include <fst/symbol-table.h>
36 #include <fst/util.h>
37 #include <fst/script/arg-packs.h>
38 #include <fst/script/script-impl.h>
39 
40 DECLARE_string(delimiter);
41 DECLARE_string(empty_symbol);
42 DECLARE_string(start_symbol);
43 DECLARE_string(end_symbol);
44 DECLARE_bool(classifier);
45 
46 namespace fst {
47 namespace script {
48 
49 using LinearCompileArgs =
50  std::tuple<const std::string &, const std::string &, const std::string &,
51  char **, int, const std::string &, const std::string &,
52  const std::string &, const std::string &>;
53 
54 bool ValidateDelimiter();
55 
56 bool ValidateEmptySymbol();
57 
58 // Returns the proper label given the symbol. For symbols other than
59 // `FST_FLAGS_start_symbol` or `FST_FLAGS_end_symbol`, looks up the symbol
60 // table to decide the label. Depending on whether
61 // `FST_FLAGS_start_symbol` and `FST_FLAGS_end_symbol` are identical, it
62 // either returns `kNoLabel` for later processing or decides the label
63 // right away.
64 template <class Arc>
65 inline typename Arc::Label LookUp(const std::string &str, SymbolTable *syms) {
66  if (str == FST_FLAGS_start_symbol) {
67  return str == FST_FLAGS_end_symbol
68  ? kNoLabel
70  } else if (str == FST_FLAGS_end_symbol) {
72  } else {
73  return syms->AddSymbol(str);
74  }
75 }
76 
77 // Splits `str` with `delim` as the delimiter and stores the labels in
78 // `output`.
79 template <class Arc>
80 void SplitAndPush(const std::string &str, const char delim, SymbolTable *syms,
81  std::vector<typename Arc::Label> *output) {
82  if (str == FST_FLAGS_empty_symbol) return;
83  std::istringstream strm(str);
84  std::string buf;
85  while (std::getline(strm, buf, delim)) {
86  output->push_back(LookUp<Arc>(buf, syms));
87  }
88 }
89 
90 // Like `std::replace_copy` but returns the number of modifications
91 template <class InputIterator, class OutputIterator, class T>
92 size_t ReplaceCopy(InputIterator first, InputIterator last,
93  OutputIterator result, const T &old_value,
94  const T &new_value) {
95  size_t changes = 0;
96  while (first != last) {
97  if (*first == old_value) {
98  *result = new_value;
99  ++changes;
100  } else {
101  *result = *first;
102  }
103  ++first;
104  ++result;
105  }
106  return changes;
107 }
108 
109 template <class Arc>
110 bool GetVocabRecord(const std::string &vocab, std::istream &strm,
111  SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
112  typename Arc::Label *word,
113  std::vector<typename Arc::Label> *feature_labels,
114  std::vector<typename Arc::Label> *possible_labels,
115  size_t *num_line);
116 
117 template <class Arc>
118 bool GetModelRecord(const std::string &model, std::istream &strm,
119  SymbolTable *fsyms, SymbolTable *osyms,
120  std::vector<typename Arc::Label> *input_labels,
121  std::vector<typename Arc::Label> *output_labels,
122  typename Arc::Weight *weight, size_t *num_line);
123 
124 // Reads in vocabulary file. Each line is in the following format
125 //
126 // word <whitespace> features [ <whitespace> possible output ]
127 //
128 // where features and possible output are `FST_FLAGS_delimiter`-delimited lists of
129 // tokens
130 template <class Arc>
131 void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
132  SymbolTable *osyms, LinearFstDataBuilder<Arc> *builder) {
133  std::ifstream in(vocab);
134  if (!in) LOG(FATAL) << "Can't open file: " << vocab;
135  size_t num_line = 0, num_added = 0;
136  std::vector<std::string> fields;
137  std::vector<typename Arc::Label> feature_labels, possible_labels;
138  typename Arc::Label word;
139  while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
140  &feature_labels, &possible_labels, &num_line)) {
141  if (word == kNoLabel) {
142  LOG(WARNING) << "Ignored: boundary word: " << fields[0];
143  continue;
144  }
145  if (possible_labels.empty()) {
146  num_added += builder->AddWord(word, feature_labels);
147  } else {
148  num_added += builder->AddWord(word, feature_labels, possible_labels);
149  }
150  }
151  VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
152  << vocab;
153 }
154 
155 template <class Arc>
156 void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms,
157  SymbolTable *osyms,
159  std::ifstream in(vocab);
160  if (!in) LOG(FATAL) << "Can't open file: " << vocab;
161  size_t num_line = 0, num_added = 0;
162  std::vector<std::string> fields;
163  std::vector<typename Arc::Label> feature_labels, possible_labels;
164  typename Arc::Label word;
165  while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
166  &feature_labels, &possible_labels, &num_line)) {
167  if (!possible_labels.empty()) {
168  LOG(FATAL)
169  << "Classifier vocabulary should not have possible output constraint";
170  }
171  if (word == kNoLabel) {
172  LOG(WARNING) << "Ignored: boundary word: " << fields[0];
173  continue;
174  }
175  num_added += builder->AddWord(word, feature_labels);
176  }
177  VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from "
178  << vocab;
179 }
180 
181 // Reads in model file. The first line is an integer designating the
182 // size of future window in the input sequences. After this, each line
183 // is in the following format
184 //
185 // input sequence <whitespace> output sequence <whitespace> weight
186 //
187 // input sequence is a `FST_FLAGS_delimiter`-delimited sequence of feature
188 // labels (see `AddVocab()`) . output sequence is a
189 // `FST_FLAGS_delimiter`-delimited sequence of output labels where the
190 // last label is the output of the feature position before the history
191 // boundary.
192 template <class Arc>
193 void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms,
194  LinearFstDataBuilder<Arc> *builder) {
195  std::ifstream in(model);
196  if (!in) LOG(FATAL) << "Can't open file: " << model;
197  std::string line;
198  std::getline(in, line);
199  if (!in) LOG(FATAL) << "Empty file: " << model;
200  size_t future_size;
201  {
202  std::istringstream strm(line);
203  strm >> future_size;
204  if (!strm) LOG(FATAL) << "Can't read future size: " << model;
205  }
206  size_t num_line = 1, num_added = 0;
207  const int group = builder->AddGroup(future_size);
208  VLOG(1) << "Group " << group << ": from " << model << "; future size is "
209  << future_size << ".";
210  // Add the rest of lines as a single feature group
211  std::vector<std::string> fields;
212  std::vector<typename Arc::Label> input_labels, output_labels;
213  typename Arc::Weight weight;
214  while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
215  &output_labels, &weight, &num_line)) {
216  if (output_labels.empty()) {
217  LOG(FATAL) << "Empty output sequence in source " << model << ", line "
218  << num_line;
219  }
220 
221  const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
223 
224  std::vector<typename Arc::Label> copy_input(input_labels.size()),
225  copy_output(output_labels.size());
226  for (int i = 0; i < 2; ++i) {
227  for (int j = 0; j < 2; ++j) {
228  size_t num_input_changes =
229  ReplaceCopy(input_labels.begin(), input_labels.end(),
230  copy_input.begin(), kNoLabel, marks[i]);
231  size_t num_output_changes =
232  ReplaceCopy(output_labels.begin(), output_labels.end(),
233  copy_output.begin(), kNoLabel, marks[j]);
234  if ((num_input_changes > 0 || i == 0) &&
235  (num_output_changes > 0 || j == 0)) {
236  num_added +=
237  builder->AddWeight(group, copy_input, copy_output, weight);
238  }
239  }
240  }
241  }
242  VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
243  << num_line << " lines.";
244 }
245 
246 template <class Arc>
247 void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms,
249  std::ifstream in(model);
250  if (!in) LOG(FATAL) << "Can't open file: " << model;
251  std::string line;
252  std::getline(in, line);
253  if (!in) LOG(FATAL) << "Empty file: " << model;
254  size_t future_size;
255  {
256  std::istringstream strm(line);
257  strm >> future_size;
258  if (!strm) LOG(FATAL) << "Can't read future size: " << model;
259  }
260  if (future_size != 0) {
261  LOG(FATAL) << "Classifier model must have future size = 0; got "
262  << future_size << " from " << model;
263  }
264  size_t num_line = 1, num_added = 0;
265  const int group = builder->AddGroup();
266  VLOG(1) << "Group " << group << ": from " << model << "; future size is "
267  << future_size << ".";
268  // Add the rest of lines as a single feature group
269  std::vector<std::string> fields;
270  std::vector<typename Arc::Label> input_labels, output_labels;
271  typename Arc::Weight weight;
272  while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
273  &output_labels, &weight, &num_line)) {
274  if (output_labels.size() != 1) {
275  LOG(FATAL) << "Output not a single label in source " << model << ", line "
276  << num_line;
277  }
278 
279  const typename Arc::Label marks[] = {LinearFstData<Arc>::kStartOfSentence,
281 
282  typename Arc::Label pred = output_labels[0];
283 
284  std::vector<typename Arc::Label> copy_input(input_labels.size());
285  for (int i = 0; i < 2; ++i) {
286  size_t num_input_changes =
287  ReplaceCopy(input_labels.begin(), input_labels.end(),
288  copy_input.begin(), kNoLabel, marks[i]);
289  if (num_input_changes > 0 || i == 0) {
290  num_added += builder->AddWeight(group, copy_input, pred, weight);
291  }
292  }
293  }
294  VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in "
295  << num_line << " lines.";
296 }
297 
298 void SplitByWhitespace(const std::string &str, std::vector<std::string> *out);
299 int ScanNumClasses(char **models, int models_length);
300 
301 template <class Arc>
303  const std::string &epsilon_symbol = std::get<0>(*args);
304  const std::string &unknown_symbol = std::get<1>(*args);
305  const std::string &vocab = std::get<2>(*args);
306  char **models = std::get<3>(*args);
307  const int models_length = std::get<4>(*args);
308  const std::string &out = std::get<5>(*args);
309  const std::string &save_isymbols = std::get<6>(*args);
310  const std::string &save_fsymbols = std::get<7>(*args);
311  const std::string &save_osymbols = std::get<8>(*args);
312 
313  SymbolTable isyms, // input (e.g. word tokens)
314  osyms, // output (e.g. tags)
315  fsyms; // feature (e.g. word identity, suffix, etc.)
316  isyms.AddSymbol(epsilon_symbol);
317  osyms.AddSymbol(epsilon_symbol);
318  fsyms.AddSymbol(epsilon_symbol);
319  isyms.AddSymbol(unknown_symbol);
320 
321  VLOG(1) << "start-of-sentence label is "
323  VLOG(1) << "end-of-sentence label is " << LinearFstData<Arc>::kEndOfSentence;
324 
325  if (FST_FLAGS_classifier) {
326  int num_classes = ScanNumClasses(models, models_length);
327  LinearClassifierFstDataBuilder<Arc> builder(num_classes, &isyms, &fsyms,
328  &osyms);
329 
330  AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
331  for (int i = 0; i < models_length; ++i) {
332  AddModel(models[i], &fsyms, &osyms, &builder);
333  }
334  LinearClassifierFst<Arc> fst(builder.Dump(), num_classes, &isyms, &osyms);
335  fst.Write(out);
336  } else {
337  LinearFstDataBuilder<Arc> builder(&isyms, &fsyms, &osyms);
338  AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
339  for (int i = 0; i < models_length; ++i) {
340  AddModel(models[i], &fsyms, &osyms, &builder);
341  }
342  LinearTaggerFst<Arc> fst(builder.Dump(), &isyms, &osyms);
343  fst.Write(out);
344  }
345 
346  if (!save_isymbols.empty()) isyms.WriteText(save_isymbols);
347  if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols);
348  if (!save_osymbols.empty()) osyms.WriteText(save_osymbols);
349 }
350 
351 void LinearCompile(const std::string &arc_type,
352  const std::string &epsilon_symbol,
353  const std::string &unknown_symbol, const std::string &vocab,
354  char **models, int models_len, const std::string &out,
355  const std::string &save_isymbols,
356  const std::string &save_fsymbols,
357  const std::string &save_osymbols);
358 
359 template <class Arc>
360 bool GetVocabRecord(const std::string &vocab, std::istream &strm,
361  SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms,
362  typename Arc::Label *word,
363  std::vector<typename Arc::Label> *feature_labels,
364  std::vector<typename Arc::Label> *possible_labels,
365  size_t *num_line) {
366  std::string line;
367  if (!std::getline(strm, line)) return false;
368  ++(*num_line);
369 
370  std::vector<std::string> fields;
371  SplitByWhitespace(line, &fields);
372  if (fields.size() != 3) {
373  LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line "
374  << num_line;
375  }
376 
377  feature_labels->clear();
378  possible_labels->clear();
379 
380  *word = LookUp<Arc>(fields[0], isyms);
381 
382  const char delim = FST_FLAGS_delimiter[0];
383  SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
384  SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
385 
386  return true;
387 }
388 
389 template <class Arc>
390 bool GetModelRecord(const std::string &model, std::istream &strm,
391  SymbolTable *fsyms, SymbolTable *osyms,
392  std::vector<typename Arc::Label> *input_labels,
393  std::vector<typename Arc::Label> *output_labels,
394  typename Arc::Weight *weight, size_t *num_line) {
395  std::string line;
396  if (!std::getline(strm, line)) return false;
397  ++(*num_line);
398 
399  std::vector<std::string> fields;
400  SplitByWhitespace(line, &fields);
401  if (fields.size() != 3) {
402  LOG(FATAL) << "Wrong number of fields in source " << model << ", line "
403  << num_line;
404  }
405 
406  input_labels->clear();
407  output_labels->clear();
408 
409  const char delim = FST_FLAGS_delimiter[0];
410  SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
411  SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
412 
413  *weight = StrToWeight<typename Arc::Weight>(fields[2]);
414 
415  GuessStartOrEnd<Arc>(input_labels, kNoLabel);
416  GuessStartOrEnd<Arc>(output_labels, kNoLabel);
417 
418  return true;
419 }
420 
421 } // namespace script
422 } // namespace fst
423 
424 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_
int64_t AddSymbol(std::string_view symbol, int64_t key)
Definition: symbol-table.h:425
DECLARE_string(delimiter)
void SplitByWhitespace(const std::string &str, std::vector< std::string > *out)
Definition: linearscript.cc:83
constexpr int kNoLabel
Definition: fst.h:195
std::tuple< const std::string &, const std::string &, const std::string &, char **, int, const std::string &, const std::string &, const std::string &, const std::string & > LinearCompileArgs
Definition: linearscript.h:52
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
#define LOG(type)
Definition: log.h:53
bool AddWord(Label word, const std::vector< Label > &features)
int AddGroup(size_t future_size)
bool GetModelRecord(const std::string &model, std::istream &strm, SymbolTable *fsyms, SymbolTable *osyms, std::vector< typename Arc::Label > *input_labels, std::vector< typename Arc::Label > *output_labels, typename Arc::Weight *weight, size_t *num_line)
Definition: linearscript.h:390
bool AddWord(Label word, const std::vector< Label > &features)
bool ValidateEmptySymbol()
Definition: linearscript.cc:59
DECLARE_bool(classifier)
#define VLOG(level)
Definition: log.h:54
int ScanNumClasses(char **models, int models_len)
Definition: linearscript.cc:90
void LinearCompile(const std::string &arc_type, const std::string &epsilon_symbol, const std::string &unknown_symbol, const std::string &vocab, char **models, int models_len, const std::string &out, const std::string &save_isymbols, const std::string &save_fsymbols, const std::string &save_osymbols)
Definition: linearscript.cc:68
void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
Definition: linearscript.h:131
bool GetVocabRecord(const std::string &vocab, std::istream &strm, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, typename Arc::Label *word, std::vector< typename Arc::Label > *feature_labels, std::vector< typename Arc::Label > *possible_labels, size_t *num_line)
Definition: linearscript.h:360
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
Definition: linearscript.h:193
size_t ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result, const T &old_value, const T &new_value)
Definition: linearscript.h:92
bool WriteText(std::ostream &strm, const std::string &sep=FST_FLAGS_fst_field_separator) const
void SplitAndPush(const std::string &str, const char delim, SymbolTable *syms, std::vector< typename Arc::Label > *output)
Definition: linearscript.h:80
Arc::Label LookUp(const std::string &str, SymbolTable *syms)
Definition: linearscript.h:65
bool ValidateDelimiter()
Definition: linearscript.cc:54
void LinearCompileTpl(LinearCompileArgs *args)
Definition: linearscript.h:302