18 #ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ 19 #define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ 44 std::tuple<
const std::string &,
const std::string &,
const std::string &,
45 char **, int,
const std::string &,
const std::string &,
46 const std::string &,
const std::string &>;
60 if (str == FST_FLAGS_start_symbol) {
61 return str == FST_FLAGS_end_symbol
64 }
else if (str == FST_FLAGS_end_symbol) {
75 std::vector<typename Arc::Label> *output) {
76 if (str == FST_FLAGS_empty_symbol)
return;
77 std::istringstream strm(str);
79 while (std::getline(strm, buf, delim)) {
80 output->push_back(LookUp<Arc>(buf, syms));
85 template <
class InputIterator,
class OutputIterator,
class T>
87 OutputIterator result,
const T &old_value,
90 while (first != last) {
91 if (*first == old_value) {
106 typename Arc::Label *word,
107 std::vector<typename Arc::Label> *feature_labels,
108 std::vector<typename Arc::Label> *possible_labels,
114 std::vector<typename Arc::Label> *input_labels,
115 std::vector<typename Arc::Label> *output_labels,
116 typename Arc::Weight *weight,
size_t *num_line);
127 std::ifstream in(vocab);
128 if (!in)
LOG(FATAL) <<
"Can't open file: " << vocab;
129 size_t num_line = 0, num_added = 0;
130 std::vector<std::string> fields;
131 std::vector<typename Arc::Label> feature_labels, possible_labels;
132 typename Arc::Label word;
133 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
134 &feature_labels, &possible_labels, &num_line)) {
136 LOG(WARNING) <<
"Ignored: boundary word: " << fields[0];
139 if (possible_labels.empty()) {
140 num_added += builder->
AddWord(word, feature_labels);
142 num_added += builder->
AddWord(word, feature_labels, possible_labels);
145 VLOG(1) <<
"Read " << num_added <<
" words in " << num_line <<
" lines from " 153 std::ifstream in(vocab);
154 if (!in)
LOG(FATAL) <<
"Can't open file: " << vocab;
155 size_t num_line = 0, num_added = 0;
156 std::vector<std::string> fields;
157 std::vector<typename Arc::Label> feature_labels, possible_labels;
158 typename Arc::Label word;
159 while (GetVocabRecord<Arc>(vocab, in, isyms, fsyms, osyms, &word,
160 &feature_labels, &possible_labels, &num_line)) {
161 if (!possible_labels.empty()) {
163 <<
"Classifier vocabulary should not have possible output constraint";
166 LOG(WARNING) <<
"Ignored: boundary word: " << fields[0];
169 num_added += builder->
AddWord(word, feature_labels);
171 VLOG(1) <<
"Read " << num_added <<
" words in " << num_line <<
" lines from " 189 std::ifstream in(model);
190 if (!in)
LOG(FATAL) <<
"Can't open file: " << model;
192 std::getline(in, line);
193 if (!in)
LOG(FATAL) <<
"Empty file: " << model;
196 std::istringstream strm(line);
198 if (!strm)
LOG(FATAL) <<
"Can't read future size: " << model;
200 size_t num_line = 1, num_added = 0;
201 const int group = builder->
AddGroup(future_size);
202 VLOG(1) <<
"Group " << group <<
": from " << model <<
"; future size is " 203 << future_size <<
".";
205 std::vector<std::string> fields;
206 std::vector<typename Arc::Label> input_labels, output_labels;
207 typename Arc::Weight weight;
208 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
209 &output_labels, &weight, &num_line)) {
210 if (output_labels.empty()) {
211 LOG(FATAL) <<
"Empty output sequence in source " << model <<
", line " 218 std::vector<typename Arc::Label> copy_input(input_labels.size()),
219 copy_output(output_labels.size());
220 for (
int i = 0; i < 2; ++i) {
221 for (
int j = 0; j < 2; ++j) {
222 size_t num_input_changes =
223 ReplaceCopy(input_labels.begin(), input_labels.end(),
224 copy_input.begin(),
kNoLabel, marks[i]);
225 size_t num_output_changes =
226 ReplaceCopy(output_labels.begin(), output_labels.end(),
227 copy_output.begin(),
kNoLabel, marks[j]);
228 if ((num_input_changes > 0 || i == 0) &&
229 (num_output_changes > 0 || j == 0)) {
231 builder->
AddWeight(group, copy_input, copy_output, weight);
236 VLOG(1) <<
"Group " << group <<
": read " << num_added <<
" weight(s) in " 237 << num_line <<
" lines.";
243 std::ifstream in(model);
244 if (!in)
LOG(FATAL) <<
"Can't open file: " << model;
246 std::getline(in, line);
247 if (!in)
LOG(FATAL) <<
"Empty file: " << model;
250 std::istringstream strm(line);
252 if (!strm)
LOG(FATAL) <<
"Can't read future size: " << model;
254 if (future_size != 0) {
255 LOG(FATAL) <<
"Classifier model must have future size = 0; got " 256 << future_size <<
" from " << model;
258 size_t num_line = 1, num_added = 0;
259 const int group = builder->
AddGroup();
260 VLOG(1) <<
"Group " << group <<
": from " << model <<
"; future size is " 261 << future_size <<
".";
263 std::vector<std::string> fields;
264 std::vector<typename Arc::Label> input_labels, output_labels;
265 typename Arc::Weight weight;
266 while (GetModelRecord<Arc>(model, in, fsyms, osyms, &input_labels,
267 &output_labels, &weight, &num_line)) {
268 if (output_labels.size() != 1) {
269 LOG(FATAL) <<
"Output not a single label in source " << model <<
", line " 276 typename Arc::Label pred = output_labels[0];
278 std::vector<typename Arc::Label> copy_input(input_labels.size());
279 for (
int i = 0; i < 2; ++i) {
280 size_t num_input_changes =
281 ReplaceCopy(input_labels.begin(), input_labels.end(),
282 copy_input.begin(),
kNoLabel, marks[i]);
283 if (num_input_changes > 0 || i == 0) {
284 num_added += builder->
AddWeight(group, copy_input, pred, weight);
288 VLOG(1) <<
"Group " << group <<
": read " << num_added <<
" weight(s) in " 289 << num_line <<
" lines.";
297 const std::string &epsilon_symbol = std::get<0>(*args);
298 const std::string &unknown_symbol = std::get<1>(*args);
299 const std::string &vocab = std::get<2>(*args);
300 char **models = std::get<3>(*args);
301 const int models_length = std::get<4>(*args);
302 const std::string &out = std::get<5>(*args);
303 const std::string &save_isymbols = std::get<6>(*args);
304 const std::string &save_fsymbols = std::get<7>(*args);
305 const std::string &save_osymbols = std::get<8>(*args);
315 VLOG(1) <<
"start-of-sentence label is " 319 if (FST_FLAGS_classifier) {
324 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
325 for (
int i = 0; i < models_length; ++i) {
326 AddModel(models[i], &fsyms, &osyms, &builder);
332 AddVocab(vocab, &isyms, &fsyms, &osyms, &builder);
333 for (
int i = 0; i < models_length; ++i) {
334 AddModel(models[i], &fsyms, &osyms, &builder);
340 if (!save_isymbols.empty()) isyms.
WriteText(save_isymbols);
341 if (!save_fsymbols.empty()) fsyms.
WriteText(save_fsymbols);
342 if (!save_osymbols.empty()) osyms.
WriteText(save_osymbols);
346 const std::string &epsilon_symbol,
347 const std::string &unknown_symbol,
const std::string &vocab,
348 char **models,
int models_len,
const std::string &out,
349 const std::string &save_isymbols,
350 const std::string &save_fsymbols,
351 const std::string &save_osymbols);
356 typename Arc::Label *word,
357 std::vector<typename Arc::Label> *feature_labels,
358 std::vector<typename Arc::Label> *possible_labels,
361 if (!std::getline(strm, line))
return false;
364 std::vector<std::string> fields;
366 if (fields.size() != 3) {
367 LOG(FATAL) <<
"Wrong number of fields in source " << vocab <<
", line " 371 feature_labels->clear();
372 possible_labels->clear();
374 *word = LookUp<Arc>(fields[0], isyms);
376 const char delim = FST_FLAGS_delimiter[0];
377 SplitAndPush<Arc>(fields[1], delim, fsyms, feature_labels);
378 SplitAndPush<Arc>(fields[2], delim, osyms, possible_labels);
386 std::vector<typename Arc::Label> *input_labels,
387 std::vector<typename Arc::Label> *output_labels,
388 typename Arc::Weight *weight,
size_t *num_line) {
390 if (!std::getline(strm, line))
return false;
393 std::vector<std::string> fields;
395 if (fields.size() != 3) {
396 LOG(FATAL) <<
"Wrong number of fields in source " << model <<
", line " 400 input_labels->clear();
401 output_labels->clear();
403 const char delim = FST_FLAGS_delimiter[0];
404 SplitAndPush<Arc>(fields[0], delim, fsyms, input_labels);
405 SplitAndPush<Arc>(fields[1], delim, osyms, output_labels);
407 *weight = StrToWeight<typename Arc::Weight>(fields[2]);
409 GuessStartOrEnd<Arc>(input_labels,
kNoLabel);
410 GuessStartOrEnd<Arc>(output_labels,
kNoLabel);
418 #endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ int64_t AddSymbol(std::string_view symbol, int64_t key)
DECLARE_string(delimiter)
void SplitByWhitespace(const std::string &str, std::vector< std::string > *out)
std::tuple< const std::string &, const std::string &, const std::string &, char **, int, const std::string &, const std::string &, const std::string &, const std::string & > LinearCompileArgs
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
bool AddWord(Label word, const std::vector< Label > &features)
bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
LinearFstData< A > * Dump()
int AddGroup(size_t future_size)
bool GetModelRecord(const std::string &model, std::istream &strm, SymbolTable *fsyms, SymbolTable *osyms, std::vector< typename Arc::Label > *input_labels, std::vector< typename Arc::Label > *output_labels, typename Arc::Weight *weight, size_t *num_line)
bool AddWord(Label word, const std::vector< Label > &features)
bool ValidateEmptySymbol()
int ScanNumClasses(char **models, int models_len)
void LinearCompile(const std::string &arc_type, const std::string &epsilon_symbol, const std::string &unknown_symbol, const std::string &vocab, char **models, int models_len, const std::string &out, const std::string &save_isymbols, const std::string &save_fsymbols, const std::string &save_osymbols)
void AddVocab(const std::string &vocab, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
bool GetVocabRecord(const std::string &vocab, std::istream &strm, SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, typename Arc::Label *word, std::vector< typename Arc::Label > *feature_labels, std::vector< typename Arc::Label > *possible_labels, size_t *num_line)
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
void AddModel(const std::string &model, SymbolTable *fsyms, SymbolTable *osyms, LinearFstDataBuilder< Arc > *builder)
size_t ReplaceCopy(InputIterator first, InputIterator last, OutputIterator result, const T &old_value, const T &new_value)
void SplitAndPush(const std::string &str, const char delim, SymbolTable *syms, std::vector< typename Arc::Label > *output)
Arc::Label LookUp(const std::string &str, SymbolTable *syms)
void LinearCompileTpl(LinearCompileArgs *args)
LinearFstData< A > * Dump()