FST  openfst-1.8.3
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
19 
20 #include <fst/symbol-table.h>
21 
22 #include <algorithm>
23 #include <cstddef>
24 #include <cstdint>
25 #include <ios>
26 #include <iostream>
27 #include <istream>
28 #include <memory>
29 #include <optional>
30 #include <ostream>
31 #include <sstream>
32 #include <string>
33 #include <utility>
34 #include <vector>
35 
36 #include <fst/flags.h>
37 #include <fst/log.h>
38 
39 #include <fstream>
40 #include <fst/util.h>
41 #include <map>
42 #include <functional>
43 #include <string_view>
44 #include <fst/lock.h>
45 
46 DEFINE_bool(fst_compat_symbols, true,
47  "Require symbol tables to match when appropriate");
48 DEFINE_string(fst_field_separator, "\t ",
49  "Set of characters used as a separator between printed fields");
50 
51 namespace fst {
52 namespace internal {
53 
54 // Identifies stream data as a symbol table (and its endianity).
55 static constexpr int32_t kSymbolTableMagicNumber = 2125658996;
56 
58  : str_hash_(),
59  buckets_(1 << 4, kEmptyBucket),
60  hash_mask_(buckets_.size() - 1) {}
61 
62 std::pair<int64_t, bool> DenseSymbolMap::InsertOrFind(std::string_view key) {
63  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
64  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
65  Rehash(buckets_.size() * 2);
66  }
67  size_t idx = GetHash(key);
68  while (buckets_[idx] != kEmptyBucket) {
69  const auto stored_value = buckets_[idx];
70  if (symbols_[stored_value] == key) return {stored_value, false};
71  idx = (idx + 1) & hash_mask_;
72  }
73  const auto next = Size();
74  buckets_[idx] = next;
75  symbols_.emplace_back(key);
76  return {next, true};
77 }
78 
79 int64_t DenseSymbolMap::Find(std::string_view key) const {
80  size_t idx = str_hash_(key) & hash_mask_;
81  while (buckets_[idx] != kEmptyBucket) {
82  const auto stored_value = buckets_[idx];
83  if (symbols_[stored_value] == key) return stored_value;
84  idx = (idx + 1) & hash_mask_;
85  }
86  return buckets_[idx];
87 }
88 
89 void DenseSymbolMap::Rehash(size_t num_buckets) {
90  buckets_.resize(num_buckets);
91  hash_mask_ = buckets_.size() - 1;
92  std::fill(buckets_.begin(), buckets_.end(), kEmptyBucket);
93  for (size_t i = 0; i < Size(); ++i) {
94  size_t idx = GetHash(symbols_[i]);
95  while (buckets_[idx] != kEmptyBucket) {
96  idx = (idx + 1) & hash_mask_;
97  }
98  buckets_[idx] = i;
99  }
100 }
101 
103  symbols_.erase(symbols_.begin() + idx);
104  Rehash(buckets_.size());
105 }
106 
107 void DenseSymbolMap::ShrinkToFit() { symbols_.shrink_to_fit(); }
108 
110  for (const auto &item : table) {
111  AddSymbol(item.Symbol());
112  }
113 }
114 
115 std::unique_ptr<SymbolTableImplBase>
117  LOG(FATAL) << "ConstSymbolTableImpl can't be copied";
118  return nullptr;
119 }
120 
121 int64_t ConstSymbolTableImpl::AddSymbol(std::string_view symbol, int64_t key) {
122  LOG(FATAL) << "ConstSymbolTableImpl does not support AddSymbol";
123  return kNoSymbol;
124 }
125 
126 int64_t ConstSymbolTableImpl::AddSymbol(std::string_view symbol) {
127  return AddSymbol(symbol, kNoSymbol);
128 }
129 
131  LOG(FATAL) << "ConstSymbolTableImpl does not support RemoveSymbol";
132 }
133 
134 void ConstSymbolTableImpl::SetName(std::string_view new_name) {
135  LOG(FATAL) << "ConstSymbolTableImpl does not support SetName";
136 }
137 
139  LOG(FATAL) << "ConstSymbolTableImpl does not support AddTable";
140 }
141 
143  std::istream &strm, std::string_view name, const std::string &sep) {
144  auto impl = std::make_unique<SymbolTableImpl>(name);
145  int64_t nline = 0;
146  char line[kLineLen];
147  const auto separator = sep + "\n";
148  while (!strm.getline(line, kLineLen).fail()) {
149  ++nline;
150  const std::vector<std::string_view> col =
151  StrSplit(line, ByAnyChar(separator), SkipEmpty());
152  if (col.empty()) continue; // Empty line.
153  if (col.size() != 2) {
154  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
155  << col.size() << "), "
156  << "file = " << name << ", line = " << nline << ": " << line;
157  return nullptr;
158  }
159  std::string_view symbol = col[0];
160  std::string_view value = col[1];
161  const auto maybe_key = ParseInt64(value);
162  if (!maybe_key.has_value() || *maybe_key == kNoSymbol) {
163  LOG(ERROR) << "SymbolTable::ReadText: Bad label (" << value << "), "
164  << "file = " << name << ", line = " << nline;
165  return nullptr;
166  }
167  impl->AddSymbol(symbol, *maybe_key);
168  }
169  impl->ShrinkToFit();
170  return impl.release();
171 }
172 
173 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
174  {
175  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
176  if (check_sum_finalized_) return;
177  }
178  // We'll acquire an exclusive lock to recompute the checksums.
179  MutexLock check_sum_lock(&check_sum_mutex_);
180  if (check_sum_finalized_) { // Another thread (coming in around the same time
181  return; // might have done it already). So we recheck.
182  }
183  // Calculates the original label-agnostic checksum.
184  CheckSummer check_sum;
185  for (size_t i = 0; i < symbols_.Size(); ++i) {
186  check_sum.Update(symbols_.GetSymbol(i));
187  check_sum.Update(std::string_view{"\0", 1});
188  }
189  check_sum_string_ = check_sum.Digest();
190  // Calculates the safer, label-dependent checksum.
191  CheckSummer labeled_check_sum;
192  for (int64_t i = 0; i < dense_key_limit_; ++i) {
193  std::ostringstream line;
194  line << symbols_.GetSymbol(i) << '\t' << i;
195  labeled_check_sum.Update(line.str());
196  }
197  using citer = std::map<int64_t, int64_t>::const_iterator;
198  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
199  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
200  // negative labels in the checksum that too many tests rely on.
201  if (it->first < dense_key_limit_) continue;
202  std::ostringstream line;
203  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
204  labeled_check_sum.Update(line.str());
205  }
206  labeled_check_sum_string_ = labeled_check_sum.Digest();
207  check_sum_finalized_ = true;
208 }
209 
210 std::string SymbolTableImpl::Find(int64_t key) const {
211  int64_t idx = key;
212  if (key < 0 || key >= dense_key_limit_) {
213  const auto it = key_map_.find(key);
214  if (it == key_map_.end()) return "";
215  idx = it->second;
216  }
217  if (idx < 0 || idx >= symbols_.Size()) return "";
218  return symbols_.GetSymbol(idx);
219 }
220 
221 int64_t SymbolTableImpl::AddSymbol(std::string_view symbol, int64_t key) {
222  if (key == kNoSymbol) return key;
223  if (const auto &[insert_key, inserted] = symbols_.InsertOrFind(symbol);
224  !inserted) {
225  const auto key_already = GetNthKey(insert_key);
226  if (key_already == key) return key;
227  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
228  << " already in symbol_map_ with key = " << key_already
229  << " but supplied new key = " << key << " (ignoring new key)";
230  return key_already;
231  }
232  if (key + 1 == static_cast<int64_t>(symbols_.Size()) &&
233  key == dense_key_limit_) {
234  ++dense_key_limit_;
235  } else {
236  idx_key_.push_back(key);
237  key_map_[key] = symbols_.Size() - 1;
238  }
239  if (key >= available_key_) available_key_ = key + 1;
240  check_sum_finalized_ = false;
241  return key;
242 }
243 
244 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
245 // the dense-key range or re-arranges the dense-key range from time to time.
246 void SymbolTableImpl::RemoveSymbol(const int64_t key) {
247  auto idx = key;
248  if (key < 0 || key >= dense_key_limit_) {
249  auto iter = key_map_.find(key);
250  if (iter == key_map_.end()) return;
251  idx = iter->second;
252  key_map_.erase(iter);
253  }
254  if (idx < 0 || idx >= static_cast<int64_t>(symbols_.Size())) return;
255  symbols_.RemoveSymbol(idx);
256  // Removed one symbol, all indexes > idx are shifted by -1.
257  for (auto &k : key_map_) {
258  if (k.second > idx) --k.second;
259  }
260  if (key >= 0 && key < dense_key_limit_) {
261  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
262  const auto new_dense_key_limit = key;
263  for (int64_t i = key + 1; i < dense_key_limit_; ++i) {
264  key_map_[i] = i - 1;
265  }
266  // Moves existing values in idx_key to new place.
267  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
268  for (int64_t i = symbols_.Size(); i >= dense_key_limit_; --i) {
269  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
270  }
271  // Adds indexes for previously dense keys.
272  for (int64_t i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
273  idx_key_[i - new_dense_key_limit] = i + 1;
274  }
275  dense_key_limit_ = new_dense_key_limit;
276  } else {
277  // Remove entry for removed index in idx_key.
278  for (size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
279  idx_key_[i] = idx_key_[i + 1];
280  }
281  idx_key_.pop_back();
282  }
283  if (key == available_key_ - 1) available_key_ = key;
284 }
285 
287  std::istream &strm, std::string_view source) {
288  int32_t magic_number = 0;
289  ReadType(strm, &magic_number);
290  if (strm.fail()) {
291  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
292  return nullptr;
293  }
294  std::string name;
295  ReadType(strm, &name);
296  auto impl = std::make_unique<SymbolTableImpl>(name);
297  ReadType(strm, &impl->available_key_);
298  int64_t size;
299  ReadType(strm, &size);
300  if (strm.fail()) {
301  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
302  return nullptr;
303  }
304  std::string symbol;
305  int64_t key;
306  impl->check_sum_finalized_ = false;
307  for (int64_t i = 0; i < size; ++i) {
308  ReadType(strm, &symbol);
309  ReadType(strm, &key);
310  if (strm.fail()) {
311  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
312  return nullptr;
313  }
314  impl->AddSymbol(symbol, key);
315  }
316  impl->ShrinkToFit();
317  return impl.release();
318 }
319 
320 bool SymbolTableImpl::Write(std::ostream &strm) const {
321  WriteType(strm, kSymbolTableMagicNumber);
322  WriteType(strm, name_);
323  WriteType(strm, available_key_);
324  const int64_t size = symbols_.Size();
325  WriteType(strm, size);
326  for (int64_t i = 0; i < dense_key_limit_; ++i) {
327  WriteType(strm, symbols_.GetSymbol(i));
328  WriteType(strm, i);
329  }
330  for (const auto &p : key_map_) {
331  WriteType(strm, symbols_.GetSymbol(p.second));
332  WriteType(strm, p.first);
333  }
334  strm.flush();
335  if (strm.fail()) {
336  LOG(ERROR) << "SymbolTable::Write: Write failed";
337  return false;
338  }
339  return true;
340 }
341 
342 void SymbolTableImpl::ShrinkToFit() { symbols_.ShrinkToFit(); }
343 
344 } // namespace internal
345 
346 SymbolTable * SymbolTable::ReadText(const std::string &source,
347  const std::string &sep) {
348  std::ifstream strm(source, std::ios_base::in);
349  if (!strm.good()) {
350  LOG(ERROR) << "SymbolTable::ReadText: Can't open file: " << source;
351  return nullptr;
352  }
353  return ReadText(strm, source, sep);
354 }
355 
356 bool SymbolTable::Write(const std::string &source) const {
357  if (!source.empty()) {
358  std::ofstream strm(source,
359  std::ios_base::out | std::ios_base::binary);
360  if (!strm) {
361  LOG(ERROR) << "SymbolTable::Write: Can't open file: " << source;
362  return false;
363  }
364  if (!Write(strm)) {
365  LOG(ERROR) << "SymbolTable::Write: Write failed: " << source;
366  return false;
367  }
368  return true;
369  } else {
370  return Write(std::cout);
371  }
372 }
373 
374 bool SymbolTable::WriteText(std::ostream &strm, const std::string &sep) const {
375  for (const auto &item : *this) {
376  std::ostringstream line;
377  line << item.Symbol() << sep[0] << item.Label() << '\n';
378  strm.write(line.str().data(), line.str().length());
379  }
380  return true;
381 }
382 
383 bool SymbolTable::WriteText(const std::string &sink,
384  const std::string &sep) const {
385  if (!sink.empty()) {
386  std::ofstream strm(sink);
387  if (!strm) {
388  LOG(ERROR) << "SymbolTable::WriteText: Can't open file: " << sink;
389  return false;
390  }
391  if (!WriteText(strm, sep)) {
392  LOG(ERROR) << "SymbolTable::WriteText: Write failed: " << sink;
393  return false;
394  }
395  return true;
396  } else {
397  return WriteText(std::cout, sep);
398  }
399 }
400 
401 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
402  bool warning) {
403  // Flag can explicitly override this check.
404  if (!FST_FLAGS_fst_compat_symbols) return true;
405  if (syms1 && syms2 &&
406  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
407  if (warning) {
408  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
409  << "Table sizes are " << syms1->NumSymbols() << " and "
410  << syms2->NumSymbols();
411  }
412  return false;
413  } else {
414  return true;
415  }
416 }
417 
418 void SymbolTableToString(const SymbolTable *table, std::string *result) {
419  std::ostringstream ostrm;
420  table->Write(ostrm);
421  *result = ostrm.str();
422 }
423 
424 SymbolTable *StringToSymbolTable(const std::string &str) {
425  std::istringstream istrm(str);
426  // TODO(jrosenstock): Change to source="string".
427  return SymbolTable::Read(istrm, /*source=*/"");
428 }
429 
430 } // namespace fst
bool Write(std::ostream &strm) const
Definition: symbol-table.h:483
void SymbolTableToString(const SymbolTable *table, std::string *result)
std::pair< int64_t, bool > InsertOrFind(std::string_view key)
Definition: symbol-table.cc:62
int64_t AddSymbol(std::string_view symbol, int64_t key) final
static SymbolTableImpl * Read(std::istream &strm, std::string_view source)
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:53
SymbolTable * StringToSymbolTable(const std::string &str)
constexpr int kLineLen
Definition: symbol-table.h:62
constexpr int64_t kNoSymbol
Definition: symbol-table.h:55
std::string Find(int64_t key) const override
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
Definition: compat.cc:77
size_t NumSymbols() const
Definition: symbol-table.h:470
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
void SetName(std::string_view new_name) final
std::unique_ptr< SymbolTableImplBase > Copy() const final
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
Definition: util.cc:46
bool Write(std::ostream &strm) const override
void AddTable(const SymbolTable &table) override
#define VLOG(level)
Definition: log.h:54
int64_t Find(std::string_view key) const
Definition: symbol-table.cc:79
void RemoveSymbol(int64_t key) override
int64_t AddSymbol(std::string_view symbol, int64_t key) override
void Update(std::string_view data)
Definition: compat.cc:38
static SymbolTable * ReadText(std::istream &strm, std::string_view name, const std::string &sep=FST_FLAGS_fst_field_separator)
Definition: symbol-table.h:381
static SymbolTableImpl * ReadText(std::istream &strm, std::string_view name, const std::string &sep=FST_FLAGS_fst_field_separator)
bool WriteText(std::ostream &strm, const std::string &sep=FST_FLAGS_fst_field_separator) const
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
std::string Digest()
Definition: compat.h:103
const std::string & LabeledCheckSum() const
Definition: symbol-table.h:457
void AddTable(const SymbolTable &table) final
static SymbolTable * Read(std::istream &strm, std::string_view source)
Definition: symbol-table.h:395
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(int64_t key) final