FST  openfst-1.8.4
OpenFst Library
symbol-table.cc
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
19 
20 #include <fst/symbol-table.h>
21 
22 #include <algorithm>
23 #include <cerrno>
24 #include <cstddef>
25 #include <cstdint>
26 #include <cstring>
27 #include <ios>
28 #include <iostream>
29 #include <istream>
30 #include <memory>
31 #include <ostream>
32 #include <sstream>
33 #include <string>
34 #include <utility>
35 #include <vector>
36 
37 #include <fst/flags.h>
38 #include <fst/log.h>
39 
40 #include <fstream>
41 #include <fst/util.h>
42 #include <map>
43 #include <functional>
44 #include <fst/compat.h>
45 #include <string_view>
46 #include <fst/lock.h>
47 
48 DEFINE_bool(fst_compat_symbols, true,
49  "Require symbol tables to match when appropriate");
50 DEFINE_string(fst_field_separator, "\t ",
51  "Set of characters used as a separator between printed fields");
52 
53 namespace fst {
54 namespace internal {
55 
56 // Identifies stream data as a symbol table (and its endianity).
57 static constexpr int32_t kSymbolTableMagicNumber = 2125658996;
58 
60  : str_hash_(),
61  buckets_(1 << 4, kEmptyBucket),
62  hash_mask_(buckets_.size() - 1) {}
63 
64 std::pair<int64_t, bool> DenseSymbolMap::InsertOrFind(std::string_view key) {
65  static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
66  if (Size() >= kMaxOccupancyRatio * buckets_.size()) {
67  Rehash(buckets_.size() * 2);
68  }
69  size_t idx = GetHash(key);
70  while (buckets_[idx] != kEmptyBucket) {
71  const auto stored_value = buckets_[idx];
72  if (symbols_[stored_value] == key) return {stored_value, false};
73  idx = (idx + 1) & hash_mask_;
74  }
75  const auto next = Size();
76  buckets_[idx] = next;
77  symbols_.emplace_back(key);
78  return {next, true};
79 }
80 
81 int64_t DenseSymbolMap::Find(std::string_view key) const {
82  size_t idx = str_hash_(key) & hash_mask_;
83  while (buckets_[idx] != kEmptyBucket) {
84  const auto stored_value = buckets_[idx];
85  if (symbols_[stored_value] == key) return stored_value;
86  idx = (idx + 1) & hash_mask_;
87  }
88  return buckets_[idx];
89 }
90 
91 void DenseSymbolMap::Rehash(size_t num_buckets) {
92  buckets_.resize(num_buckets);
93  hash_mask_ = buckets_.size() - 1;
94  std::fill(buckets_.begin(), buckets_.end(), kEmptyBucket);
95  for (size_t i = 0; i < Size(); ++i) {
96  size_t idx = GetHash(symbols_[i]);
97  while (buckets_[idx] != kEmptyBucket) {
98  idx = (idx + 1) & hash_mask_;
99  }
100  buckets_[idx] = i;
101  }
102 }
103 
105  symbols_.erase(symbols_.begin() + idx);
106  Rehash(buckets_.size());
107 }
108 
109 void DenseSymbolMap::ShrinkToFit() { symbols_.shrink_to_fit(); }
110 
112  for (const auto &item : table) {
113  AddSymbol(item.Symbol());
114  }
115 }
116 
117 std::unique_ptr<SymbolTableImplBase>
119  LOG(FATAL) << "ConstSymbolTableImpl can't be copied";
120  return nullptr;
121 }
122 
123 int64_t ConstSymbolTableImpl::AddSymbol(std::string_view symbol, int64_t key) {
124  LOG(FATAL) << "ConstSymbolTableImpl does not support AddSymbol";
125  return kNoSymbol;
126 }
127 
128 int64_t ConstSymbolTableImpl::AddSymbol(std::string_view symbol) {
129  return AddSymbol(symbol, kNoSymbol);
130 }
131 
133  LOG(FATAL) << "ConstSymbolTableImpl does not support RemoveSymbol";
134 }
135 
136 void ConstSymbolTableImpl::SetName(std::string_view new_name) {
137  LOG(FATAL) << "ConstSymbolTableImpl does not support SetName";
138 }
139 
141  LOG(FATAL) << "ConstSymbolTableImpl does not support AddTable";
142 }
143 
145  std::istream &strm, std::string_view name, std::string_view sep) {
146  auto impl = std::make_unique<SymbolTableImpl>(name);
147  int64_t nline = 0;
148  char line[kLineLen];
149  const std::string separator = fst::StrCat(sep, "\n");
150  while (!strm.getline(line, kLineLen).fail()) {
151  ++nline;
152  const std::vector<std::string_view> col =
153  StrSplit(line, ByAnyChar(separator), SkipEmpty());
154  if (col.empty()) continue; // Empty line.
155  if (col.size() != 2) {
156  LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
157  << col.size() << "), "
158  << "file = " << name << ", line = " << nline << ": " << line;
159  return nullptr;
160  }
161  std::string_view symbol = col[0];
162  std::string_view value = col[1];
163  const auto maybe_key = ParseInt64(value);
164  if (!maybe_key.has_value() || *maybe_key == kNoSymbol) {
165  LOG(ERROR) << "SymbolTable::ReadText: Bad label (" << value << "), "
166  << "file = " << name << ", line = " << nline;
167  return nullptr;
168  }
169  impl->AddSymbol(symbol, *maybe_key);
170  }
171  impl->ShrinkToFit();
172  return impl.release();
173 }
174 
175 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
176  {
177  ReaderMutexLock check_sum_lock(&check_sum_mutex_);
178  if (check_sum_finalized_) return;
179  }
180  // We'll acquire an exclusive lock to recompute the checksums.
181  MutexLock check_sum_lock(&check_sum_mutex_);
182  if (check_sum_finalized_) { // Another thread (coming in around the same time
183  return; // might have done it already). So we recheck.
184  }
185  // Calculates the original label-agnostic checksum.
186  CheckSummer check_sum;
187  for (size_t i = 0; i < symbols_.Size(); ++i) {
188  check_sum.Update(symbols_.GetSymbol(i));
189  check_sum.Update(std::string_view{"\0", 1});
190  }
191  check_sum_string_ = check_sum.Digest();
192  // Calculates the safer, label-dependent checksum.
193  CheckSummer labeled_check_sum;
194  for (int64_t i = 0; i < dense_key_limit_; ++i) {
195  std::ostringstream line;
196  line << symbols_.GetSymbol(i) << '\t' << i;
197  labeled_check_sum.Update(line.str());
198  }
199  using citer = std::map<int64_t, int64_t>::const_iterator;
200  for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
201  // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
202  // negative labels in the checksum that too many tests rely on.
203  if (it->first < dense_key_limit_) continue;
204  std::ostringstream line;
205  line << symbols_.GetSymbol(it->second) << '\t' << it->first;
206  labeled_check_sum.Update(line.str());
207  }
208  labeled_check_sum_string_ = labeled_check_sum.Digest();
209  check_sum_finalized_ = true;
210 }
211 
212 std::string SymbolTableImpl::Find(int64_t key) const {
213  int64_t idx = key;
214  if (key < 0 || key >= dense_key_limit_) {
215  const auto it = key_map_.find(key);
216  if (it == key_map_.end()) return "";
217  idx = it->second;
218  }
219  if (idx < 0 || idx >= symbols_.Size()) return "";
220  return symbols_.GetSymbol(idx);
221 }
222 
223 int64_t SymbolTableImpl::AddSymbol(std::string_view symbol, int64_t key) {
224  if (key == kNoSymbol) return key;
225  if (const auto &[insert_key, inserted] = symbols_.InsertOrFind(symbol);
226  !inserted) {
227  const auto key_already = GetNthKey(insert_key);
228  if (key_already == key) return key;
229  VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
230  << " already in symbol_map_ with key = " << key_already
231  << " but supplied new key = " << key << " (ignoring new key)";
232  return key_already;
233  }
234  if (key + 1 == static_cast<int64_t>(symbols_.Size()) &&
235  key == dense_key_limit_) {
236  ++dense_key_limit_;
237  } else {
238  idx_key_.push_back(key);
239  key_map_[key] = symbols_.Size() - 1;
240  }
241  if (key >= available_key_) available_key_ = key + 1;
242  check_sum_finalized_ = false;
243  return key;
244 }
245 
246 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
247 // the dense-key range or re-arranges the dense-key range from time to time.
248 void SymbolTableImpl::RemoveSymbol(const int64_t key) {
249  auto idx = key;
250  if (key < 0 || key >= dense_key_limit_) {
251  auto iter = key_map_.find(key);
252  if (iter == key_map_.end()) return;
253  idx = iter->second;
254  key_map_.erase(iter);
255  }
256  if (idx < 0 || idx >= static_cast<int64_t>(symbols_.Size())) return;
257  symbols_.RemoveSymbol(idx);
258  // Removed one symbol, all indexes > idx are shifted by -1.
259  for (auto &k : key_map_) {
260  if (k.second > idx) --k.second;
261  }
262  if (key >= 0 && key < dense_key_limit_) {
263  // Removal puts a hole in the dense key range. Adjusts range to [0, key).
264  const auto new_dense_key_limit = key;
265  for (int64_t i = key + 1; i < dense_key_limit_; ++i) {
266  key_map_[i] = i - 1;
267  }
268  // Moves existing values in idx_key to new place.
269  idx_key_.resize(symbols_.Size() - new_dense_key_limit);
270  for (int64_t i = symbols_.Size(); i >= dense_key_limit_; --i) {
271  idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
272  }
273  // Adds indexes for previously dense keys.
274  for (int64_t i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
275  idx_key_[i - new_dense_key_limit] = i + 1;
276  }
277  dense_key_limit_ = new_dense_key_limit;
278  } else {
279  // Remove entry for removed index in idx_key.
280  for (size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
281  idx_key_[i] = idx_key_[i + 1];
282  }
283  idx_key_.pop_back();
284  }
285  if (key == available_key_ - 1) available_key_ = key;
286 }
287 
289  std::istream &strm, std::string_view source) {
290  int32_t magic_number = 0;
291  ReadType(strm, &magic_number);
292  if (strm.fail()) {
293  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
294  return nullptr;
295  }
296  std::string name;
297  ReadType(strm, &name);
298  auto impl = std::make_unique<SymbolTableImpl>(name);
299  ReadType(strm, &impl->available_key_);
300  int64_t size;
301  ReadType(strm, &size);
302  if (strm.fail()) {
303  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
304  return nullptr;
305  }
306  std::string symbol;
307  int64_t key;
308  impl->check_sum_finalized_ = false;
309  for (int64_t i = 0; i < size; ++i) {
310  ReadType(strm, &symbol);
311  ReadType(strm, &key);
312  if (strm.fail()) {
313  LOG(ERROR) << "SymbolTable::Read: Read failed: " << source;
314  return nullptr;
315  }
316  impl->AddSymbol(symbol, key);
317  }
318  impl->ShrinkToFit();
319  return impl.release();
320 }
321 
322 bool SymbolTableImpl::Write(std::ostream &strm) const {
323  WriteType(strm, kSymbolTableMagicNumber);
324  WriteType(strm, name_);
325  WriteType(strm, available_key_);
326  const int64_t size = symbols_.Size();
327  WriteType(strm, size);
328  for (int64_t i = 0; i < dense_key_limit_; ++i) {
329  WriteType(strm, symbols_.GetSymbol(i));
330  WriteType(strm, i);
331  }
332  for (const auto &p : key_map_) {
333  WriteType(strm, symbols_.GetSymbol(p.second));
334  WriteType(strm, p.first);
335  }
336  strm.flush();
337  if (strm.fail()) {
338  LOG(ERROR) << "SymbolTable::Write: Write failed";
339  return false;
340  }
341  return true;
342 }
343 
344 void SymbolTableImpl::ShrinkToFit() { symbols_.ShrinkToFit(); }
345 
346 } // namespace internal
347 
348 SymbolTable * SymbolTable::ReadText(const std::string &source,
349  std::string_view sep) {
350  std::ifstream strm(source, std::ios_base::in);
351  if (!strm.good()) {
352  LOG(ERROR) << "SymbolTable::ReadText: Can't open file: " << source;
353  return nullptr;
354  }
355  return ReadText(strm, source, sep);
356 }
357 
358 bool SymbolTable::Write(const std::string &source) const {
359  if (!source.empty()) {
360  std::ofstream strm(source,
361  std::ios_base::out | std::ios_base::binary);
362  if (!strm) {
363  LOG(ERROR) << "SymbolTable::Write: Can't open file: " << source;
364  return false;
365  }
366  if (!Write(strm)) {
367  LOG(ERROR) << "SymbolTable::Write: Write failed: " << source;
368  return false;
369  }
370  return true;
371  } else {
372  return Write(std::cout);
373  }
374 }
375 
376 bool SymbolTable::WriteText(std::ostream &strm, std::string_view sep) const {
377  for (const auto &item : *this) {
378  std::ostringstream line;
379  line << item.Symbol() << sep[0] << item.Label() << '\n';
380  strm.write(line.str().data(), line.str().length());
381  }
382  return true;
383 }
384 
385 bool SymbolTable::WriteText(const std::string &sink,
386  std::string_view sep) const {
387  if (!sink.empty()) {
388  std::ofstream strm(sink);
389  if (!strm) {
390  LOG(ERROR) << "SymbolTable::WriteText: Can't open file: " << sink;
391  return false;
392  }
393  if (!WriteText(strm, sep)) {
394  LOG(ERROR) << "SymbolTable::WriteText: Write failed: " << sink;
395  return false;
396  }
397  return true;
398  } else {
399  return WriteText(std::cout, sep);
400  }
401 }
402 
403 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
404  bool warning) {
405  // Flag can explicitly override this check.
406  if (!FST_FLAGS_fst_compat_symbols) return true;
407  if (syms1 && syms2 &&
408  (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
409  if (warning) {
410  LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
411  << "Table sizes are " << syms1->NumSymbols() << " and "
412  << syms2->NumSymbols();
413  }
414  return false;
415  } else {
416  return true;
417  }
418 }
419 
420 void SymbolTableToString(const SymbolTable *table, std::string *result) {
421  std::ostringstream ostrm;
422  table->Write(ostrm);
423  *result = ostrm.str();
424 }
425 
426 SymbolTable *StringToSymbolTable(const std::string &str) {
427  std::istringstream istrm(str);
428  // TODO(jrosenstock): Change to source="string".
429  return SymbolTable::Read(istrm, /*source=*/"");
430 }
431 
432 } // namespace fst
bool Write(std::ostream &strm) const
Definition: symbol-table.h:479
bool WriteText(std::ostream &strm, std::string_view sep=FST_FLAGS_fst_field_separator) const
void SymbolTableToString(const SymbolTable *table, std::string *result)
std::pair< int64_t, bool > InsertOrFind(std::string_view key)
Definition: symbol-table.cc:64
int64_t AddSymbol(std::string_view symbol, int64_t key) final
static SymbolTableImpl * Read(std::istream &strm, std::string_view source)
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
#define LOG(type)
Definition: log.h:53
SymbolTable * StringToSymbolTable(const std::string &str)
constexpr int kLineLen
Definition: symbol-table.h:58
constexpr int64_t kNoSymbol
Definition: symbol-table.h:51
std::string Find(int64_t key) const override
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
Definition: compat.cc:75
size_t NumSymbols() const
Definition: symbol-table.h:466
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:229
void SetName(std::string_view new_name) final
std::unique_ptr< SymbolTableImplBase > Copy() const final
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
Definition: util.cc:46
bool Write(std::ostream &strm) const override
void AddTable(const SymbolTable &table) override
#define VLOG(level)
Definition: log.h:54
int64_t Find(std::string_view key) const
Definition: symbol-table.cc:81
void RemoveSymbol(int64_t key) override
int64_t AddSymbol(std::string_view symbol, int64_t key) override
void Update(std::string_view data)
Definition: compat.cc:36
static SymbolTableImpl * ReadText(std::istream &strm, std::string_view name, std::string_view sep=FST_FLAGS_fst_field_separator)
std::string StrCat(const StringOrInt &s1, const StringOrInt &s2)
Definition: compat.h:295
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:81
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
static SymbolTable * ReadText(std::istream &strm, std::string_view name, std::string_view sep=FST_FLAGS_fst_field_separator)
Definition: symbol-table.h:377
std::string Digest()
Definition: compat.h:105
const std::string & LabeledCheckSum() const
Definition: symbol-table.h:453
void AddTable(const SymbolTable &table) final
static SymbolTable * Read(std::istream &strm, std::string_view source)
Definition: symbol-table.h:391
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(int64_t key) final