29 #include <string_view> 33 "Require symbol tables to match when appropriate");
35 "Set of characters used as a separator between printed fields");
40 : allow_negative_labels(allow_negative_labels),
41 fst_field_separator(FST_FLAGS_fst_field_separator) {}
46 static constexpr int32_t kSymbolTableMagicNumber = 2125658996;
48 DenseSymbolMap::DenseSymbolMap()
50 buckets_(1 << 4, kEmptyBucket),
51 hash_mask_(buckets_.size() - 1) {}
54 static constexpr
float kMaxOccupancyRatio = 0.75;
55 if (
Size() >= kMaxOccupancyRatio * buckets_.size()) {
56 Rehash(buckets_.size() * 2);
58 size_t idx = GetHash(key);
59 while (buckets_[idx] != kEmptyBucket) {
60 const auto stored_value = buckets_[idx];
61 if (symbols_[stored_value] == key)
return {stored_value,
false};
62 idx = (idx + 1) & hash_mask_;
64 const auto next =
Size();
66 symbols_.emplace_back(key);
71 size_t idx = str_hash_(key) & hash_mask_;
72 while (buckets_[idx] != kEmptyBucket) {
73 const auto stored_value = buckets_[idx];
74 if (symbols_[stored_value] == key)
return stored_value;
75 idx = (idx + 1) & hash_mask_;
80 void DenseSymbolMap::Rehash(
size_t num_buckets) {
81 buckets_.resize(num_buckets);
82 hash_mask_ = buckets_.size() - 1;
83 std::fill(buckets_.begin(), buckets_.end(), kEmptyBucket);
84 for (
size_t i = 0; i <
Size(); ++i) {
85 size_t idx = GetHash(symbols_[i]);
86 while (buckets_[idx] != kEmptyBucket) {
87 idx = (idx + 1) & hash_mask_;
94 symbols_.erase(symbols_.begin() + idx);
95 Rehash(buckets_.size());
101 for (
const auto &item : table) {
102 AddSymbol(item.Symbol());
107 LOG(FATAL) <<
"ConstSymbolTableImpl can't be copied";
112 LOG(FATAL) <<
"ConstSymbolTableImpl does not support AddSymbol";
121 LOG(FATAL) <<
"ConstSymbolTableImpl does not support RemoveSymbol";
125 LOG(FATAL) <<
"ConstSymbolTableImpl does not support SetName";
129 LOG(FATAL) <<
"ConstSymbolTableImpl does not support AddTable";
133 std::string_view name,
135 auto impl = std::make_unique<SymbolTableImpl>(name);
139 while (!strm.getline(line,
kLineLen).fail()) {
141 std::vector<std::string_view> col =
143 if (col.empty())
continue;
144 if (col.size() != 2) {
145 LOG(ERROR) <<
"SymbolTable::ReadText: Bad number of columns (" 146 << col.size() <<
"), " 147 <<
"file = " << name <<
", line = " << nline <<
":<" << line
151 std::string_view symbol = col[0];
152 std::string_view value = col[1];
154 if (!maybe_key.has_value() ||
157 LOG(ERROR) <<
"SymbolTable::ReadText: Bad non-negative integer \"" 159 <<
"file = " << name <<
", line = " << nline;
162 impl->AddSymbol(symbol, *maybe_key);
165 return impl.release();
168 void SymbolTableImpl::MaybeRecomputeCheckSum()
const {
171 if (check_sum_finalized_)
return;
174 MutexLock check_sum_lock(&check_sum_mutex_);
175 if (check_sum_finalized_) {
180 for (
size_t i = 0; i < symbols_.Size(); ++i) {
181 check_sum.
Update(symbols_.GetSymbol(i));
182 check_sum.
Update(std::string_view{
"\0", 1});
184 check_sum_string_ = check_sum.
Digest();
187 for (int64_t i = 0; i < dense_key_limit_; ++i) {
188 std::ostringstream line;
189 line << symbols_.GetSymbol(i) <<
'\t' << i;
190 labeled_check_sum.
Update(line.str());
192 using citer = std::map<int64_t, int64_t>::const_iterator;
193 for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
196 if (it->first < dense_key_limit_)
continue;
197 std::ostringstream line;
198 line << symbols_.GetSymbol(it->second) <<
'\t' << it->first;
199 labeled_check_sum.
Update(line.str());
201 labeled_check_sum_string_ = labeled_check_sum.
Digest();
202 check_sum_finalized_ =
true;
207 if (key < 0 || key >= dense_key_limit_) {
208 const auto it = key_map_.find(key);
209 if (it == key_map_.end())
return "";
212 if (idx < 0 || idx >= symbols_.Size())
return "";
213 return symbols_.GetSymbol(idx);
218 if (
const auto &[insert_key, inserted] = symbols_.InsertOrFind(symbol);
220 const auto key_already = GetNthKey(insert_key);
221 if (key_already == key)
return key;
222 VLOG(1) <<
"SymbolTable::AddSymbol: symbol = " << symbol
223 <<
" already in symbol_map_ with key = " << key_already
224 <<
" but supplied new key = " << key <<
" (ignoring new key)";
227 if (key + 1 == static_cast<int64_t>(symbols_.Size()) &&
228 key == dense_key_limit_) {
231 idx_key_.push_back(key);
232 key_map_[key] = symbols_.Size() - 1;
234 if (key >= available_key_) available_key_ = key + 1;
235 check_sum_finalized_ =
false;
243 if (key < 0 || key >= dense_key_limit_) {
244 auto iter = key_map_.find(key);
245 if (iter == key_map_.end())
return;
247 key_map_.erase(iter);
249 if (idx < 0 || idx >= static_cast<int64_t>(symbols_.Size()))
return;
250 symbols_.RemoveSymbol(idx);
252 for (
auto &k : key_map_) {
253 if (k.second > idx) --k.second;
255 if (key >= 0 && key < dense_key_limit_) {
257 const auto new_dense_key_limit = key;
258 for (int64_t i = key + 1; i < dense_key_limit_; ++i) {
262 idx_key_.resize(symbols_.Size() - new_dense_key_limit);
263 for (int64_t i = symbols_.Size(); i >= dense_key_limit_; --i) {
264 idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
267 for (int64_t i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
268 idx_key_[i - new_dense_key_limit] = i + 1;
270 dense_key_limit_ = new_dense_key_limit;
273 for (
size_t i = idx - dense_key_limit_; i + 1 < idx_key_.size(); ++i) {
274 idx_key_[i] = idx_key_[i + 1];
278 if (key == available_key_ - 1) available_key_ = key;
282 std::string_view source) {
283 int32_t magic_number = 0;
286 LOG(ERROR) <<
"SymbolTable::Read: Read failed: " << source;
291 auto impl = std::make_unique<SymbolTableImpl>(name);
292 ReadType(strm, &impl->available_key_);
296 LOG(ERROR) <<
"SymbolTable::Read: Read failed: " << source;
301 impl->check_sum_finalized_ =
false;
302 for (int64_t i = 0; i < size; ++i) {
306 LOG(ERROR) <<
"SymbolTable::Read: Read failed: " << source;
309 impl->AddSymbol(symbol, key);
312 return impl.release();
316 WriteType(strm, kSymbolTableMagicNumber);
319 const int64_t size = symbols_.Size();
321 for (int64_t i = 0; i < dense_key_limit_; ++i) {
325 for (
const auto &p : key_map_) {
326 WriteType(strm, symbols_.GetSymbol(p.second));
331 LOG(ERROR) <<
"SymbolTable::Write: Write failed";
343 std::ifstream strm(source, std::ios_base::in);
345 LOG(ERROR) <<
"SymbolTable::ReadText: Can't open file: " << source;
348 return ReadText(strm, source, opts);
352 if (!source.empty()) {
353 std::ofstream strm(source,
354 std::ios_base::out | std::ios_base::binary);
356 LOG(ERROR) <<
"SymbolTable::Write: Can't open file: " << source;
360 LOG(ERROR) <<
"SymbolTable::Write: Write failed: " << source;
365 return Write(std::cout);
372 LOG(ERROR) <<
"Missing required field separator";
375 bool once_only =
false;
376 for (
const auto &item : *
this) {
377 std::ostringstream line;
379 LOG(WARNING) <<
"Negative symbol table entry when not allowed";
384 strm.write(line.str().data(), line.str().length());
390 if (!source.empty()) {
391 std::ofstream strm(source);
393 LOG(ERROR) <<
"SymbolTable::WriteText: Can't open file: " << source;
397 LOG(ERROR) <<
"SymbolTable::WriteText: Write failed: " << source;
409 if (!FST_FLAGS_fst_compat_symbols)
return true;
410 if (syms1 && syms2 &&
413 LOG(WARNING) <<
"CompatSymbols: Symbol table checksums do not match. " 414 <<
"Table sizes are " << syms1->
NumSymbols() <<
" and " 424 std::ostringstream ostrm;
426 *result = ostrm.str();
430 std::istringstream istrm(str);
bool Write(std::ostream &strm) const
void SymbolTableToString(const SymbolTable *table, std::string *result)
std::pair< int64_t, bool > InsertOrFind(std::string_view key)
int64_t AddSymbol(std::string_view symbol, int64_t key) final
static SymbolTableImpl * Read(std::istream &strm, std::string_view source)
DEFINE_bool(fst_compat_symbols, true,"Require symbol tables to match when appropriate")
SymbolTable * StringToSymbolTable(const std::string &str)
bool WriteText(std::ostream &strm, const SymbolTableTextOptions &opts=SymbolTableTextOptions()) const
constexpr int64_t kNoSymbol
std::string Find(int64_t key) const override
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
size_t NumSymbols() const
std::ostream & WriteType(std::ostream &strm, const T t)
void SetName(std::string_view new_name) final
std::unique_ptr< SymbolTableImplBase > Copy() const final
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
bool Write(std::ostream &strm) const override
static SymbolTable * ReadText(std::istream &strm, std::string_view name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
void AddTable(const SymbolTable &table) override
SymbolTableTextOptions(bool allow_negative_labels=false)
int64_t Find(std::string_view key) const
static SymbolTable * Read(std::istream &strm, const std::string &source)
void RemoveSymbol(int64_t key) override
int64_t AddSymbol(std::string_view symbol, int64_t key) override
void Update(std::string_view data)
std::string fst_field_separator
bool allow_negative_labels
static SymbolTableImpl * ReadText(std::istream &strm, std::string_view name, const SymbolTableTextOptions &opts=SymbolTableTextOptions())
std::istream & ReadType(std::istream &strm, T *t)
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
const std::string & LabeledCheckSum() const
void AddTable(const SymbolTable &table) final
DEFINE_string(fst_field_separator,"\t ","Set of characters used as a separator between printed fields")
void RemoveSymbol(int64_t key) final
void RemoveSymbol(size_t idx)