18 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ 19 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ 50 std::tuple<
const std::string &,
const std::string &,
const std::string &,
51 char **, int,
const std::string &,
const std::string &,
52 const std::string &,
const std::string &>;
66 if (str == FST_FLAGS_start_symbol) {
67 return str == FST_FLAGS_end_symbol
70 }
else if (str == FST_FLAGS_end_symbol) {
81 std::vector<typename Arc::Label> *output) {
82 if (str == FST_FLAGS_empty_symbol)
return;
83 std::istringstream strm(str);
85 while (std::getline(strm, buf, delim)) {
86 output->push_back(LookUp<Arc>(buf, syms));
91 template <
class InputIterator,
class OutputIterator,
class T>
93 OutputIterator result,
const T &old_value,
96 while (first != last) {
97 if (*first == old_value) {
112 typename Arc::Label *word,
113 std::vector<typename Arc::Label> *feature_labels,
114 std::vector<typename Arc::Label> *possible_labels,
120 std::vector<typename Arc::Label> *input_labels,
121 std::vector<typename Arc::Label> *output_labels,
122 typename Arc::Weight *weight,
size_t *num_line);
133 std::ifstream in(vocab);
134 if (!in)
LOG(FATAL) <<
"Can't open file: " << vocab;
135 size_t num_line = 0, num_added = 0;
136 std::vector<std::string> fields;
137 std::vector<typename Arc::Label> feature_labels, possible_labels;
138 typename Arc::Label word;
139 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
140 &feature_labels, &possible_labels, &num_line)) {
142 LOG(WARNING) <<
"Ignored: boundary word: " << fields[0];
145 if (possible_labels.empty()) {
146 num_added += builder->
AddWord(word, feature_labels);
148 num_added += builder->
AddWord(word, feature_labels, possible_labels);
151 VLOG(1) <<
"Read " << num_added <<
" words in " << num_line <<
" lines from " 159 std::ifstream in(vocab);
160 if (!in)
LOG(FATAL) <<
"Can't open file: " << vocab;
161 size_t num_line = 0, num_added = 0;
162 std::vector<std::string> fields;
163 std::vector<typename Arc::Label> feature_labels, possible_labels;
164 typename Arc::Label word;
165 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
166 &feature_labels, &possible_labels, &num_line)) {
167 if (!possible_labels.empty()) {
169 <<
"Classifier vocabulary should not have possible output constraint";
172 LOG(WARNING) <<
"Ignored: boundary word: " << fields[0];
175 num_added += builder->
AddWord(word, feature_labels);
177 VLOG(1) <<
"Read " << num_added <<
" words in " << num_line <<
" lines from " 195 std::ifstream in(model);
196 if (!in)
LOG(FATAL) <<
"Can't open file: " << model;
198 std::getline(in, line);
199 if (!in)
LOG(FATAL) <<
"Empty file: " << model;
202 std::istringstream strm(line);
204 if (!strm)
LOG(FATAL) <<
"Can't read future size: " << model;
206 size_t num_line = 1, num_added = 0;
207 const int group = builder->
AddGroup(future_size);
208 VLOG(1) <<
"Group " << group <<
": from " << model <<
"; future size is " 209 << future_size <<
".";
211 std::vector<std::string> fields;
212 std::vector<typename Arc::Label> input_labels, output_labels;
213 typename Arc::Weight weight;
214 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
215 &output_labels, &weight, &num_line)) {
216 if (output_labels.empty()) {
217 LOG(FATAL) <<
"Empty output sequence in source " << model <<
", line " 224 std::vector<typename Arc::Label> copy_input(input_labels.size()),
225 copy_output(output_labels.size());
226 for (
int i = 0; i < 2; ++i) {
227 for (
int j = 0; j < 2; ++j) {
228 size_t num_input_changes =
229 ReplaceCopy(input_labels.begin(), input_labels.end(),
230 copy_input.begin(),
kNoLabel, marks[i]);
231 size_t num_output_changes =
232 ReplaceCopy(output_labels.begin(), output_labels.end(),
233 copy_output.begin(),
kNoLabel, marks[j]);
234 if ((num_input_changes > 0 || i == 0) &&
235 (num_output_changes > 0 || j == 0)) {
237 builder->
AddWeight(group, copy_input, copy_output, weight);
242 VLOG(1) <<
"Group " << group <<
": read " << num_added <<
" weight(s) in " 243 << num_line <<
" lines.";
249 std::ifstream in(model);
250 if (!in)
LOG(FATAL) <<
"Can't open file: " << model;
252 std::getline(in, line);
253 if (!in)
LOG(FATAL) <<
"Empty file: " << model;
256 std::istringstream strm(line);
258 if (!strm)
LOG(FATAL) <<
"Can't read future size: " << model;
260 if (future_size != 0) {
261 LOG(FATAL) <<
"Classifier model must have future size = 0; got " 262 << future_size <<
" from " << model;
264 size_t num_line = 1, num_added = 0;
265 const int group = builder->
AddGroup();
266 VLOG(1) <<
"Group " << group <<
": from " << model <<
"; future size is " 267 << future_size <<
".";
269 std::vector<std::string> fields;
270 std::vector<typename Arc::Label> input_labels, output_labels;
271 typename Arc::Weight weight;
272 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
273 &output_labels, &weight, &num_line)) {
274 if (output_labels.size() != 1) {
275 LOG(FATAL) <<
"Output not a single label in source " << model <<
", line " 282 typename Arc::Label pred = output_labels[0];
284 std::vector<typename Arc::Label> copy_input(input_labels.size());
285 for (
int i = 0; i < 2; ++i) {
286 size_t num_input_changes =
287 ReplaceCopy(input_labels.begin(), input_labels.end(),
288 copy_input.begin(),
kNoLabel, marks[i]);
289 if (num_input_changes > 0 || i == 0) {
290 num_added += builder->
AddWeight(group, copy_input, pred, weight);
294 VLOG(1) <<
"Group " << group <<
": read " << num_added <<
" weight(s) in " 295 << num_line <<
" lines.";
303 const std::string &epsilon_symbol = std::get<0>(*args);
304 const std::string &unknown_symbol = std::get<1>(*args);
305 const std::string &vocab = std::get<2>(*args);
306 char **models = std::get<3>(*args);
307 const int models_length = std::get<4>(*args);
308 const std::string &out = std::get<5>(*args);
309 const std::string &save_isymbols = std::get<6>(*args);
310 const std::string &save_fsymbols = std::get<7>(*args);
311 const std::string &save_osymbols = std::get<8>(*args);
321 VLOG(1) <<
"start-of-sentence label is " 325 if (FST_FLAGS_classifier) {
330 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
331 for (
int i = 0; i < models_length; ++i) {
332 AddModel(models[i], &fsyms, &osyms, &builder);
338 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
339 for (
int i = 0; i < models_length; ++i) {
340 AddModel(models[i], &fsyms, &osyms, &builder);
346 if (!save_isymbols.empty()) isyms.
WriteText(save_isymbols);
347 if (!save_fsymbols.empty()) fsyms.
WriteText(save_fsymbols);
348 if (!save_osymbols.empty()) osyms.
WriteText(save_osymbols);
352 const std::string &epsilon_symbol,
353 const std::string &unknown_symbol,
const std::string &vocab,
354 char **models,
int models_len,
const std::string &out,
355 const std::string &save_isymbols,
356 const std::string &save_fsymbols,
357 const std::string &save_osymbols);
362 typename Arc::Label *word,
363 std::vector<typename Arc::Label> *feature_labels,
364 std::vector<typename Arc::Label> *possible_labels,
367 if (!std::getline(strm, line))
return false;
370 std::vector<std::string> fields;
372 if (fields.size() != 3) {
373 LOG(FATAL) <<
"Wrong number of fields in source " << vocab <<
", line " 377 feature_labels->clear();
378 possible_labels->clear();
380 *word = LookUp<Arc>(fields[0], isyms);
382 const char delim = FST_FLAGS_delimiter[0];
383 SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
384 SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
392 std::vector<typename Arc::Label> *input_labels,
393 std::vector<typename Arc::Label> *output_labels,
394 typename Arc::Weight *weight,
size_t *num_line) {
396 if (!std::getline(strm, line))
return false;
399 std::vector<std::string> fields;
401 if (fields.size() != 3) {
402 LOG(FATAL) <<
"Wrong number of fields in source " << model <<
", line " 406 input_labels->clear();
407 output_labels->clear();
409 const char delim = FST_FLAGS_delimiter[0];
410 SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
411 SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
413 *weight = StrToWeight<typename Arc::Weight>(fields[2]);
415 GuessStartOrEnd<Arc>(input_labels,
kNoLabel);
416 GuessStartOrEnd<Arc>(output_labels,
kNoLabel);
424 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ bool WriteText(std::ostream &strm, std::string_view sep=FST_FLAGS_fst_field_separator) const
int64_t AddSymbol(std::string_view symbol, int64_t key)
DECLARE_string(delimiter)
void SplitByWhitespace(const std::string &str, std::vector< std::string > *out)
std::tuple< const std::string &, const std::string &, const std::string &, char **, int, const std::string &, const std::string &, const std::string &, const std::string & > LinearCompileArgs
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
bool AddWord(Label word, const std::vector< Label > &features)
LinearFstData< A > * Dump()
int AddGroup(size_t future_size)
bool GetModelRecord(const std::string &model, std::istream &strm, SymbolTable *fsyms, SymbolTable *osyms, std::vector< typename Arc::Label > *input_labels, std::vector< typename Arc::Label > *output_labels, typename Arc::Weight *weight, size_t *num_line)
bool AddWord(Label word, const std::vector< Label > &features)
bool ValidateEmptySymbol()
int ScanNumClasses(char **models, int models_len)
void LinearCompile(const std::string &arc_type, const std::string &epsilon_symbol, const std::string &unknown_symbol, const std::string &vocab, char **models, int models_len, const std::string &out, const std::string &save_isymbols, const std::string &save_fsymbols, const std::string &save_osymbols)
void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
bool GetVocabRecord(const std::string &vocab, std::istream &strm, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, typename Arc::Label *word, std::vector< typename Arc::Label > *feature_labels, std::vector< typename Arc::Label > *possible_labels, size_t *num_line)
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
size_t ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result, const T &old_value, const T &new_value)
void SplitAndPush(const std::string &str, const char delim, SymbolTable *syms, std::vector< typename Arc::Label > *output)
Arc::Label LookUp(const std::string &str, SymbolTable *syms)
void LinearCompileTpl(LinearCompileArgs *args)
LinearFstData< A > * Dump()