FST  openfst-1.7.5
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
5 
6 #include <fst/symbol-table.h>
7 
8 #include <fst/flags.h>
9 #include <fst/types.h>
10 #include <fst/log.h>
11 
12 #include <fstream>
13 #include <fst/util.h>
14 
15 DEFINE_bool(fst_compat_symbols, true,
16  "Require symbol tables to match when appropriate");
17 DEFINE_string(fst_field_separator, "\t ",
18  "Set of characters used as a separator between printed fields");
19 
20 namespace fst {
21 
23  : allow_negative_labels(allow_negative_labels),
24  fst_field_separator(FLAGS_fst_field_separator) {}
25 
26 namespace internal {
27 
28 // Maximum line length in textual symbols file.
29 const int kLineLen = 8096;
30 
31 // Identifies stream data as a symbol table (and its endianity).
32 static constexpr int32 kSymbolTableMagicNumber = 2125658996;
33 
34 constexpr int64 DenseSymbolMap::kEmptyBucket;
35 
36 DenseSymbolMap::DenseSymbolMap()
37  : str_hash_(),
38  buckets_(1 << 4, kEmptyBucket),
39  hash_mask_(buckets_.size() - 1) {}
40 
41 std::pair<int64, bool> DenseSymbolMap::InsertOrFind(KeyType key) {
42  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
43  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
44  Rehash(buckets_.size() * 2);
45  }
46  size_t idx = GetHash(key);
47  while (buckets_[idx] != kEmptyBucket) {
48  const auto stored_value = buckets_[idx];
49  if (symbols_[stored_value] == key) return {stored_value, false};
50  idx = (idx + 1) & hash_mask_;
51  }
52  const auto next = Size();
53  buckets_[idx] = next;
54  symbols_.emplace_back(key);
55  return {next, true};
56 }
57 
59  size_t idx = str_hash_(key) & hash_mask_;
60  while (buckets_[idx] != kEmptyBucket) {
61  const auto stored_value = buckets_[idx];
62  if (symbols_[stored_value] == key) return stored_value;
63  idx = (idx + 1) & hash_mask_;
64  }
65  return buckets_[idx];
66 }
67 
68 void DenseSymbolMap::Rehash(size_t num_buckets) {
69  buckets_.resize(num_buckets);
70  hash_mask_ = buckets_.size() - 1;
71  std::fill(buckets_.begin(), buckets_.end(), kEmptyBucket);
72  for (size_t i = 0; i < Size(); ++i) {
73  size_t idx = GetHash(symbols_[i]);
74  while (buckets_[idx] != kEmptyBucket) {
75  idx = (idx + 1) & hash_mask_;
76  }
77  buckets_[idx] = i;
78  }
79 }
80 
81 void DenseSymbolMap::RemoveSymbol(size_t idx) {
82  symbols_.erase(symbols_.begin() + idx);
83  Rehash(buckets_.size());
84 }
85 
87  symbols_.shrink_to_fit();
88 }
89 
91  for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) {
92  AddSymbol(iter.Symbol());
93  }
94 }
95 
96 std::unique_ptr<SymbolTableImplBase> ConstSymbolTableImpl::Copy() const {
97  LOG(FATAL) << "ConstSymbolTableImpl can't be copied";
98  return nullptr;
99 }
100 
102  LOG(FATAL) << "ConstSymbolTableImpl does not support AddSymbol";
103  return kNoSymbol;
104 }
105 
107  return AddSymbol(symbol, kNoSymbol);
108 }
109 
111  LOG(FATAL) << "ConstSymbolTableImpl does not support RemoveSymbol";
112 }
113 
114 void ConstSymbolTableImpl::SetName(const std::string &new_name) {
115  LOG(FATAL) << "ConstSymbolTableImpl does not support SetName";
116 }
117 
119  LOG(FATAL) << "ConstSymbolTableImpl does not support AddTable";
120 }
121 
123  const std::string &source,
124  const SymbolTableTextOptions &opts) {
125  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(source));
126  int64 nline = 0;
127  char line[kLineLen];
128  while (!strm.getline(line, kLineLen).fail()) {
129  ++nline;
130  std::vector<char *> col;
131  const auto separator = opts.fst_field_separator + "\n";
132  SplitString(line, separator.c_str(), &col, true);
133  if (col.empty()) continue; // Empty line.
134  if (col.size() != 2) {
135  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
136  << col.size() << "), "
137  << "file = " << source << ", line = " << nline << ":<" << line
138  << ">";
139  return nullptr;
140  }
141  const char *symbol = col[0];
142  const char *value = col[1];
143  char *p;
144  const auto key = strtoll(value, &p, 10);
145  if (p < value + strlen(value) || (!opts.allow_negative_labels && key < 0) ||
146  key == kNoSymbol) {
147  LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
148  << value << "\", "
149  << "file = " << source << ", line = " << nline;
150  return nullptr;
151  }
152  impl->AddSymbol(symbol, key);
153  }
154  impl->ShrinkToFit();
155  return impl.release();
156 }
157 
158 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
159  {
160  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
161  if (check_sum_finalized_) return;
162  }
163  // We'll acquire an exclusive lock to recompute the checksums.
164  MutexLock check_sum_lock(&check_sum_mutex_);
165  if (check_sum_finalized_) { // Another thread (coming in around the same time
166  return; // might have done it already). So we recheck.
167  }
168  // Calculates the original label-agnostic checksum.
169  CheckSummer check_sum;
170  for (size_t i = 0; i < symbols_.Size(); ++i) {
171  const auto &symbol = symbols_.GetSymbol(i);
172  check_sum.Update(symbol.data(), symbol.size());
173  check_sum.Update("", 1);
174  }
175  check_sum_string_ = check_sum.Digest();
176  // Calculates the safer, label-dependent checksum.
177  CheckSummer labeled_check_sum;
178  for (int64 i = 0; i < dense_key_limit_; ++i) {
179  std::ostringstream line;
180  line << symbols_.GetSymbol(i) << '\t' << i;
181  labeled_check_sum.Update(line.str().data(), line.str().size());
182  }
183  using citer = std::map<int64, int64>::const_iterator;
184  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
185  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
186  // negative labels in the checksum that too many tests rely on.
187  if (it->first < dense_key_limit_) continue;
188  std::ostringstream line;
189  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
190  labeled_check_sum.Update(line.str().data(), line.str().size());
191  }
192  labeled_check_sum_string_ = labeled_check_sum.Digest();
193  check_sum_finalized_ = true;
194 }
195 
196 std::string SymbolTableImpl::Find(int64 key) const {
197  int64 idx = key;
198  if (key < 0 || key >= dense_key_limit_) {
199  const auto it = key_map_.find(key);
200  if (it == key_map_.end()) return "";
201  idx = it->second;
202  }
203  if (idx < 0 || idx >= symbols_.Size()) return "";
204  return symbols_.GetSymbol(idx);
205 }
206 
208  if (key == kNoSymbol) return key;
209  const auto insert_key = symbols_.InsertOrFind(symbol);
210  if (!insert_key.second) {
211  const auto key_already = GetNthKey(insert_key.first);
212  if (key_already == key) return key;
213  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
214  << " already in symbol_map_ with key = " << key_already
215  << " but supplied new key = " << key << " (ignoring new key)";
216  return key_already;
217  }
218  if (key + 1 == static_cast<int64>(symbols_.Size()) &&
219  key == dense_key_limit_) {
220  ++dense_key_limit_;
221  } else {
222  idx_key_.push_back(key);
223  key_map_[key] = symbols_.Size() - 1;
224  }
225  if (key >= available_key_) available_key_ = key + 1;
226  check_sum_finalized_ = false;
227  return key;
228 }
229 
230 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
231 // the dense-key range or re-arranges the dense-key range from time to time.
233  auto idx = key;
234  if (key < 0 || key >= dense_key_limit_) {
235  auto iter = key_map_.find(key);
236  if (iter == key_map_.end()) return;
237  idx = iter->second;
238  key_map_.erase(iter);
239  }
240  if (idx < 0 || idx >= static_cast<int64>(symbols_.Size())) return;
241  symbols_.RemoveSymbol(idx);
242  // Removed one symbol, all indexes > idx are shifted by -1.
243  for (auto &k : key_map_) {
244  if (k.second > idx) --k.second;
245  }
246  if (key >= 0 && key < dense_key_limit_) {
247  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
248  const auto new_dense_key_limit = key;
249  for (int64 i = key + 1; i < dense_key_limit_; ++i) {
250  key_map_[i] = i - 1;
251  }
252  // Moves existing values in idx_key to new place.
253  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
254  for (int64 i = symbols_.Size(); i >= dense_key_limit_; --i) {
255  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
256  }
257  // Adds indexes for previously dense keys.
258  for (int64 i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
259  idx_key_[i - new_dense_key_limit] = i + 1;
260  }
261  dense_key_limit_ = new_dense_key_limit;
262  } else {
263  // Remove entry for removed index in idx_key.
264  for (size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
265  idx_key_[i] = idx_key_[i + 1];
266  }
267  idx_key_.pop_back();
268  }
269  if (key == available_key_ - 1) available_key_ = key;
270 }
271 
273  const SymbolTableReadOptions &) {
274  int32 magic_number = 0;
275  ReadType(strm, &magic_number);
276  if (strm.fail()) {
277  LOG(ERROR) << "SymbolTable::Read: Read failed";
278  return nullptr;
279  }
280  std::string name;
281  ReadType(strm, &name);
282  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(name));
283  ReadType(strm, &impl->available_key_);
284  int64 size;
285  ReadType(strm, &size);
286  if (strm.fail()) {
287  LOG(ERROR) << "SymbolTable::Read: Read failed";
288  return nullptr;
289  }
290  std::string symbol;
291  int64 key;
292  impl->check_sum_finalized_ = false;
293  for (int64 i = 0; i < size; ++i) {
294  ReadType(strm, &symbol);
295  ReadType(strm, &key);
296  if (strm.fail()) {
297  LOG(ERROR) << "SymbolTable::Read: Read failed";
298  return nullptr;
299  }
300  impl->AddSymbol(symbol, key);
301  }
302  impl->ShrinkToFit();
303  return impl.release();
304 }
305 
306 bool SymbolTableImpl::Write(std::ostream &strm) const {
307  WriteType(strm, kSymbolTableMagicNumber);
308  WriteType(strm, name_);
309  WriteType(strm, available_key_);
310  const int64 size = symbols_.Size();
311  WriteType(strm, size);
312  for (int64 i = 0; i < dense_key_limit_; ++i) {
313  WriteType(strm, symbols_.GetSymbol(i));
314  WriteType(strm, i);
315  }
316  for (const auto &p : key_map_) {
317  WriteType(strm, symbols_.GetSymbol(p.second));
318  WriteType(strm, p.first);
319  }
320  strm.flush();
321  if (strm.fail()) {
322  LOG(ERROR) << "SymbolTable::Write: Write failed";
323  return false;
324  }
325  return true;
326 }
327 
329  symbols_.ShrinkToFit();
330 }
331 
332 } // namespace internal
333 
334 SymbolTable *SymbolTable::ReadText(const std::string &source,
335  const SymbolTableTextOptions &opts) {
336  std::ifstream strm(source, std::ios_base::in);
337  if (!strm.good()) {
338  LOG(ERROR) << "SymbolTable::ReadText: Can't open file: " << source;
339  return nullptr;
340  }
341  return ReadText(strm, source, opts);
342 }
343 
344 bool SymbolTable::Write(const std::string &source) const {
345  if (!source.empty()) {
346  std::ofstream strm(source,
347  std::ios_base::out | std::ios_base::binary);
348  if (!strm) {
349  LOG(ERROR) << "SymbolTable::Write: Can't open file: " << source;
350  return false;
351  }
352  if (!Write(strm)) {
353  LOG(ERROR) << "SymbolTable::Write: Write failed: " << source;
354  return false;
355  }
356  return true;
357  } else {
358  return Write(std::cout);
359  }
360 }
361 
362 bool SymbolTable::WriteText(std::ostream &strm,
363  const SymbolTableTextOptions &opts) const {
364  if (opts.fst_field_separator.empty()) {
365  LOG(ERROR) << "Missing required field separator";
366  return false;
367  }
368  bool once_only = false;
369  for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
370  std::ostringstream line;
371  if (iter.Value() < 0 && !opts.allow_negative_labels && !once_only) {
372  LOG(WARNING) << "Negative symbol table entry when not allowed";
373  once_only = true;
374  }
375  line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
376  << '\n';
377  strm.write(line.str().data(), line.str().length());
378  }
379  return true;
380 }
381 
382 bool SymbolTable::WriteText(const std::string &source) const {
383  if (!source.empty()) {
384  std::ofstream strm(source);
385  if (!strm) {
386  LOG(ERROR) << "SymbolTable::WriteText: Can't open file: " << source;
387  return false;
388  }
389  if (!WriteText(strm, SymbolTableTextOptions())) {
390  LOG(ERROR) << "SymbolTable::WriteText: Write failed: " << source;
391  return false;
392  }
393  return true;
394  } else {
395  return WriteText(std::cout, SymbolTableTextOptions());
396  }
397 }
398 
400  return SymbolTable::const_iterator(*this, 0);
401 }
402 
404  return SymbolTable::const_iterator(*this, this->NumSymbols());
405 }
406 
408 
410 
411 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
412  bool warning) {
413  // Flag can explicitly override this check.
414  if (!FLAGS_fst_compat_symbols) return true;
415  if (syms1 && syms2 &&
416  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
417  if (warning) {
418  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
419  << "Table sizes are " << syms1->NumSymbols() << " and "
420  << syms2->NumSymbols();
421  }
422  return false;
423  } else {
424  return true;
425  }
426 }
427 
428 void SymbolTableToString(const SymbolTable *table, std::string *result) {
429  std::ostringstream ostrm;
430  table->Write(ostrm);
431  *result = ostrm.str();
432 }
433 
434 SymbolTable *StringToSymbolTable(const std::string &str) {
435  std::istringstream istrm(str);
436  return SymbolTable::Read(istrm, SymbolTableReadOptions());
437 }
438 
439 } // namespace fst
bool Write(std::ostream &strm) const
Definition: symbol-table.h:432
const_iterator end() const
DenseSymbolMap::KeyType SymbolType
Definition: symbol-table.h:123
void SymbolTableToString(const SymbolTable *table, std::string *result)
iterator const_iterator
Definition: symbol-table.h:311
std::pair< int64, bool > InsertOrFind(KeyType key)
Definition: symbol-table.cc:41
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:46
SymbolTable * StringToSymbolTable(const std::string &str)
bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
int64 AddSymbol(SymbolType symbol, int64 key) final
void SetName(const std::string &new_name) final
int64 Find(KeyType key) const
Definition: symbol-table.cc:58
size_t NumSymbols() const
Definition: symbol-table.h:419
int64_t int64
Definition: types.h:27
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:162
void RemoveSymbol(int64 key) override
constexpr int64 kNoSymbol
Definition: symbol-table.h:46
void SplitString(char *line, const char *delim, std::vector< char * > *vec, bool omit_empty_strings)
Definition: util.cc:25
std::unique_ptr< SymbolTableImplBase > Copy() const final
Definition: symbol-table.cc:96
bool Write(std::ostream &strm) const override
const_iterator cbegin() const
static SymbolTable * Read(std::istream &strm, const SymbolTableReadOptions &opts)
Definition: symbol-table.h:336
void AddTable(const SymbolTable &table) override
Definition: symbol-table.cc:90
void Update(void const *data, int size)
Definition: compat.cc:38
#define VLOG(level)
Definition: log.h:47
SymbolTableTextOptions(bool allow_negative_labels=false)
Definition: symbol-table.cc:22
const std::string & KeyType
Definition: symbol-table.h:85
const int kLineLen
Definition: symbol-table.cc:29
std::string Find(int64 key) const override
static SymbolTableImpl * Read(std::istream &strm, const SymbolTableReadOptions &opts)
const_iterator cend() const
int32_t int32
Definition: types.h:26
static SymbolTable * ReadText(std::istream &strm, const std::string &name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.h:321
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:48
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
std::string Digest()
Definition: compat.h:75
const std::string & LabeledCheckSum() const
Definition: symbol-table.h:406
void AddTable(const SymbolTable &table) final
const_iterator begin() const
int64 AddSymbol(SymbolType symbol, int64 key) override
void RemoveSymbol(int64 key) final
static SymbolTableImpl * ReadText(std::istream &strm, const std::string &name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(size_t idx)
Definition: symbol-table.cc:81