FST  openfst-1.7.2
OpenFst Library
compile-strings.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 
4 #ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
5 #define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
6 
7 #include <libgen.h>
8 
9 #include <fstream>
10 #include <istream>
11 #include <string>
12 #include <vector>
13 
14 #include <fst/extensions/far/far.h>
15 #include <fstream>
16 #include <fst/string.h>
17 
18 namespace fst {
19 
20 // Constructs a reader that provides FSTs from a file (stream) either on a
21 // line-by-line basis or on a per-stream basis. Note that the freshly
22 // constructed reader is already set to the first input.
23 //
24 // Sample usage:
25 //
26 // for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) {
27 // auto *fst = reader.GetVectorFst();
28 // }
29 template <class Arc>
30 class StringReader {
31  public:
32  using Label = typename Arc::Label;
33  using Weight = typename Arc::Weight;
34 
35  enum EntryType { LINE = 1, FILE = 2 };
36 
37  StringReader(std::istream &istrm, const string &source, EntryType entry_type,
38  StringTokenType token_type, bool allow_negative_labels,
39  const SymbolTable *syms = nullptr,
40  Label unknown_label = kNoStateId)
41  : nline_(0),
42  istrm_(istrm),
43  source_(source),
44  entry_type_(entry_type),
45  token_type_(token_type),
46  symbols_(syms),
47  done_(false),
48  compiler_(token_type, syms, unknown_label, allow_negative_labels) {
49  Next(); // Initialize the reader to the first input.
50  }
51 
52  bool Done() { return done_; }
53 
54  void Next() {
55  VLOG(1) << "Processing source " << source_ << " at line " << nline_;
56  if (!istrm_) { // We're done if we have no more input.
57  done_ = true;
58  return;
59  }
60  if (entry_type_ == LINE) {
61  getline(istrm_, content_);
62  ++nline_;
63  } else {
64  content_.clear();
65  string line;
66  while (getline(istrm_, line)) {
67  ++nline_;
68  content_.append(line);
69  content_.append("\n");
70  }
71  }
72  if (!istrm_ && content_.empty()) // We're also done if we read off all the
73  done_ = true; // whitespace at the end of a file.
74  }
75 
76  VectorFst<Arc> *GetVectorFst(bool keep_symbols = false) {
77  std::unique_ptr<VectorFst<Arc>> fst(new VectorFst<Arc>());
78  if (keep_symbols) {
79  fst->SetInputSymbols(symbols_);
80  fst->SetOutputSymbols(symbols_);
81  }
82  if (compiler_(content_, fst.get())) {
83  return fst.release();
84  } else {
85  return nullptr;
86  }
87  }
88 
89  CompactStringFst<Arc> *GetCompactFst(bool keep_symbols = false) {
90  std::unique_ptr<CompactStringFst<Arc>> fst;
91  if (keep_symbols) {
92  VectorFst<Arc> tmp;
93  tmp.SetInputSymbols(symbols_);
94  tmp.SetOutputSymbols(symbols_);
95  fst.reset(new CompactStringFst<Arc>(tmp));
96  } else {
97  fst.reset(new CompactStringFst<Arc>());
98  }
99  if (compiler_(content_, fst.get())) {
100  return fst.release();
101  } else {
102  return nullptr;
103  }
104  }
105 
106  private:
107  size_t nline_;
108  std::istream &istrm_;
109  string source_;
110  EntryType entry_type_;
111  StringTokenType token_type_;
112  const SymbolTable *symbols_;
113  bool done_;
114  StringCompiler<Arc> compiler_;
115  string content_; // The actual content of the input stream's next FST.
116 
117  StringReader(const StringReader &) = delete;
118  StringReader &operator=(const StringReader &) = delete;
119 };
120 
121 // Computes the minimal length required to encode each line number as a decimal
122 // number.
123 int KeySize(const char *filename);
124 
125 template <class Arc>
126 void FarCompileStrings(const std::vector<string> &in_fnames,
127  const string &out_fname, const string &fst_type,
128  const FarType &far_type, int32 generate_keys,
129  FarEntryType fet, FarTokenType tt,
130  const string &symbols_fname,
131  const string &unknown_symbol, bool keep_symbols,
132  bool initial_symbols, bool allow_negative_labels,
133  const string &key_prefix, const string &key_suffix) {
134  typename StringReader<Arc>::EntryType entry_type;
135  if (fet == FET_LINE) {
136  entry_type = StringReader<Arc>::LINE;
137  } else if (fet == FET_FILE) {
138  entry_type = StringReader<Arc>::FILE;
139  } else {
140  FSTERROR() << "FarCompileStrings: Unknown entry type";
141  return;
142  }
143  StringTokenType token_type;
144  if (tt == FTT_SYMBOL) {
145  token_type = StringTokenType::SYMBOL;
146  } else if (tt == FTT_BYTE) {
147  token_type = StringTokenType::BYTE;
148  } else if (tt == FTT_UTF8) {
149  token_type = StringTokenType::UTF8;
150  } else {
151  FSTERROR() << "FarCompileStrings: Unknown token type";
152  return;
153  }
154  bool compact;
155  if (fst_type.empty() || (fst_type == "vector")) {
156  compact = false;
157  } else if (fst_type == "compact") {
158  compact = true;
159  } else {
160  FSTERROR() << "FarCompileStrings: Unknown FST type: " << fst_type;
161  return;
162  }
163  std::unique_ptr<const SymbolTable> syms;
164  typename Arc::Label unknown_label = kNoLabel;
165  if (!symbols_fname.empty()) {
166  const SymbolTableTextOptions opts(allow_negative_labels);
167  syms.reset(SymbolTable::ReadText(symbols_fname, opts));
168  if (!syms) {
169  LOG(ERROR) << "FarCompileStrings: Error reading symbol table: "
170  << symbols_fname;
171  return;
172  }
173  if (!unknown_symbol.empty()) {
174  unknown_label = syms->Find(unknown_symbol);
175  if (unknown_label == kNoLabel) {
176  FSTERROR() << "FarCompileStrings: Label \"" << unknown_label
177  << "\" missing from symbol table: " << symbols_fname;
178  return;
179  }
180  }
181  }
182  std::unique_ptr<FarWriter<Arc>> far_writer(
183  FarWriter<Arc>::Create(out_fname, far_type));
184  if (!far_writer) return;
185  int n = 0;
186  for (const auto &in_fname : in_fnames) {
187  if (generate_keys == 0 && in_fname.empty()) {
188  FSTERROR() << "FarCompileStrings: Read from a file instead of stdin or"
189  << " set the --generate_keys flags.";
190  return;
191  }
192  int key_size =
193  generate_keys ? generate_keys : (entry_type == StringReader<Arc>::FILE
194  ? 1 : KeySize(in_fname.c_str()));
195  std::ifstream fstrm;
196  if (!in_fname.empty()) {
197  fstrm.open(in_fname);
198  if (!fstrm) {
199  FSTERROR() << "FarCompileStrings: Can't open file: " << in_fname;
200  return;
201  }
202  }
203  std::istream &istrm = fstrm.is_open() ? fstrm : std::cin;
204  bool keep_syms = keep_symbols;
205  for (StringReader<Arc> reader(
206  istrm, in_fname.empty() ? "stdin" : in_fname, entry_type,
207  token_type, allow_negative_labels, syms.get(), unknown_label);
208  !reader.Done(); reader.Next()) {
209  ++n;
210  std::unique_ptr<const Fst<Arc>> fst;
211  if (compact) {
212  fst.reset(reader.GetCompactFst(keep_syms));
213  } else {
214  fst.reset(reader.GetVectorFst(keep_syms));
215  }
216  if (initial_symbols) keep_syms = false;
217  if (!fst) {
218  FSTERROR() << "FarCompileStrings: Compiling string number " << n
219  << " in file " << in_fname << " failed with token_type = "
220  << (tt == FTT_BYTE
221  ? "byte"
222  : (tt == FTT_UTF8
223  ? "utf8"
224  : (tt == FTT_SYMBOL ? "symbol" : "unknown")))
225  << " and entry_type = "
226  << (fet == FET_LINE
227  ? "line"
228  : (fet == FET_FILE ? "file" : "unknown"));
229  return;
230  }
231  std::ostringstream keybuf;
232  keybuf.width(key_size);
233  keybuf.fill('0');
234  keybuf << n;
235  string key;
236  if (generate_keys > 0) {
237  key = keybuf.str();
238  } else {
239  auto *filename = new char[in_fname.size() + 1];
240  strcpy(filename, in_fname.c_str());
241  key = basename(filename);
242  if (entry_type != StringReader<Arc>::FILE) {
243  key += "-";
244  key += keybuf.str();
245  }
246  delete[] filename;
247  }
248  far_writer->Add(key_prefix + key + key_suffix, *fst);
249  }
250  if (generate_keys == 0) n = 0;
251  }
252 }
253 
254 } // namespace fst
255 
256 #endif // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
constexpr int kNoLabel
Definition: fst.h:179
StringTokenType
Definition: string.h:27
int KeySize(const char *filename)
Definition: strings.cc:18
#define LOG(type)
Definition: log.h:48
constexpr int kNoStateId
Definition: fst.h:180
FarTokenType
Definition: far.h:23
typename Arc::Weight Weight
#define FSTERROR()
Definition: util.h:35
CompactStringFst< Arc > * GetCompactFst(bool keep_symbols=false)
void SetOutputSymbols(const SymbolTable *osyms) override
Definition: mutable-fst.h:373
FarType
Definition: far.h:71
void SetInputSymbols(const SymbolTable *isyms) override
Definition: mutable-fst.h:368
typename Arc::Label Label
void FarCompileStrings(const std::vector< string > &in_fnames, const string &out_fname, const string &fst_type, const FarType &far_type, int32 generate_keys, FarEntryType fet, FarTokenType tt, const string &symbols_fname, const string &unknown_symbol, bool keep_symbols, bool initial_symbols, bool allow_negative_labels, const string &key_prefix, const string &key_suffix)
#define VLOG(level)
Definition: log.h:49
VectorFst< Arc > * GetVectorFst(bool keep_symbols=false)
int32_t int32
Definition: types.h:26
StringReader(std::istream &istrm, const string &source, EntryType entry_type, StringTokenType token_type, bool allow_negative_labels, const SymbolTable *syms=nullptr, Label unknown_label=kNoStateId)
FarEntryType
Definition: far.h:21
static SymbolTable * ReadText(std::istream &strm, const string &name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.h:218