FST  openfst-1.8.2
OpenFst Library
compile-strings.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
19 #define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
20 
21 #include <libgen.h>
22 
23 #include <cstdint>
24 #include <fstream>
25 #include <istream>
26 #include <string>
27 #include <vector>
28 
29 #include <fst/extensions/far/far.h>
30 #include <fstream>
31 #include <fst/string.h>
32 
33 namespace fst {
34 namespace internal {
35 
36 // Constructs a reader that provides FSTs from a file (stream) either on a
37 // line-by-line basis or on a per-stream basis. Note that the freshly
38 // constructed reader is already set to the first input.
39 //
40 // Sample usage:
41 //
42 // for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) {
43 // auto *fst = reader.GetVectorFst();
44 // }
45 template <class Arc>
46 class StringReader {
47  public:
48  using Label = typename Arc::Label;
49  using Weight = typename Arc::Weight;
50 
51  StringReader(std::istream &istrm, const std::string &source,
52  FarEntryType entry_type, TokenType token_type,
53  bool allow_negative_labels, const SymbolTable *syms = nullptr,
54  Label unknown_label = kNoStateId)
55  : nline_(0),
56  istrm_(istrm),
57  source_(source),
58  entry_type_(entry_type),
59  token_type_(token_type),
60  symbols_(syms),
61  done_(false),
62  compiler_(token_type, syms, unknown_label, allow_negative_labels) {
63  Next(); // Initialize the reader to the first input.
64  }
65 
66  bool Done() { return done_; }
67 
68  void Next() {
69  VLOG(1) << "Processing source " << source_ << " at line " << nline_;
70  if (!istrm_) { // We're done if we have no more input.
71  done_ = true;
72  return;
73  }
74  if (entry_type_ == FarEntryType::LINE) {
75  std::getline(istrm_, content_);
76  ++nline_;
77  } else {
78  content_.clear();
79  std::string line;
80  while (std::getline(istrm_, line)) {
81  ++nline_;
82  content_.append(line);
83  content_.append("\n");
84  }
85  }
86  if (!istrm_ && content_.empty()) // We're also done if we read off all the
87  done_ = true; // whitespace at the end of a file.
88  }
89 
90  VectorFst<Arc> *GetVectorFst(bool keep_symbols = false) {
91  std::unique_ptr<VectorFst<Arc>> fst(new VectorFst<Arc>());
92  if (keep_symbols) {
93  fst->SetInputSymbols(symbols_);
94  fst->SetOutputSymbols(symbols_);
95  }
96  if (compiler_(content_, fst.get())) {
97  return fst.release();
98  } else {
99  return nullptr;
100  }
101  }
102 
103  CompactStringFst<Arc> *GetCompactFst(bool keep_symbols = false) {
104  std::unique_ptr<CompactStringFst<Arc>> fst;
105  if (keep_symbols) {
106  VectorFst<Arc> tmp;
107  tmp.SetInputSymbols(symbols_);
108  tmp.SetOutputSymbols(symbols_);
109  fst.reset(new CompactStringFst<Arc>(tmp));
110  } else {
111  fst.reset(new CompactStringFst<Arc>());
112  }
113  if (compiler_(content_, fst.get())) {
114  return fst.release();
115  } else {
116  return nullptr;
117  }
118  }
119 
120  private:
121  size_t nline_;
122  std::istream &istrm_;
123  std::string source_;
124  FarEntryType entry_type_;
125  TokenType token_type_;
126  const SymbolTable *symbols_;
127  bool done_;
128  StringCompiler<Arc> compiler_;
129  std::string content_; // The actual content of the input stream's next FST.
130 
131  StringReader(const StringReader &) = delete;
132  StringReader &operator=(const StringReader &) = delete;
133 };
134 
135 // Computes the minimal length required to encode each line number as a decimal
136 // number, or zero if the file is not seekable.
137 int KeySize(const std::string &source);
138 
139 } // namespace internal
140 
141 template <class Arc>
142 void CompileStrings(const std::vector<std::string> &sources,
143  FarWriter<Arc> &writer, std::string_view fst_type,
144  int32_t generate_keys, FarEntryType entry_type,
145  TokenType token_type, const std::string &symbols_source,
146  const std::string &unknown_symbol, bool keep_symbols,
147  bool initial_symbols, bool allow_negative_labels,
148  const std::string &key_prefix,
149  const std::string &key_suffix) {
150  bool compact;
151  if (fst_type.empty() || (fst_type == "vector")) {
152  compact = false;
153  } else if (fst_type == "compact") {
154  compact = true;
155  } else {
156  FSTERROR() << "CompileStrings: Unknown FST type: " << fst_type;
157  return;
158  }
159  std::unique_ptr<const SymbolTable> syms;
160  typename Arc::Label unknown_label = kNoLabel;
161  if (!symbols_source.empty()) {
162  const SymbolTableTextOptions opts(allow_negative_labels);
163  syms.reset(SymbolTable::ReadText(symbols_source, opts));
164  if (!syms) {
165  LOG(ERROR) << "CompileStrings: Error reading symbol table: "
166  << symbols_source;
167  return;
168  }
169  if (!unknown_symbol.empty()) {
170  unknown_label = syms->Find(unknown_symbol);
171  if (unknown_label == kNoLabel) {
172  FSTERROR() << "CompileStrings: Label \"" << unknown_label
173  << "\" missing from symbol table: " << symbols_source;
174  return;
175  }
176  }
177  }
178  int n = 0;
179  for (const auto &in_source : sources) {
180  // Don't try to call KeySize("").
181  if (generate_keys == 0 && in_source.empty()) {
182  FSTERROR() << "CompileStrings: Read from a file instead of stdin or"
183  << " set the --generate_keys flag.";
184  return;
185  }
186  const int key_size = generate_keys ? generate_keys
187  : (entry_type == FarEntryType::FILE
188  ? 1
189  : internal::KeySize(in_source));
190  if (key_size == 0) {
191  FSTERROR() << "CompileStrings: " << in_source << " is not seekable. "
192  << "Read from a file instead or set the --generate_keys flag.";
193  return;
194  }
195  std::ifstream fstrm;
196  if (!in_source.empty()) {
197  fstrm.open(in_source);
198  if (!fstrm) {
199  FSTERROR() << "CompileStrings: Can't open file: " << in_source;
200  return;
201  }
202  }
203  std::istream &istrm = fstrm.is_open() ? fstrm : std::cin;
204  bool keep_syms = keep_symbols;
205  for (internal::StringReader<Arc> reader(
206  istrm, in_source.empty() ? "stdin" : in_source, entry_type,
207  token_type, allow_negative_labels, syms.get(), unknown_label);
208  !reader.Done(); reader.Next()) {
209  ++n;
210  std::unique_ptr<const Fst<Arc>> fst;
211  if (compact) {
212  fst.reset(reader.GetCompactFst(keep_syms));
213  } else {
214  fst.reset(reader.GetVectorFst(keep_syms));
215  }
216  if (initial_symbols) keep_syms = false;
217  if (!fst) {
218  FSTERROR() << "CompileStrings: Compiling string number " << n
219  << " in file " << in_source
220  << " failed with token_type = " << token_type
221  << " and entry_type = "
222  << (entry_type == FarEntryType::LINE
223  ? "line"
224  : (entry_type == FarEntryType::FILE ? "file"
225  : "unknown"));
226  return;
227  }
228  std::ostringstream keybuf;
229  keybuf.width(key_size);
230  keybuf.fill('0');
231  keybuf << n;
232  std::string key;
233  if (generate_keys > 0) {
234  key = keybuf.str();
235  } else {
236  auto source =
237  fst::make_unique_for_overwrite<char[]>(in_source.size() + 1);
238  strcpy(source.get(), in_source.c_str()); // NOLINT(runtime/printf)
239  key = basename(source.get());
240  if (entry_type != FarEntryType::FILE) {
241  key += "-";
242  key += keybuf.str();
243  }
244  }
245  writer.Add(key_prefix + key + key_suffix, *fst);
246  }
247  if (generate_keys == 0) n = 0;
248  }
249 }
250 
251 } // namespace fst
252 
253 #endif // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
constexpr int kNoLabel
Definition: fst.h:201
#define LOG(type)
Definition: log.h:49
constexpr int kNoStateId
Definition: fst.h:202
typename Arc::Label Label
VectorFst< Arc > * GetVectorFst(bool keep_symbols=false)
virtual void Add(std::string_view key, const Fst< Arc > &fst)=0
#define FSTERROR()
Definition: util.h:53
void SetOutputSymbols(const SymbolTable *osyms) override
Definition: mutable-fst.h:396
static SymbolTable * ReadText(std::istream &strm, std::string_view name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.h:377
void SetInputSymbols(const SymbolTable *isyms) override
Definition: mutable-fst.h:391
#define VLOG(level)
Definition: log.h:50
TokenType
Definition: string.h:47
int KeySize(const std::string &source)
CompactStringFst< Arc > * GetCompactFst(bool keep_symbols=false)
typename Arc::Weight Weight
void CompileStrings(const std::vector< std::string > &sources, FarWriter< Arc > &writer, std::string_view fst_type, int32_t generate_keys, FarEntryType entry_type, TokenType token_type, const std::string &symbols_source, const std::string &unknown_symbol, bool keep_symbols, bool initial_symbols, bool allow_negative_labels, const std::string &key_prefix, const std::string &key_suffix)
FarEntryType
Definition: far.h:41
StringReader(std::istream &istrm, const std::string &source, FarEntryType entry_type, TokenType token_type, bool allow_negative_labels, const SymbolTable *syms=nullptr, Label unknown_label=kNoStateId)