FST  openfst-1.7.1
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
5 
6 #include <fst/symbol-table.h>
7 
8 #include <fst/flags.h>
9 #include <fst/log.h>
10 
11 #include <fstream>
12 #include <fst/util.h>
13 
14 DEFINE_bool(fst_compat_symbols, true,
15  "Require symbol tables to match when appropriate");
16 DEFINE_string(fst_field_separator, "\t ",
17  "Set of characters used as a separator between printed fields");
18 
19 namespace fst {
20 
22  : allow_negative_labels(allow_negative_labels),
23  fst_field_separator(FLAGS_fst_field_separator) {}
24 
25 namespace internal {
26 
27 // Maximum line length in textual symbols file.
28 static constexpr int kLineLen = 8096;
29 
30 // Identifies stream data as a symbol table (and its endianity).
31 static constexpr int32 kSymbolTableMagicNumber = 2125658996;
32 
33 DenseSymbolMap::DenseSymbolMap()
34  : empty_(-1), buckets_(1 << 4), hash_mask_(buckets_.size() - 1) {
35  std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
36 }
37 
39  : empty_(-1),
40  symbols_(other.symbols_),
41  buckets_(other.buckets_),
42  hash_mask_(other.hash_mask_) {}
43 
44 std::pair<int64, bool> DenseSymbolMap::InsertOrFind(const string &key) {
45  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
46  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
47  Rehash(buckets_.size() * 2);
48  }
49  size_t idx = str_hash_(key) & hash_mask_;
50  while (buckets_[idx] != empty_) {
51  const auto stored_value = buckets_[idx];
52  if (symbols_[stored_value] == key) return {stored_value, false};
53  idx = (idx + 1) & hash_mask_;
54  }
55  const auto next = Size();
56  buckets_[idx] = next;
57  symbols_.emplace_back(key);
58  return {next, true};
59 }
60 
61 int64 DenseSymbolMap::Find(const string &key) const {
62  size_t idx = str_hash_(key) & hash_mask_;
63  while (buckets_[idx] != empty_) {
64  const auto stored_value = buckets_[idx];
65  if (symbols_[stored_value] == key) return stored_value;
66  idx = (idx + 1) & hash_mask_;
67  }
68  return buckets_[idx];
69 }
70 
71 void DenseSymbolMap::Rehash(size_t num_buckets) {
72  buckets_.resize(num_buckets);
73  hash_mask_ = buckets_.size() - 1;
74  std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
75  for (size_t i = 0; i < Size(); ++i) {
76  size_t idx = str_hash_(string(symbols_[i])) & hash_mask_;
77  while (buckets_[idx] != empty_) {
78  idx = (idx + 1) & hash_mask_;
79  }
80  buckets_[idx] = i;
81  }
82 }
83 
84 void DenseSymbolMap::RemoveSymbol(size_t idx) {
85  symbols_.erase(symbols_.begin() + idx);
86  Rehash(buckets_.size());
87 }
88 
90  const string &filename,
91  const SymbolTableTextOptions &opts) {
92  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(filename));
93  int64 nline = 0;
94  char line[kLineLen];
95  while (!strm.getline(line, kLineLen).fail()) {
96  ++nline;
97  std::vector<char *> col;
98  const auto separator = opts.fst_field_separator + "\n";
99  SplitString(line, separator.c_str(), &col, true);
100  if (col.empty()) continue; // Empty line.
101  if (col.size() != 2) {
102  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
103  << col.size() << "), "
104  << "file = " << filename << ", line = " << nline << ":<"
105  << line << ">";
106  return nullptr;
107  }
108  const char *symbol = col[0];
109  const char *value = col[1];
110  char *p;
111  const auto key = strtoll(value, &p, 10);
112  if (p < value + strlen(value) || (!opts.allow_negative_labels && key < 0) ||
113  key == kNoSymbol) {
114  LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
115  << value << "\", "
116  << "file = " << filename << ", line = " << nline;
117  return nullptr;
118  }
119  impl->AddSymbol(symbol, key);
120  }
121  return impl.release();
122 }
123 
124 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
125  {
126  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
127  if (check_sum_finalized_) return;
128  }
129  // We'll acquire an exclusive lock to recompute the checksums.
130  MutexLock check_sum_lock(&check_sum_mutex_);
131  if (check_sum_finalized_) { // Another thread (coming in around the same time
132  return; // might have done it already). So we recheck.
133  }
134  // Calculates the original label-agnostic checksum.
135  CheckSummer check_sum;
136  for (size_t i = 0; i < symbols_.Size(); ++i) {
137  const auto &symbol = symbols_.GetSymbol(i);
138  check_sum.Update(symbol.data(), symbol.size());
139  check_sum.Update("", 1);
140  }
141  check_sum_string_ = check_sum.Digest();
142  // Calculates the safer, label-dependent checksum.
143  CheckSummer labeled_check_sum;
144  for (int64 i = 0; i < dense_key_limit_; ++i) {
145  std::ostringstream line;
146  line << symbols_.GetSymbol(i) << '\t' << i;
147  labeled_check_sum.Update(line.str().data(), line.str().size());
148  }
149  using citer = std::map<int64, int64>::const_iterator;
150  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
151  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
152  // negative labels in the checksum that too many tests rely on.
153  if (it->first < dense_key_limit_) continue;
154  std::ostringstream line;
155  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
156  labeled_check_sum.Update(line.str().data(), line.str().size());
157  }
158  labeled_check_sum_string_ = labeled_check_sum.Digest();
159  check_sum_finalized_ = true;
160 }
161 
162 int64 SymbolTableImpl::AddSymbol(const string &symbol, int64 key) {
163  if (key == kNoSymbol) return key;
164  const auto insert_key = symbols_.InsertOrFind(symbol);
165  if (!insert_key.second) {
166  const auto key_already = GetNthKey(insert_key.first);
167  if (key_already == key) return key;
168  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
169  << " already in symbol_map_ with key = " << key_already
170  << " but supplied new key = " << key << " (ignoring new key)";
171  return key_already;
172  }
173  if (key == (symbols_.Size() - 1) && key == dense_key_limit_) {
174  ++dense_key_limit_;
175  } else {
176  idx_key_.push_back(key);
177  key_map_[key] = symbols_.Size() - 1;
178  }
179  if (key >= available_key_) available_key_ = key + 1;
180  check_sum_finalized_ = false;
181  return key;
182 }
183 
184 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
185 // the dense-key range or re-arranges the dense-key range from time to time.
187  auto idx = key;
188  if (key < 0 || key >= dense_key_limit_) {
189  auto iter = key_map_.find(key);
190  if (iter == key_map_.end()) return;
191  idx = iter->second;
192  key_map_.erase(iter);
193  }
194  if (idx < 0 || idx >= symbols_.Size()) return;
195  symbols_.RemoveSymbol(idx);
196  // Removed one symbol, all indexes > idx are shifted by -1.
197  for (auto &k : key_map_) {
198  if (k.second > idx) --k.second;
199  }
200  if (key >= 0 && key < dense_key_limit_) {
201  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
202  const auto new_dense_key_limit = key;
203  for (int64 i = key + 1; i < dense_key_limit_; ++i) {
204  key_map_[i] = i - 1;
205  }
206  // Moves existing values in idx_key to new place.
207  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
208  for (int64 i = symbols_.Size(); i >= dense_key_limit_; --i) {
209  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
210  }
211  // Adds indexes for previously dense keys.
212  for (int64 i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
213  idx_key_[i - new_dense_key_limit] = i + 1;
214  }
215  dense_key_limit_ = new_dense_key_limit;
216  } else {
217  // Remove entry for removed index in idx_key.
218  for (int64 i = idx - dense_key_limit_; i < idx_key_.size() - 1; ++i) {
219  idx_key_[i] = idx_key_[i + 1];
220  }
221  idx_key_.pop_back();
222  }
223  if (key == available_key_ - 1) available_key_ = key;
224 }
225 
227  const SymbolTableReadOptions &opts) {
228  int32 magic_number = 0;
229  ReadType(strm, &magic_number);
230  if (strm.fail()) {
231  LOG(ERROR) << "SymbolTable::Read: Read failed";
232  return nullptr;
233  }
234  string name;
235  ReadType(strm, &name);
236  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(name));
237  ReadType(strm, &impl->available_key_);
238  int64 size;
239  ReadType(strm, &size);
240  if (strm.fail()) {
241  LOG(ERROR) << "SymbolTable::Read: Read failed";
242  return nullptr;
243  }
244  string symbol;
245  int64 key;
246  impl->check_sum_finalized_ = false;
247  for (int64 i = 0; i < size; ++i) {
248  ReadType(strm, &symbol);
249  ReadType(strm, &key);
250  if (strm.fail()) {
251  LOG(ERROR) << "SymbolTable::Read: Read failed";
252  return nullptr;
253  }
254  impl->AddSymbol(symbol, key);
255  }
256  return impl.release();
257 }
258 
259 bool SymbolTableImpl::Write(std::ostream &strm) const {
260  WriteType(strm, kSymbolTableMagicNumber);
261  WriteType(strm, name_);
262  WriteType(strm, available_key_);
263  const int64 size = symbols_.Size();
264  WriteType(strm, size);
265  for (int64 i = 0; i < size; ++i) {
266  auto key = (i < dense_key_limit_) ? i : idx_key_[i - dense_key_limit_];
267  WriteType(strm, symbols_.GetSymbol(i));
268  WriteType(strm, key);
269  }
270  strm.flush();
271  if (strm.fail()) {
272  LOG(ERROR) << "SymbolTable::Write: Write failed";
273  return false;
274  }
275  return true;
276 }
277 
278 } // namespace internal
279 
280 void SymbolTable::AddTable(const SymbolTable &table) {
281  MutateCheck();
282  for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) {
283  impl_->AddSymbol(iter.Symbol());
284  }
285 }
286 
287 bool SymbolTable::WriteText(std::ostream &strm,
288  const SymbolTableTextOptions &opts) const {
289  if (opts.fst_field_separator.empty()) {
290  LOG(ERROR) << "Missing required field separator";
291  return false;
292  }
293  bool once_only = false;
294  for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
295  std::ostringstream line;
296  if (iter.Value() < 0 && !opts.allow_negative_labels && !once_only) {
297  LOG(WARNING) << "Negative symbol table entry when not allowed";
298  once_only = true;
299  }
300  line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
301  << '\n';
302  strm.write(line.str().data(), line.str().length());
303  }
304  return true;
305 }
306 
307 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
308  bool warning) {
309  // Flag can explicitly override this check.
310  if (!FLAGS_fst_compat_symbols) return true;
311  if (syms1 && syms2 &&
312  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
313  if (warning) {
314  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
315  << "Table sizes are " << syms1->NumSymbols() << " and "
316  << syms2->NumSymbols();
317  }
318  return false;
319  } else {
320  return true;
321  }
322 }
323 
324 void SymbolTableToString(const SymbolTable *table, string *result) {
325  std::ostringstream ostrm;
326  table->Write(ostrm);
327  *result = ostrm.str();
328 }
329 
330 SymbolTable *StringToSymbolTable(const string &str) {
331  std::istringstream istrm(str);
332  return SymbolTable::Read(istrm, SymbolTableReadOptions());
333 }
334 
335 } // namespace fst
virtual const string & LabeledCheckSum() const
Definition: symbol-table.h:306
virtual void AddTable(const SymbolTable &table)
std::pair< int64, bool > InsertOrFind(const string &key)
Definition: symbol-table.cc:44
virtual bool Write(std::ostream &strm) const
Definition: symbol-table.h:334
bool Write(std::ostream &strm) const
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:48
virtual bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
int64 Find(const string &key) const
Definition: symbol-table.cc:61
virtual size_t NumSymbols() const
Definition: symbol-table.h:321
int64_t int64
Definition: types.h:27
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
constexpr int64 kNoSymbol
Definition: symbol-table.h:28
void SplitString(char *line, const char *delim, std::vector< char * > *vec, bool omit_empty_strings)
Definition: util.cc:22
int64 AddSymbol(const string &symbol, int64 key)
static SymbolTable * Read(std::istream &strm, const SymbolTableReadOptions &opts)
Definition: symbol-table.h:236
SymbolTable * StringToSymbolTable(const string &str)
void Update(void const *data, int size)
Definition: compat.h:79
#define VLOG(level)
Definition: log.h:49
SymbolTableTextOptions(bool allow_negative_labels=false)
Definition: symbol-table.cc:21
static SymbolTableImpl * ReadText(std::istream &strm, const string &name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.cc:89
static SymbolTableImpl * Read(std::istream &strm, const SymbolTableReadOptions &opts)
void SymbolTableToString(const SymbolTable *table, string *result)
int32_t int32
Definition: types.h:26
string Digest()
Definition: compat.h:92
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(size_t idx)
Definition: symbol-table.cc:84