20 #ifndef FST_RELABEL_H_ 21 #define FST_RELABEL_H_ 40 #include <unordered_map> 51 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
53 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
55 using Label =
typename Arc::Label;
58 const std::unordered_map<Label, Label> input_map(
59 ipairs.begin(), ipairs.end());
60 const std::unordered_map<Label, Label> output_map(
61 opairs.begin(), opairs.end());
65 !aiter.Done(); aiter.Next()) {
66 auto arc = aiter.Value();
72 if (
auto it = input_map.find(arc.ilabel); it != input_map.end()) {
74 FSTERROR() <<
"Input symbol ID " << arc.ilabel
75 <<
" missing from target vocabulary";
79 arc.ilabel = it->second;
82 if (
auto it = output_map.find(arc.olabel); it != output_map.end()) {
84 FSTERROR() <<
"Output symbol id " << arc.olabel
85 <<
" missing from target vocabulary";
89 arc.olabel = it->second;
105 const std::string &unknown_isymbol,
bool attach_new_isymbols,
107 const std::string &unknown_osymbol,
bool attach_new_osymbols) {
108 using Label =
typename Arc::Label;
110 std::vector<std::pair<Label, Label>> ipairs;
111 if (old_isymbols && new_isymbols) {
112 size_t num_missing_syms = 0;
114 if (!unknown_isymbol.empty()) {
115 unknown_ilabel = new_isymbols->
Find(unknown_isymbol);
117 VLOG(1) <<
"Input symbol '" << unknown_isymbol
118 <<
"' missing from target symbol table";
123 for (
const auto &sitem : *old_isymbols) {
124 const auto old_index = sitem.Label();
125 const auto symbol = sitem.Symbol();
126 auto new_index = new_isymbols->
Find(symbol);
129 new_index = unknown_ilabel;
131 VLOG(1) <<
"Input symbol ID " << old_index <<
" symbol '" << symbol
132 <<
"' missing from target symbol table";
136 ipairs.emplace_back(old_index, new_index);
138 if (num_missing_syms > 0) {
139 LOG(WARNING) <<
"Target symbol table missing: " << num_missing_syms
145 std::vector<std::pair<Label, Label>> opairs;
146 if (old_osymbols && new_osymbols) {
147 size_t num_missing_syms = 0;
149 if (!unknown_osymbol.empty()) {
150 unknown_olabel = new_osymbols->
Find(unknown_osymbol);
152 VLOG(1) <<
"Output symbol '" << unknown_osymbol
153 <<
"' missing from target symbol table";
157 for (
const auto &sitem : *old_osymbols) {
158 const auto old_index = sitem.Label();
159 const auto symbol = sitem.Symbol();
160 auto new_index = new_osymbols->
Find(symbol);
163 new_index = unknown_olabel;
165 VLOG(1) <<
"Output symbol ID " << old_index <<
" symbol '" << symbol
166 <<
"' missing from target symbol table";
170 opairs.emplace_back(old_index, new_index);
172 if (num_missing_syms > 0) {
173 LOG(WARNING) <<
"Target symbol table missing: " << num_missing_syms
174 <<
" output symbols";
186 const SymbolTable *new_isymbols,
bool attach_new_isymbols,
188 bool attach_new_osymbols) {
189 Relabel(fst, old_isymbols, new_isymbols,
"" ,
190 attach_new_isymbols, old_osymbols, new_osymbols,
191 "" , attach_new_osymbols);
242 const std::vector<std::pair<Label, Label>> &ipairs,
243 const std::vector<std::pair<Label, Label>> &opairs,
247 input_map_(ipairs.begin(), ipairs.end()),
248 output_map_(opairs.begin(), opairs.end()),
249 relabel_input_(!ipairs.empty()),
250 relabel_output_(!opairs.empty()) {
261 relabel_input_(false),
262 relabel_output_(false) {
267 if (old_isymbols && new_isymbols &&
269 for (
const auto &sitem : *old_isymbols) {
270 input_map_[sitem.Label()] = new_isymbols->
Find(sitem.Symbol());
273 relabel_input_ =
true;
275 if (old_osymbols && new_osymbols &&
277 for (
const auto &sitem : *old_osymbols) {
278 output_map_[sitem.Label()] = new_osymbols->
Find(sitem.Symbol());
281 relabel_output_ =
true;
287 fst_(impl.fst_->Copy(true)),
288 input_map_(impl.input_map_),
289 output_map_(impl.output_map_),
290 relabel_input_(impl.relabel_input_),
291 relabel_output_(impl.relabel_output_) {
327 if ((mask &
kError) && fst_->Properties(kError,
false)) {
340 auto arc = aiter.Value();
341 if (relabel_input_) {
342 if (
auto it = input_map_.find(arc.ilabel); it != input_map_.end()) {
343 arc.ilabel = it->second;
346 if (relabel_output_) {
347 if (
auto it = output_map_.find(arc.olabel); it != output_map_.end()) {
348 arc.olabel = it->second;
357 std::unique_ptr<const Fst<Arc>> fst_;
359 std::unordered_map<Label, Label> input_map_;
360 std::unordered_map<Label, Label> output_map_;
362 bool relabel_output_;
385 const std::vector<std::pair<Label, Label>> &ipairs,
386 const std::vector<std::pair<Label, Label>> &opairs,
402 old_osymbols, new_osymbols,
417 return GetMutableImpl()->InitArcIterator(s, data);
434 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
436 bool Done() const final {
return siter_.Done(); }
441 if (!siter_.Done()) {
476 data->
base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
484 #endif // FST_RELABEL_H_ void InitStateIterator(StateIteratorData< Arc > *data) const override
void SetProperties(uint64_t props)
size_t NumInputEpsilons(StateId s) const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
bool HasFinal(StateId s) const
StateIterator(const RelabelFst< Arc > &fst)
void SetFinal(StateId s, Weight weight=Weight::One())
const SymbolTable * OutputSymbols() const
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
typename Arc::StateId StateId
void SetOutputSymbols(const SymbolTable *osyms)
typename Arc::StateId StateId
typename Arc::Weight Weight
uint64_t RelabelProperties(uint64_t inprops)
size_t NumArcs(StateId s) const
StateId Value() const final
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
uint64_t Properties(uint64_t mask) const override
const SymbolTable * OutputSymbols() const override=0
RelabelFstImpl(const RelabelFstImpl< Arc > &impl)
virtual uint64_t Properties() const
constexpr uint64_t kCopyProperties
virtual void SetProperties(uint64_t props, uint64_t mask)=0
RelabelFst * Copy(bool safe=false) const override
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
size_t NumInputEpsilons(StateId s)
std::unique_ptr< StateIteratorBase< Arc > > base
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
void SetInputSymbols(const SymbolTable *isyms)
bool HasArcs(StateId s) const
RelabelFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts)
size_t NumOutputEpsilons(StateId s)
constexpr uint64_t kFstProperties
void PushArc(StateId s, const Arc &arc)
ArcIterator(const RelabelFst< Arc > &fst, StateId s)
size_t NumArcs(StateId s)
RelabelFst(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
typename RelabelFst< Arc >::Arc Arc
typename Arc::StateId StateId
const SymbolTable * InputSymbols() const
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
void SetType(std::string_view type)
Weight Final(StateId s) const
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
typename Arc::Label Label
typename Arc::Label Label
typename Store::State State
RelabelFst(const Fst< Arc > &fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
RelabelFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts=RelabelFstOptions())
std::string Find(int64_t key) const
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
const std::string & LabeledCheckSum() const
RelabelFstImpl(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
typename CacheState< Arc >::Arc Arc
internal::RelabelFstImpl< A > * GetMutableImpl() const
typename Arc::StateId StateId
RelabelFst(const RelabelFst &fst, bool safe=false)
size_t NumOutputEpsilons(StateId s) const
CacheOptions RelabelFstOptions
const internal::RelabelFstImpl< A > * GetImpl() const
uint64_t Properties() const override
typename Arc::Weight Weight