FST  openfst-1.7.3
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
5 
6 #include <fst/symbol-table.h>
7 
8 #include <fst/flags.h>
9 #include <fst/log.h>
10 
11 #include <fstream>
12 #include <fst/util.h>
13 
14 DEFINE_bool(fst_compat_symbols, true,
15  "Require symbol tables to match when appropriate");
16 DEFINE_string(fst_field_separator, "\t ",
17  "Set of characters used as a separator between printed fields");
18 
19 namespace fst {
20 
22  : allow_negative_labels(allow_negative_labels),
23  fst_field_separator(FLAGS_fst_field_separator) {}
24 
25 namespace internal {
26 
27 // Maximum line length in textual symbols file.
28 const int kLineLen = 8096;
29 
30 // Identifies stream data as a symbol table (and its endianity).
31 static constexpr int32 kSymbolTableMagicNumber = 2125658996;
32 
33 DenseSymbolMap::DenseSymbolMap()
34  : empty_(-1), buckets_(1 << 4), hash_mask_(buckets_.size() - 1) {
35  std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
36 }
37 
39  : empty_(-1),
40  symbols_(other.symbols_),
41  buckets_(other.buckets_),
42  hash_mask_(other.hash_mask_) {}
43 
44 std::pair<int64, bool> DenseSymbolMap::InsertOrFind(KeyType key) {
45  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
46  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
47  Rehash(buckets_.size() * 2);
48  }
49  size_t idx = str_hash_(key) & hash_mask_;
50  while (buckets_[idx] != empty_) {
51  const auto stored_value = buckets_[idx];
52  if (symbols_[stored_value] == key) return {stored_value, false};
53  idx = (idx + 1) & hash_mask_;
54  }
55  const auto next = Size();
56  buckets_[idx] = next;
57  symbols_.emplace_back(key);
58  return {next, true};
59 }
60 
62  size_t idx = str_hash_(key) & hash_mask_;
63  while (buckets_[idx] != empty_) {
64  const auto stored_value = buckets_[idx];
65  if (symbols_[stored_value] == key) return stored_value;
66  idx = (idx + 1) & hash_mask_;
67  }
68  return buckets_[idx];
69 }
70 
71 void DenseSymbolMap::Rehash(size_t num_buckets) {
72  buckets_.resize(num_buckets);
73  hash_mask_ = buckets_.size() - 1;
74  std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
75  for (size_t i = 0; i < Size(); ++i) {
76  size_t idx = str_hash_(symbols_[i]) & hash_mask_;
77  while (buckets_[idx] != empty_) {
78  idx = (idx + 1) & hash_mask_;
79  }
80  buckets_[idx] = i;
81  }
82 }
83 
84 void DenseSymbolMap::RemoveSymbol(size_t idx) {
85  symbols_.erase(symbols_.begin() + idx);
86  Rehash(buckets_.size());
87 }
88 
90  const std::string &filename,
91  const SymbolTableTextOptions &opts) {
92  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(filename));
93  int64 nline = 0;
94  char line[kLineLen];
95  while (!strm.getline(line, kLineLen).fail()) {
96  ++nline;
97  std::vector<char *> col;
98  const auto separator = opts.fst_field_separator + "\n";
99  SplitString(line, separator.c_str(), &col, true);
100  if (col.empty()) continue; // Empty line.
101  if (col.size() != 2) {
102  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
103  << col.size() << "), "
104  << "file = " << filename << ", line = " << nline << ":<"
105  << line << ">";
106  return nullptr;
107  }
108  const char *symbol = col[0];
109  const char *value = col[1];
110  char *p;
111  const auto key = strtoll(value, &p, 10);
112  if (p < value + strlen(value) || (!opts.allow_negative_labels && key < 0) ||
113  key == kNoSymbol) {
114  LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
115  << value << "\", "
116  << "file = " << filename << ", line = " << nline;
117  return nullptr;
118  }
119  impl->AddSymbol(symbol, key);
120  }
121  return impl.release();
122 }
123 
124 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
125  {
126  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
127  if (check_sum_finalized_) return;
128  }
129  // We'll acquire an exclusive lock to recompute the checksums.
130  MutexLock check_sum_lock(&check_sum_mutex_);
131  if (check_sum_finalized_) { // Another thread (coming in around the same time
132  return; // might have done it already). So we recheck.
133  }
134  // Calculates the original label-agnostic checksum.
135  CheckSummer check_sum;
136  for (size_t i = 0; i < symbols_.Size(); ++i) {
137  const auto &symbol = symbols_.GetSymbol(i);
138  check_sum.Update(symbol.data(), symbol.size());
139  check_sum.Update("", 1);
140  }
141  check_sum_string_ = check_sum.Digest();
142  // Calculates the safer, label-dependent checksum.
143  CheckSummer labeled_check_sum;
144  for (int64 i = 0; i < dense_key_limit_; ++i) {
145  std::ostringstream line;
146  line << symbols_.GetSymbol(i) << '\t' << i;
147  labeled_check_sum.Update(line.str().data(), line.str().size());
148  }
149  using citer = std::map<int64, int64>::const_iterator;
150  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
151  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
152  // negative labels in the checksum that too many tests rely on.
153  if (it->first < dense_key_limit_) continue;
154  std::ostringstream line;
155  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
156  labeled_check_sum.Update(line.str().data(), line.str().size());
157  }
158  labeled_check_sum_string_ = labeled_check_sum.Digest();
159  check_sum_finalized_ = true;
160 }
161 
163  if (key == kNoSymbol) return key;
164  const auto insert_key = symbols_.InsertOrFind(symbol);
165  if (!insert_key.second) {
166  const auto key_already = GetNthKey(insert_key.first);
167  if (key_already == key) return key;
168  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
169  << " already in symbol_map_ with key = " << key_already
170  << " but supplied new key = " << key << " (ignoring new key)";
171  return key_already;
172  }
173  if (key + 1 == static_cast<int64>(symbols_.Size()) &&
174  key == dense_key_limit_) {
175  ++dense_key_limit_;
176  } else {
177  idx_key_.push_back(key);
178  key_map_[key] = symbols_.Size() - 1;
179  }
180  if (key >= available_key_) available_key_ = key + 1;
181  check_sum_finalized_ = false;
182  return key;
183 }
184 
185 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
186 // the dense-key range or re-arranges the dense-key range from time to time.
188  auto idx = key;
189  if (key < 0 || key >= dense_key_limit_) {
190  auto iter = key_map_.find(key);
191  if (iter == key_map_.end()) return;
192  idx = iter->second;
193  key_map_.erase(iter);
194  }
195  if (idx < 0 || idx >= static_cast<int64>(symbols_.Size())) return;
196  symbols_.RemoveSymbol(idx);
197  // Removed one symbol, all indexes > idx are shifted by -1.
198  for (auto &k : key_map_) {
199  if (k.second > idx) --k.second;
200  }
201  if (key >= 0 && key < dense_key_limit_) {
202  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
203  const auto new_dense_key_limit = key;
204  for (int64 i = key + 1; i < dense_key_limit_; ++i) {
205  key_map_[i] = i - 1;
206  }
207  // Moves existing values in idx_key to new place.
208  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
209  for (int64 i = symbols_.Size(); i >= dense_key_limit_; --i) {
210  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
211  }
212  // Adds indexes for previously dense keys.
213  for (int64 i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
214  idx_key_[i - new_dense_key_limit] = i + 1;
215  }
216  dense_key_limit_ = new_dense_key_limit;
217  } else {
218  // Remove entry for removed index in idx_key.
219  for (size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
220  idx_key_[i] = idx_key_[i + 1];
221  }
222  idx_key_.pop_back();
223  }
224  if (key == available_key_ - 1) available_key_ = key;
225 }
226 
228  std::istream &strm, const SymbolTableReadOptions &) {
229  int32 magic_number = 0;
230  ReadType(strm, &magic_number);
231  if (strm.fail()) {
232  LOG(ERROR) << "SymbolTable::Read: Read failed";
233  return nullptr;
234  }
235  std::string name;
236  ReadType(strm, &name);
237  std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(name));
238  ReadType(strm, &impl->available_key_);
239  int64 size;
240  ReadType(strm, &size);
241  if (strm.fail()) {
242  LOG(ERROR) << "SymbolTable::Read: Read failed";
243  return nullptr;
244  }
245  std::string symbol;
246  int64 key;
247  impl->check_sum_finalized_ = false;
248  for (int64 i = 0; i < size; ++i) {
249  ReadType(strm, &symbol);
250  ReadType(strm, &key);
251  if (strm.fail()) {
252  LOG(ERROR) << "SymbolTable::Read: Read failed";
253  return nullptr;
254  }
255  impl->AddSymbol(symbol, key);
256  }
257  return impl.release();
258 }
259 
260 bool SymbolTableImpl::Write(std::ostream &strm) const {
261  WriteType(strm, kSymbolTableMagicNumber);
262  WriteType(strm, name_);
263  WriteType(strm, available_key_);
264  const int64 size = symbols_.Size();
265  WriteType(strm, size);
266  for (int64 i = 0; i < size; ++i) {
267  auto key = (i < dense_key_limit_) ? i : idx_key_[i - dense_key_limit_];
268  WriteType(strm, symbols_.GetSymbol(i));
269  WriteType(strm, key);
270  }
271  strm.flush();
272  if (strm.fail()) {
273  LOG(ERROR) << "SymbolTable::Write: Write failed";
274  return false;
275  }
276  return true;
277 }
278 
279 } // namespace internal
280 
281 void SymbolTable::AddTable(const SymbolTable &table) {
282  MutateCheck();
283  for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) {
284  impl_->AddSymbol(iter.Symbol());
285  }
286 }
287 
288 bool SymbolTable::WriteText(std::ostream &strm,
289  const SymbolTableTextOptions &opts) const {
290  if (opts.fst_field_separator.empty()) {
291  LOG(ERROR) << "Missing required field separator";
292  return false;
293  }
294  bool once_only = false;
295  for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
296  std::ostringstream line;
297  if (iter.Value() < 0 && !opts.allow_negative_labels && !once_only) {
298  LOG(WARNING) << "Negative symbol table entry when not allowed";
299  once_only = true;
300  }
301  line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
302  << '\n';
303  strm.write(line.str().data(), line.str().length());
304  }
305  return true;
306 }
307 
308 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
309  bool warning) {
310  // Flag can explicitly override this check.
311  if (!FLAGS_fst_compat_symbols) return true;
312  if (syms1 && syms2 &&
313  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
314  if (warning) {
315  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
316  << "Table sizes are " << syms1->NumSymbols() << " and "
317  << syms2->NumSymbols();
318  }
319  return false;
320  } else {
321  return true;
322  }
323 }
324 
325 void SymbolTableToString(const SymbolTable *table, std::string *result) {
326  std::ostringstream ostrm;
327  table->Write(ostrm);
328  *result = ostrm.str();
329 }
330 
331 SymbolTable *StringToSymbolTable(const std::string &str) {
332  std::istringstream istrm(str);
333  return SymbolTable::Read(istrm, SymbolTableReadOptions());
334 }
335 
336 } // namespace fst
void SymbolTableToString(const SymbolTable *table, std::string *result)
std::pair< int64, bool > InsertOrFind(KeyType key)
Definition: symbol-table.cc:44
virtual void AddTable(const SymbolTable &table)
virtual bool Write(std::ostream &strm) const
Definition: symbol-table.h:364
virtual const std::string & LabeledCheckSum() const
Definition: symbol-table.h:338
bool Write(std::ostream &strm) const
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:46
SymbolTable * StringToSymbolTable(const std::string &str)
virtual bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
virtual size_t NumSymbols() const
Definition: symbol-table.h:351
int64 Find(KeyType key) const
Definition: symbol-table.cc:61
int64_t int64
Definition: types.h:27
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
constexpr int64 kNoSymbol
Definition: symbol-table.h:44
void SplitString(char *line, const char *delim, std::vector< char * > *vec, bool omit_empty_strings)
Definition: util.cc:24
static SymbolTable * Read(std::istream &strm, const SymbolTableReadOptions &opts)
Definition: symbol-table.h:269
void Update(void const *data, int size)
Definition: compat.h:71
#define VLOG(level)
Definition: log.h:47
SymbolTableTextOptions(bool allow_negative_labels=false)
Definition: symbol-table.cc:21
int64 AddSymbol(SymbolType symbol, int64 key)
const std::string & KeyType
Definition: symbol-table.h:81
const int kLineLen
Definition: symbol-table.cc:28
static SymbolTableImpl * Read(std::istream &strm, const SymbolTableReadOptions &opts)
int32_t int32
Definition: types.h:26
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
std::string Digest()
Definition: compat.h:84
static SymbolTableImpl * ReadText(std::istream &strm, const std::string &name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.cc:89
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
DenseSymbolMap::KeyType SymbolType
Definition: symbol-table.h:113
void RemoveSymbol(size_t idx)
Definition: symbol-table.cc:84