20 #ifndef FST_RELABEL_H_ 21 #define FST_RELABEL_H_ 33 #include <unordered_map> 44 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
46 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
48 using Label =
typename Arc::Label;
51 const std::unordered_map<Label, Label> input_map(
52 ipairs.begin(), ipairs.end());
53 const std::unordered_map<Label, Label> output_map(
54 opairs.begin(), opairs.end());
58 !aiter.Done(); aiter.Next()) {
59 auto arc = aiter.Value();
65 auto it = input_map.find(arc.ilabel);
66 if (it != input_map.end()) {
68 FSTERROR() <<
"Input symbol ID " << arc.ilabel
69 <<
" missing from target vocabulary";
73 arc.ilabel = it->second;
76 it = output_map.find(arc.olabel);
77 if (it != output_map.end()) {
79 FSTERROR() <<
"Output symbol id " << arc.olabel
80 <<
" missing from target vocabulary";
84 arc.olabel = it->second;
100 const std::string &unknown_isymbol,
bool attach_new_isymbols,
102 const std::string &unknown_osymbol,
bool attach_new_osymbols) {
103 using Label =
typename Arc::Label;
105 std::vector<std::pair<Label, Label>> ipairs;
106 if (old_isymbols && new_isymbols) {
107 size_t num_missing_syms = 0;
109 if (!unknown_isymbol.empty()) {
110 unknown_ilabel = new_isymbols->
Find(unknown_isymbol);
112 VLOG(1) <<
"Input symbol '" << unknown_isymbol
113 <<
"' missing from target symbol table";
118 for (
const auto &sitem : *old_isymbols) {
119 const auto old_index = sitem.Label();
120 const auto symbol = sitem.Symbol();
121 auto new_index = new_isymbols->
Find(symbol);
124 new_index = unknown_ilabel;
126 VLOG(1) <<
"Input symbol ID " << old_index <<
" symbol '" << symbol
127 <<
"' missing from target symbol table";
131 ipairs.emplace_back(old_index, new_index);
133 if (num_missing_syms > 0) {
134 LOG(WARNING) <<
"Target symbol table missing: " << num_missing_syms
140 std::vector<std::pair<Label, Label>> opairs;
141 if (old_osymbols && new_osymbols) {
142 size_t num_missing_syms = 0;
144 if (!unknown_osymbol.empty()) {
145 unknown_olabel = new_osymbols->
Find(unknown_osymbol);
147 VLOG(1) <<
"Output symbol '" << unknown_osymbol
148 <<
"' missing from target symbol table";
152 for (
const auto &sitem : *old_osymbols) {
153 const auto old_index = sitem.Label();
154 const auto symbol = sitem.Symbol();
155 auto new_index = new_osymbols->
Find(symbol);
158 new_index = unknown_olabel;
160 VLOG(1) <<
"Output symbol ID " << old_index <<
" symbol '" << symbol
161 <<
"' missing from target symbol table";
165 opairs.emplace_back(old_index, new_index);
167 if (num_missing_syms > 0) {
168 LOG(WARNING) <<
"Target symbol table missing: " << num_missing_syms
169 <<
" output symbols";
181 const SymbolTable *new_isymbols,
bool attach_new_isymbols,
183 bool attach_new_osymbols) {
184 Relabel(fst, old_isymbols, new_isymbols,
"" ,
185 attach_new_isymbols, old_osymbols, new_osymbols,
186 "" , attach_new_osymbols);
237 const std::vector<std::pair<Label, Label>> &ipairs,
238 const std::vector<std::pair<Label, Label>> &opairs,
242 input_map_(ipairs.begin(), ipairs.end()),
243 output_map_(opairs.begin(), opairs.end()),
244 relabel_input_(!ipairs.empty()),
245 relabel_output_(!opairs.empty()) {
256 relabel_input_(false),
257 relabel_output_(false) {
262 if (old_isymbols && new_isymbols &&
264 for (
const auto &sitem : *old_isymbols) {
265 input_map_[sitem.Label()] = new_isymbols->
Find(sitem.Symbol());
268 relabel_input_ =
true;
270 if (old_osymbols && new_osymbols &&
272 for (
const auto &sitem : *old_osymbols) {
273 output_map_[sitem.Label()] = new_osymbols->
Find(sitem.Symbol());
276 relabel_output_ =
true;
282 fst_(impl.fst_->Copy(true)),
283 input_map_(impl.input_map_),
284 output_map_(impl.output_map_),
285 relabel_input_(impl.relabel_input_),
286 relabel_output_(impl.relabel_output_) {
322 if ((mask &
kError) && fst_->Properties(kError,
false)) {
335 auto arc = aiter.Value();
336 if (relabel_input_) {
337 auto it = input_map_.find(arc.ilabel);
338 if (it != input_map_.end()) arc.ilabel = it->second;
340 if (relabel_output_) {
341 auto it = output_map_.find(arc.olabel);
342 if (it != output_map_.end()) {
343 arc.olabel = it->second;
352 std::unique_ptr<const Fst<Arc>> fst_;
354 std::unordered_map<Label, Label> input_map_;
355 std::unordered_map<Label, Label> output_map_;
357 bool relabel_output_;
380 const std::vector<std::pair<Label, Label>> &ipairs,
381 const std::vector<std::pair<Label, Label>> &opairs,
397 old_osymbols, new_osymbols,
412 return GetMutableImpl()->InitArcIterator(s, data);
429 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
431 bool Done() const final {
return siter_.Done(); }
436 if (!siter_.Done()) {
471 data->
base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
479 #endif // FST_RELABEL_H_ void InitStateIterator(StateIteratorData< Arc > *data) const override
void SetProperties(uint64_t props)
size_t NumInputEpsilons(StateId s) const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
bool HasFinal(StateId s) const
StateIterator(const RelabelFst< Arc > &fst)
void SetFinal(StateId s, Weight weight=Weight::One())
const SymbolTable * OutputSymbols() const
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
typename Arc::StateId StateId
void SetOutputSymbols(const SymbolTable *osyms)
typename Arc::StateId StateId
typename Arc::Weight Weight
uint64_t RelabelProperties(uint64_t inprops)
size_t NumArcs(StateId s) const
StateId Value() const final
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
uint64_t Properties(uint64_t mask) const override
const SymbolTable * OutputSymbols() const override=0
RelabelFstImpl(const RelabelFstImpl< Arc > &impl)
virtual uint64_t Properties() const
constexpr uint64_t kCopyProperties
virtual void SetProperties(uint64_t props, uint64_t mask)=0
RelabelFst * Copy(bool safe=false) const override
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
size_t NumInputEpsilons(StateId s)
std::unique_ptr< StateIteratorBase< Arc > > base
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
void SetInputSymbols(const SymbolTable *isyms)
bool HasArcs(StateId s) const
RelabelFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts)
size_t NumOutputEpsilons(StateId s)
constexpr uint64_t kFstProperties
void PushArc(StateId s, const Arc &arc)
ArcIterator(const RelabelFst< Arc > &fst, StateId s)
size_t NumArcs(StateId s)
RelabelFst(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
typename RelabelFst< Arc >::Arc Arc
typename Arc::StateId StateId
const SymbolTable * InputSymbols() const
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
void SetType(std::string_view type)
Weight Final(StateId s) const
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
typename Arc::Label Label
typename Arc::Label Label
typename Store::State State
RelabelFst(const Fst< Arc > &fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
RelabelFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts=RelabelFstOptions())
std::string Find(int64_t key) const
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
const std::string & LabeledCheckSum() const
RelabelFstImpl(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
typename CacheState< Arc >::Arc Arc
internal::RelabelFstImpl< A > * GetMutableImpl() const
typename Arc::StateId StateId
RelabelFst(const RelabelFst &fst, bool safe=false)
size_t NumOutputEpsilons(StateId s) const
CacheOptions RelabelFstOptions
const internal::RelabelFstImpl< A > * GetImpl() const
uint64_t Properties() const override
typename Arc::Weight Weight