FST  openfst-1.7.2
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
5 
6 #include <fst/symbol-table.h>
7 
8 #include <fst/flags.h>
9 #include <fst/log.h>
10 
11 #include <fstream>
12 #include <fst/util.h>
13 
14 DEFINE_bool(fst_compat_symbols, true,
15  "Require symbol tables to match when appropriate");
16 DEFINE_string(fst_field_separator, "\t ",
17  "Set of characters used as a separator between printed fields");
18 
19 namespace fst {
20 
22  : allow_negative_labels(allow_negative_labels),
23  fst_field_separator(FLAGS_fst_field_separator) {}
24 
25 namespace internal {
26 
27 // Maximum line length in textual symbols file.
28 const int kLineLen = 8096;
29 
30 // Identifies stream data as a symbol table (and its endianity).
31 static constexpr int32 kSymbolTableMagicNumber = 2125658996;
32 
33 
34 DenseSymbolMap::DenseSymbolMap()
35  : empty_(-1), buckets_(1 << 4), hash_mask_(buckets_.size() - 1) {
36  std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
37 }
38 
40  : empty_(-1),
41  symbols_(other.symbols_),
42  buckets_(other.buckets_),
43  hash_mask_(other.hash_mask_) {}
44 
45 std::pair<int64, bool> DenseSymbolMap::InsertOrFind(const string &key) {
46  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
47  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
48  Rehash(buckets_.size() * 2);
49  }
50  size_t idx = str_hash_(key) & hash_mask_;
51  while (buckets_[idx] != empty_) {
52  const auto stored_value = buckets_[idx];
53  if (symbols_[stored_value] == key) return {stored_value, false};
54  idx = (idx + 1) & hash_mask_;
55  }
56  const auto next = Size();
57  buckets_[idx] = next;
58  symbols_.emplace_back(key);
59  return {next, true};
60 }
61 
62 int64 DenseSymbolMap::Find(const string &key) const {
63  size_t idx = str_hash_(key) & hash_mask_;
64  while (buckets_[idx] != empty_) {
65  const auto stored_value = buckets_[idx];
66  if (symbols_[stored_value] == key) return stored_value;
67  idx = (idx + 1) & hash_mask_;
68  }
69  return buckets_[idx];
70 }
71 
72 void DenseSymbolMap::Rehash(size_t num_buckets) {
73  buckets_.resize(num_buckets);
74  hash_mask_ = buckets_.size() - 1;
75  std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
76  for (size_t i = 0; i < Size(); ++i) {
77  size_t idx = str_hash_(string(symbols_[i])) & hash_mask_;
78  while (buckets_[idx] != empty_) {
79  idx = (idx + 1) & hash_mask_;
80  }
81  buckets_[idx] = i;
82  }
83 }
84 
85 void DenseSymbolMap::RemoveSymbol(size_t idx) {
86  symbols_.erase(symbols_.begin() + idx);
87  Rehash(buckets_.size());
88 }
89 
91  const string &filename,
92  const SymbolTableTextOptions &opts) {
93  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(filename));
94  int64 nline = 0;
95  char line[kLineLen];
96  while (!strm.getline(line, kLineLen).fail()) {
97  ++nline;
98  std::vector<char *> col;
99  const auto separator = opts.fst_field_separator + "\n";
100  SplitString(line, separator.c_str(), &col, true);
101  if (col.empty()) continue; // Empty line.
102  if (col.size() != 2) {
103  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
104  << col.size() << "), "
105  << "file = " << filename << ", line = " << nline << ":<"
106  << line << ">";
107  return nullptr;
108  }
109  const char *symbol = col[0];
110  const char *value = col[1];
111  char *p;
112  const auto key = strtoll(value, &p, 10);
113  if (p < value + strlen(value) || (!opts.allow_negative_labels && key < 0) ||
114  key == kNoSymbol) {
115  LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
116  << value << "\", "
117  << "file = " << filename << ", line = " << nline;
118  return nullptr;
119  }
120  impl->AddSymbol(symbol, key);
121  }
122  return impl.release();
123 }
124 
125 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
126  {
127  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
128  if (check_sum_finalized_) return;
129  }
130  // We'll acquire an exclusive lock to recompute the checksums.
131  MutexLock check_sum_lock(&check_sum_mutex_);
132  if (check_sum_finalized_) { // Another thread (coming in around the same time
133  return; // might have done it already). So we recheck.
134  }
135  // Calculates the original label-agnostic checksum.
136  CheckSummer check_sum;
137  for (size_t i = 0; i < symbols_.Size(); ++i) {
138  const auto &symbol = symbols_.GetSymbol(i);
139  check_sum.Update(symbol.data(), symbol.size());
140  check_sum.Update("", 1);
141  }
142  check_sum_string_ = check_sum.Digest();
143  // Calculates the safer, label-dependent checksum.
144  CheckSummer labeled_check_sum;
145  for (int64 i = 0; i < dense_key_limit_; ++i) {
146  std::ostringstream line;
147  line << symbols_.GetSymbol(i) << '\t' << i;
148  labeled_check_sum.Update(line.str().data(), line.str().size());
149  }
150  using citer = std::map<int64, int64>::const_iterator;
151  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
152  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
153  // negative labels in the checksum that too many tests rely on.
154  if (it->first < dense_key_limit_) continue;
155  std::ostringstream line;
156  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
157  labeled_check_sum.Update(line.str().data(), line.str().size());
158  }
159  labeled_check_sum_string_ = labeled_check_sum.Digest();
160  check_sum_finalized_ = true;
161 }
162 
163 int64 SymbolTableImpl::AddSymbol(const string &symbol, int64 key) {
164  if (key == kNoSymbol) return key;
165  const auto insert_key = symbols_.InsertOrFind(symbol);
166  if (!insert_key.second) {
167  const auto key_already = GetNthKey(insert_key.first);
168  if (key_already == key) return key;
169  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
170  << " already in symbol_map_ with key = " << key_already
171  << " but supplied new key = " << key << " (ignoring new key)";
172  return key_already;
173  }
174  if (key + 1 == static_cast<int64>(symbols_.Size()) &&
175  key == dense_key_limit_) {
176  ++dense_key_limit_;
177  } else {
178  idx_key_.push_back(key);
179  key_map_[key] = symbols_.Size() - 1;
180  }
181  if (key >= available_key_) available_key_ = key + 1;
182  check_sum_finalized_ = false;
183  return key;
184 }
185 
186 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
187 // the dense-key range or re-arranges the dense-key range from time to time.
189  auto idx = key;
190  if (key < 0 || key >= dense_key_limit_) {
191  auto iter = key_map_.find(key);
192  if (iter == key_map_.end()) return;
193  idx = iter->second;
194  key_map_.erase(iter);
195  }
196  if (idx < 0 || idx >= static_cast<int64>(symbols_.Size())) return;
197  symbols_.RemoveSymbol(idx);
198  // Removed one symbol, all indexes > idx are shifted by -1.
199  for (auto &k : key_map_) {
200  if (k.second > idx) --k.second;
201  }
202  if (key >= 0 && key < dense_key_limit_) {
203  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
204  const auto new_dense_key_limit = key;
205  for (int64 i = key + 1; i < dense_key_limit_; ++i) {
206  key_map_[i] = i - 1;
207  }
208  // Moves existing values in idx_key to new place.
209  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
210  for (int64 i = symbols_.Size(); i >= dense_key_limit_; --i) {
211  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
212  }
213  // Adds indexes for previously dense keys.
214  for (int64 i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
215  idx_key_[i - new_dense_key_limit] = i + 1;
216  }
217  dense_key_limit_ = new_dense_key_limit;
218  } else {
219  // Remove entry for removed index in idx_key.
220  for (size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
221  idx_key_[i] = idx_key_[i + 1];
222  }
223  idx_key_.pop_back();
224  }
225  if (key == available_key_ - 1) available_key_ = key;
226 }
227 
229  std::istream &strm, const SymbolTableReadOptions &) {
230  int32 magic_number = 0;
231  ReadType(strm, &magic_number);
232  if (strm.fail()) {
233  LOG(ERROR) << "SymbolTable::Read: Read failed";
234  return nullptr;
235  }
236  string name;
237  ReadType(strm, &name);
238  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(name));
239  ReadType(strm, &impl->available_key_);
240  int64 size;
241  ReadType(strm, &size);
242  if (strm.fail()) {
243  LOG(ERROR) << "SymbolTable::Read: Read failed";
244  return nullptr;
245  }
246  string symbol;
247  int64 key;
248  impl->check_sum_finalized_ = false;
249  for (int64 i = 0; i < size; ++i) {
250  ReadType(strm, &symbol);
251  ReadType(strm, &key);
252  if (strm.fail()) {
253  LOG(ERROR) << "SymbolTable::Read: Read failed";
254  return nullptr;
255  }
256  impl->AddSymbol(symbol, key);
257  }
258  return impl.release();
259 }
260 
261 bool SymbolTableImpl::Write(std::ostream &strm) const {
262  WriteType(strm, kSymbolTableMagicNumber);
263  WriteType(strm, name_);
264  WriteType(strm, available_key_);
265  const int64 size = symbols_.Size();
266  WriteType(strm, size);
267  for (int64 i = 0; i < size; ++i) {
268  auto key = (i < dense_key_limit_) ? i : idx_key_[i - dense_key_limit_];
269  WriteType(strm, symbols_.GetSymbol(i));
270  WriteType(strm, key);
271  }
272  strm.flush();
273  if (strm.fail()) {
274  LOG(ERROR) << "SymbolTable::Write: Write failed";
275  return false;
276  }
277  return true;
278 }
279 
280 } // namespace internal
281 
282 void SymbolTable::AddTable(const SymbolTable &table) {
283  MutateCheck();
284  for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) {
285  impl_->AddSymbol(iter.Symbol());
286  }
287 }
288 
289 bool SymbolTable::WriteText(std::ostream &strm,
290  const SymbolTableTextOptions &opts) const {
291  if (opts.fst_field_separator.empty()) {
292  LOG(ERROR) << "Missing required field separator";
293  return false;
294  }
295  bool once_only = false;
296  for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
297  std::ostringstream line;
298  if (iter.Value() < 0 && !opts.allow_negative_labels && !once_only) {
299  LOG(WARNING) << "Negative symbol table entry when not allowed";
300  once_only = true;
301  }
302  line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
303  << '\n';
304  strm.write(line.str().data(), line.str().length());
305  }
306  return true;
307 }
308 
309 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
310  bool warning) {
311  // Flag can explicitly override this check.
312  if (!FLAGS_fst_compat_symbols) return true;
313  if (syms1 && syms2 &&
314  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
315  if (warning) {
316  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
317  << "Table sizes are " << syms1->NumSymbols() << " and "
318  << syms2->NumSymbols();
319  }
320  return false;
321  } else {
322  return true;
323  }
324 }
325 
326 void SymbolTableToString(const SymbolTable *table, string *result) {
327  std::ostringstream ostrm;
328  table->Write(ostrm);
329  *result = ostrm.str();
330 }
331 
332 SymbolTable *StringToSymbolTable(const string &str) {
333  std::istringstream istrm(str);
334  return SymbolTable::Read(istrm, SymbolTableReadOptions());
335 }
336 
337 } // namespace fst
virtual const string & LabeledCheckSum() const
Definition: symbol-table.h:308
virtual void AddTable(const SymbolTable &table)
std::pair< int64, bool > InsertOrFind(const string &key)
Definition: symbol-table.cc:45
virtual bool Write(std::ostream &strm) const
Definition: symbol-table.h:336
bool Write(std::ostream &strm) const
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:48
virtual bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
int64 Find(const string &key) const
Definition: symbol-table.cc:62
virtual size_t NumSymbols() const
Definition: symbol-table.h:323
int64_t int64
Definition: types.h:27
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
constexpr int64 kNoSymbol
Definition: symbol-table.h:28
void SplitString(char *line, const char *delim, std::vector< char * > *vec, bool omit_empty_strings)
Definition: util.cc:22
int64 AddSymbol(const string &symbol, int64 key)
static SymbolTable * Read(std::istream &strm, const SymbolTableReadOptions &opts)
Definition: symbol-table.h:238
SymbolTable * StringToSymbolTable(const string &str)
void Update(void const *data, int size)
Definition: compat.h:79
#define VLOG(level)
Definition: log.h:49
SymbolTableTextOptions(bool allow_negative_labels=false)
Definition: symbol-table.cc:21
static SymbolTableImpl * ReadText(std::istream &strm, const string &name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.cc:90
const int kLineLen
Definition: symbol-table.cc:28
static SymbolTableImpl * Read(std::istream &strm, const SymbolTableReadOptions &opts)
void SymbolTableToString(const SymbolTable *table, string *result)
int32_t int32
Definition: types.h:26
string Digest()
Definition: compat.h:92
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(size_t idx)
Definition: symbol-table.cc:85