FST  openfst-1.8.2.post1
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
19 
20 #include <fst/symbol-table.h>
21 
22 #include <cstdint>
23 
24 #include <fst/flags.h>
25 #include <fst/log.h>
26 
27 #include <fstream>
28 #include <fst/util.h>
29 #include <string_view>
30 #include <fst/lock.h>
31 
32 DEFINE_bool(fst_compat_symbols, true,
33  "Require symbol tables to match when appropriate");
34 DEFINE_string(fst_field_separator, "\t ",
35  "Set of characters used as a separator between printed fields");
36 
37 namespace fst {
38 
40  : allow_negative_labels(allow_negative_labels),
41  fst_field_separator(FST_FLAGS_fst_field_separator) {}
42 
43 namespace internal {
44 
45 // Identifies stream data as a symbol table (and its endianity).
46 static constexpr int32_t kSymbolTableMagicNumber = 2125658996;
47 
48 DenseSymbolMap::DenseSymbolMap()
49  : str_hash_(),
50  buckets_(1 << 4, kEmptyBucket),
51  hash_mask_(buckets_.size() - 1) {}
52 
53 std::pair<int64_t, bool> DenseSymbolMap::InsertOrFind(std::string_view key) {
54  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
55  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
56  Rehash(buckets_.size() * 2);
57  }
58  size_t idx = GetHash(key);
59  while (buckets_[idx] != kEmptyBucket) {
60  const auto stored_value = buckets_[idx];
61  if (symbols_[stored_value] == key) return {stored_value, false};
62  idx = (idx + 1) & hash_mask_;
63  }
64  const auto next = Size();
65  buckets_[idx] = next;
66  symbols_.emplace_back(key);
67  return {next, true};
68 }
69 
70 int64_t DenseSymbolMap::Find(std::string_view key) const {
71  size_t idx = str_hash_(key) & hash_mask_;
72  while (buckets_[idx] != kEmptyBucket) {
73  const auto stored_value = buckets_[idx];
74  if (symbols_[stored_value] == key) return stored_value;
75  idx = (idx + 1) & hash_mask_;
76  }
77  return buckets_[idx];
78 }
79 
80 void DenseSymbolMap::Rehash(size_t num_buckets) {
81  buckets_.resize(num_buckets);
82  hash_mask_ = buckets_.size() - 1;
83  std::fill(buckets_.begin(), buckets_.end(), kEmptyBucket);
84  for (size_t i = 0; i < Size(); ++i) {
85  size_t idx = GetHash(symbols_[i]);
86  while (buckets_[idx] != kEmptyBucket) {
87  idx = (idx + 1) & hash_mask_;
88  }
89  buckets_[idx] = i;
90  }
91 }
92 
93 void DenseSymbolMap::RemoveSymbol(size_t idx) {
94  symbols_.erase(symbols_.begin() + idx);
95  Rehash(buckets_.size());
96 }
97 
98 void DenseSymbolMap::ShrinkToFit() { symbols_.shrink_to_fit(); }
99 
101  for (const auto &item : table) {
102  AddSymbol(item.Symbol());
103  }
104 }
105 
106 std::unique_ptr<SymbolTableImplBase> ConstSymbolTableImpl::Copy() const {
107  LOG(FATAL) << "ConstSymbolTableImpl can't be copied";
108  return nullptr;
109 }
110 
111 int64_t ConstSymbolTableImpl::AddSymbol(std::string_view symbol, int64_t key) {
112  LOG(FATAL) << "ConstSymbolTableImpl does not support AddSymbol";
113  return kNoSymbol;
114 }
115 
116 int64_t ConstSymbolTableImpl::AddSymbol(std::string_view symbol) {
117  return AddSymbol(symbol, kNoSymbol);
118 }
119 
121  LOG(FATAL) << "ConstSymbolTableImpl does not support RemoveSymbol";
122 }
123 
124 void ConstSymbolTableImpl::SetName(std::string_view new_name) {
125  LOG(FATAL) << "ConstSymbolTableImpl does not support SetName";
126 }
127 
129  LOG(FATAL) << "ConstSymbolTableImpl does not support AddTable";
130 }
131 
133  std::string_view name,
134  const SymbolTableTextOptions &opts) {
135  auto impl = std::make_unique<SymbolTableImpl>(name);
136  int64_t nline = 0;
137  char line[kLineLen];
138  const auto separator = opts.fst_field_separator + "\n";
139  while (!strm.getline(line, kLineLen).fail()) {
140  ++nline;
141  std::vector<std::string_view> col =
142  StrSplit(line, ByAnyChar(separator), SkipEmpty());
143  if (col.empty()) continue; // Empty line.
144  if (col.size() != 2) {
145  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
146  << col.size() << "), "
147  << "file = " << name << ", line = " << nline << ":<" << line
148  << ">";
149  return nullptr;
150  }
151  std::string_view symbol = col[0];
152  std::string_view value = col[1];
153  const auto maybe_key = ParseInt64(value);
154  if (!maybe_key.has_value() ||
155  (!opts.allow_negative_labels && *maybe_key < 0) ||
156  *maybe_key == kNoSymbol) {
157  LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
158  << value << "\", "
159  << "file = " << name << ", line = " << nline;
160  return nullptr;
161  }
162  impl->AddSymbol(symbol, *maybe_key);
163  }
164  impl->ShrinkToFit();
165  return impl.release();
166 }
167 
168 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
169  {
170  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
171  if (check_sum_finalized_) return;
172  }
173  // We'll acquire an exclusive lock to recompute the checksums.
174  MutexLock check_sum_lock(&check_sum_mutex_);
175  if (check_sum_finalized_) { // Another thread (coming in around the same time
176  return; // might have done it already). So we recheck.
177  }
178  // Calculates the original label-agnostic checksum.
179  CheckSummer check_sum;
180  for (size_t i = 0; i < symbols_.Size(); ++i) {
181  check_sum.Update(symbols_.GetSymbol(i));
182  check_sum.Update(std::string_view{"\0", 1});
183  }
184  check_sum_string_ = check_sum.Digest();
185  // Calculates the safer, label-dependent checksum.
186  CheckSummer labeled_check_sum;
187  for (int64_t i = 0; i < dense_key_limit_; ++i) {
188  std::ostringstream line;
189  line << symbols_.GetSymbol(i) << '\t' << i;
190  labeled_check_sum.Update(line.str());
191  }
192  using citer = std::map<int64_t, int64_t>::const_iterator;
193  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
194  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
195  // negative labels in the checksum that too many tests rely on.
196  if (it->first < dense_key_limit_) continue;
197  std::ostringstream line;
198  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
199  labeled_check_sum.Update(line.str());
200  }
201  labeled_check_sum_string_ = labeled_check_sum.Digest();
202  check_sum_finalized_ = true;
203 }
204 
205 std::string SymbolTableImpl::Find(int64_t key) const {
206  int64_t idx = key;
207  if (key < 0 || key >= dense_key_limit_) {
208  const auto it = key_map_.find(key);
209  if (it == key_map_.end()) return "";
210  idx = it->second;
211  }
212  if (idx < 0 || idx >= symbols_.Size()) return "";
213  return symbols_.GetSymbol(idx);
214 }
215 
216 int64_t SymbolTableImpl::AddSymbol(std::string_view symbol, int64_t key) {
217  if (key == kNoSymbol) return key;
218  if (const auto &[insert_key, inserted] = symbols_.InsertOrFind(symbol);
219  !inserted) {
220  const auto key_already = GetNthKey(insert_key);
221  if (key_already == key) return key;
222  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
223  << " already in symbol_map_ with key = " << key_already
224  << " but supplied new key = " << key << " (ignoring new key)";
225  return key_already;
226  }
227  if (key + 1 == static_cast<int64_t>(symbols_.Size()) &&
228  key == dense_key_limit_) {
229  ++dense_key_limit_;
230  } else {
231  idx_key_.push_back(key);
232  key_map_[key] = symbols_.Size() - 1;
233  }
234  if (key >= available_key_) available_key_ = key + 1;
235  check_sum_finalized_ = false;
236  return key;
237 }
238 
239 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
240 // the dense-key range or re-arranges the dense-key range from time to time.
241 void SymbolTableImpl::RemoveSymbol(const int64_t key) {
242  auto idx = key;
243  if (key < 0 || key >= dense_key_limit_) {
244  auto iter = key_map_.find(key);
245  if (iter == key_map_.end()) return;
246  idx = iter->second;
247  key_map_.erase(iter);
248  }
249  if (idx < 0 || idx >= static_cast<int64_t>(symbols_.Size())) return;
250  symbols_.RemoveSymbol(idx);
251  // Removed one symbol, all indexes > idx are shifted by -1.
252  for (auto &k : key_map_) {
253  if (k.second > idx) --k.second;
254  }
255  if (key >= 0 && key < dense_key_limit_) {
256  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
257  const auto new_dense_key_limit = key;
258  for (int64_t i = key + 1; i < dense_key_limit_; ++i) {
259  key_map_[i] = i - 1;
260  }
261  // Moves existing values in idx_key to new place.
262  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
263  for (int64_t i = symbols_.Size(); i >= dense_key_limit_; --i) {
264  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
265  }
266  // Adds indexes for previously dense keys.
267  for (int64_t i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
268  idx_key_[i - new_dense_key_limit] = i + 1;
269  }
270  dense_key_limit_ = new_dense_key_limit;
271  } else {
272  // Remove entry for removed index in idx_key.
273  for (size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
274  idx_key_[i] = idx_key_[i + 1];
275  }
276  idx_key_.pop_back();
277  }
278  if (key == available_key_ - 1) available_key_ = key;
279 }
280 
282  std::string_view source) {
283  int32_t magic_number = 0;
284  ReadType(strm, &magic_number);
285  if (strm.fail()) {
286  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
287  return nullptr;
288  }
289  std::string name;
290  ReadType(strm, &name);
291  auto impl = std::make_unique<SymbolTableImpl>(name);
292  ReadType(strm, &impl->available_key_);
293  int64_t size;
294  ReadType(strm, &size);
295  if (strm.fail()) {
296  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
297  return nullptr;
298  }
299  std::string symbol;
300  int64_t key;
301  impl->check_sum_finalized_ = false;
302  for (int64_t i = 0; i < size; ++i) {
303  ReadType(strm, &symbol);
304  ReadType(strm, &key);
305  if (strm.fail()) {
306  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
307  return nullptr;
308  }
309  impl->AddSymbol(symbol, key);
310  }
311  impl->ShrinkToFit();
312  return impl.release();
313 }
314 
315 bool SymbolTableImpl::Write(std::ostream &strm) const {
316  WriteType(strm, kSymbolTableMagicNumber);
317  WriteType(strm, name_);
318  WriteType(strm, available_key_);
319  const int64_t size = symbols_.Size();
320  WriteType(strm, size);
321  for (int64_t i = 0; i < dense_key_limit_; ++i) {
322  WriteType(strm, symbols_.GetSymbol(i));
323  WriteType(strm, i);
324  }
325  for (const auto &p : key_map_) {
326  WriteType(strm, symbols_.GetSymbol(p.second));
327  WriteType(strm, p.first);
328  }
329  strm.flush();
330  if (strm.fail()) {
331  LOG(ERROR) << "SymbolTable::Write: Write failed";
332  return false;
333  }
334  return true;
335 }
336 
337 void SymbolTableImpl::ShrinkToFit() { symbols_.ShrinkToFit(); }
338 
339 } // namespace internal
340 
341 SymbolTable *SymbolTable::ReadText(const std::string &source,
342  const SymbolTableTextOptions &opts) {
343  std::ifstream strm(source, std::ios_base::in);
344  if (!strm.good()) {
345  LOG(ERROR) << "SymbolTable::ReadText: Can't open file: " << source;
346  return nullptr;
347  }
348  return ReadText(strm, source, opts);
349 }
350 
351 bool SymbolTable::Write(const std::string &source) const {
352  if (!source.empty()) {
353  std::ofstream strm(source,
354  std::ios_base::out | std::ios_base::binary);
355  if (!strm) {
356  LOG(ERROR) << "SymbolTable::Write: Can't open file: " << source;
357  return false;
358  }
359  if (!Write(strm)) {
360  LOG(ERROR) << "SymbolTable::Write: Write failed: " << source;
361  return false;
362  }
363  return true;
364  } else {
365  return Write(std::cout);
366  }
367 }
368 
369 bool SymbolTable::WriteText(std::ostream &strm,
370  const SymbolTableTextOptions &opts) const {
371  if (opts.fst_field_separator.empty()) {
372  LOG(ERROR) << "Missing required field separator";
373  return false;
374  }
375  bool once_only = false;
376  for (const auto &item : *this) {
377  std::ostringstream line;
378  if (item.Label() < 0 && !opts.allow_negative_labels && !once_only) {
379  LOG(WARNING) << "Negative symbol table entry when not allowed";
380  once_only = true;
381  }
382  line << item.Symbol() << opts.fst_field_separator[0] << item.Label()
383  << '\n';
384  strm.write(line.str().data(), line.str().length());
385  }
386  return true;
387 }
388 
389 bool SymbolTable::WriteText(const std::string &source) const {
390  if (!source.empty()) {
391  std::ofstream strm(source);
392  if (!strm) {
393  LOG(ERROR) << "SymbolTable::WriteText: Can't open file: " << source;
394  return false;
395  }
396  if (!WriteText(strm, SymbolTableTextOptions())) {
397  LOG(ERROR) << "SymbolTable::WriteText: Write failed: " << source;
398  return false;
399  }
400  return true;
401  } else {
402  return WriteText(std::cout, SymbolTableTextOptions());
403  }
404 }
405 
406 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
407  bool warning) {
408  // Flag can explicitly override this check.
409  if (!FST_FLAGS_fst_compat_symbols) return true;
410  if (syms1 && syms2 &&
411  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
412  if (warning) {
413  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
414  << "Table sizes are " << syms1->NumSymbols() << " and "
415  << syms2->NumSymbols();
416  }
417  return false;
418  } else {
419  return true;
420  }
421 }
422 
423 void SymbolTableToString(const SymbolTable *table, std::string *result) {
424  std::ostringstream ostrm;
425  table->Write(ostrm);
426  *result = ostrm.str();
427 }
428 
429 SymbolTable *StringToSymbolTable(const std::string &str) {
430  std::istringstream istrm(str);
431  // TODO(jrosenstock): Change to source="string".
432  return SymbolTable::Read(istrm, /*source=*/"");
433 }
434 
435 } // namespace fst
bool Write(std::ostream &strm) const
Definition: symbol-table.h:479
void SymbolTableToString(const SymbolTable *table, std::string *result)
std::pair< int64_t, bool > InsertOrFind(std::string_view key)
Definition: symbol-table.cc:53
int64_t AddSymbol(std::string_view symbol, int64_t key) final
static SymbolTableImpl * Read(std::istream &strm, std::string_view source)
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:49
SymbolTable * StringToSymbolTable(const std::string &str)
bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
constexpr int kLineLen
Definition: symbol-table.h:63
constexpr int64_t kNoSymbol
Definition: symbol-table.h:49
std::string Find(int64_t key) const override
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
Definition: compat.cc:81
size_t NumSymbols() const
Definition: symbol-table.h:466
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:214
void SetName(std::string_view new_name) final
std::unique_ptr< SymbolTableImplBase > Copy() const final
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
Definition: util.cc:42
bool Write(std::ostream &strm) const override
static SymbolTable * ReadText(std::istream &strm, std::string_view name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
Definition: symbol-table.h:377
void AddTable(const SymbolTable &table) override
#define VLOG(level)
Definition: log.h:50
SymbolTableTextOptions(bool allow_negative_labels=false)
Definition: symbol-table.cc:39
int64_t Find(std::string_view key) const
Definition: symbol-table.cc:70
static SymbolTable * Read(std::istream &strm, const std::string &source)
Definition: symbol-table.h:391
void RemoveSymbol(int64_t key) override
int64_t AddSymbol(std::string_view symbol, int64_t key) override
void Update(std::string_view data)
Definition: compat.cc:42
static SymbolTableImpl * ReadText(std::istream &strm, std::string_view name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:68
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
std::string Digest()
Definition: compat.h:96
const std::string & LabeledCheckSum() const
Definition: symbol-table.h:453
void AddTable(const SymbolTable &table) final
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(int64_t key) final
void RemoveSymbol(size_t idx)
Definition: symbol-table.cc:93