FST  openfst-1.7.2
OpenFst Library
relabel.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes to relabel an FST (either on input or output).
5 
6 #ifndef FST_RELABEL_H_
7 #define FST_RELABEL_H_
8 
9 #include <string>
10 #include <unordered_map>
11 #include <utility>
12 #include <vector>
13 
14 #include <fst/log.h>
15 
16 #include <fst/cache.h>
17 #include <fst/test-properties.h>
18 
19 
20 #include <unordered_map>
21 
22 namespace fst {
23 
24 // Relabels either the input labels or output labels. The old to
25 // new labels are specified using a vector of std::pair<Label, Label>.
26 // Any label associations not specified are assumed to be identity
27 // mapping. The destination labels must be valid labels (e.g., not kNoLabel).
28 template <class Arc>
29 void Relabel(
31  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
32  &ipairs,
33  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
34  &opairs) {
35  using Label = typename Arc::Label;
36  const auto props = fst->Properties(kFstProperties, false);
37  // Constructs label-to-label maps.
38  const std::unordered_map<Label, Label> input_map(
39  ipairs.begin(), ipairs.end());
40  const std::unordered_map<Label, Label> output_map(
41  opairs.begin(), opairs.end());
42  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
43  siter.Next()) {
44  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
45  !aiter.Done(); aiter.Next()) {
46  auto arc = aiter.Value();
47  // Relabels input.
48  auto it = input_map.find(arc.ilabel);
49  if (it != input_map.end()) {
50  if (it->second == kNoLabel) {
51  FSTERROR() << "Input symbol ID " << arc.ilabel
52  << " missing from target vocabulary";
54  return;
55  }
56  arc.ilabel = it->second;
57  }
58  // Relabels output.
59  it = output_map.find(arc.olabel);
60  if (it != output_map.end()) {
61  if (it->second == kNoLabel) {
62  FSTERROR() << "Output symbol id " << arc.olabel
63  << " missing from target vocabulary";
65  return;
66  }
67  arc.olabel = it->second;
68  }
69  aiter.SetValue(arc);
70  }
71  }
73 }
74 
75 // Relabels either the input labels or output labels. The old to
76 // new labels are specified using pairs of old and new symbol tables.
77 // The tables must contain (at least) all labels on the appropriate side of the
78 // FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any
79 // missing symbol in new_i(o)symbols table.
80 template <class Arc>
82  const SymbolTable *old_isymbols, const SymbolTable *new_isymbols,
83  const string &unknown_isymbol, bool attach_new_isymbols,
84  const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
85  const string &unknown_osymbol, bool attach_new_osymbols) {
86  using Label = typename Arc::Label;
87  // Constructs vectors of input-side label pairs.
88  std::vector<std::pair<Label, Label>> ipairs;
89  if (old_isymbols && new_isymbols) {
90  size_t num_missing_syms = 0;
91  Label unknown_ilabel = kNoLabel;
92  if (!unknown_isymbol.empty()) {
93  unknown_ilabel = new_isymbols->Find(unknown_isymbol);
94  if (unknown_ilabel == kNoLabel) {
95  VLOG(1) << "Input symbol '" << unknown_isymbol
96  << "' missing from target symbol table";
97  ++num_missing_syms;
98  }
99  }
100 
101  for (SymbolTableIterator siter(*old_isymbols); !siter.Done();
102  siter.Next()) {
103  const auto old_index = siter.Value();
104  const auto symbol = siter.Symbol();
105  auto new_index = new_isymbols->Find(siter.Symbol());
106  if (new_index == kNoLabel) {
107  if (unknown_ilabel != kNoLabel) {
108  new_index = unknown_ilabel;
109  } else {
110  VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol
111  << "' missing from target symbol table";
112  ++num_missing_syms;
113  }
114  }
115  ipairs.push_back(std::make_pair(old_index, new_index));
116  }
117  if (num_missing_syms > 0) {
118  LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
119  << " input symbols";
120  }
121  if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols);
122  }
123  // Constructs vectors of output-side label pairs.
124  std::vector<std::pair<Label, Label>> opairs;
125  if (old_osymbols && new_osymbols) {
126  size_t num_missing_syms = 0;
127  Label unknown_olabel = kNoLabel;
128  if (!unknown_osymbol.empty()) {
129  unknown_olabel = new_osymbols->Find(unknown_osymbol);
130  if (unknown_olabel == kNoLabel) {
131  VLOG(1) << "Output symbol '" << unknown_osymbol
132  << "' missing from target symbol table";
133  ++num_missing_syms;
134  }
135  }
136 
137  for (SymbolTableIterator siter(*old_osymbols); !siter.Done();
138  siter.Next()) {
139  const auto old_index = siter.Value();
140  const auto symbol = siter.Symbol();
141  auto new_index = new_osymbols->Find(siter.Symbol());
142  if (new_index == kNoLabel) {
143  if (unknown_olabel != kNoLabel) {
144  new_index = unknown_olabel;
145  } else {
146  VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol
147  << "' missing from target symbol table";
148  ++num_missing_syms;
149  }
150  }
151  opairs.push_back(std::make_pair(old_index, new_index));
152  }
153  if (num_missing_syms > 0) {
154  LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
155  << " output symbols";
156  }
157  if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols);
158  }
159  // Calls relabel using vector of relabel pairs.
160  Relabel(fst, ipairs, opairs);
161 }
162 
163 // Same as previous but no special allowance for unknown symbols. Kept
164 // for backward compat.
165 template <class Arc>
166 void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
167  const SymbolTable *new_isymbols, bool attach_new_isymbols,
168  const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
169  bool attach_new_osymbols) {
170  Relabel(fst,
171  old_isymbols, new_isymbols, "" /* no unknown isymbol */,
172  attach_new_isymbols,
173  old_osymbols, new_osymbols, "" /* no unknown ioymbol */,
174  attach_new_osymbols);
175 }
176 
177 
178 // Relabels either the input labels or output labels. The old to
179 // new labels are specified using symbol tables. Any label associations not
180 // specified are assumed to be identity mapping.
181 template <class Arc>
182 void Relabel(MutableFst<Arc> *fst, const SymbolTable *new_isymbols,
183  const SymbolTable *new_osymbols) {
184  Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(),
185  new_osymbols, true);
186 }
187 
189 
190 template <class Arc>
191 class RelabelFst;
192 
193 namespace internal {
194 
195 // Relabels an FST from one symbol set to another. Relabeling can either be on
196 // input or output space. RelabelFst implements a delayed version of the
197 // relabel. Arcs are relabeled on the fly and not cached; i.e., each request is
198 // recomputed.
199 template <class Arc>
200 class RelabelFstImpl : public CacheImpl<Arc> {
201  public:
202  using Label = typename Arc::Label;
203  using StateId = typename Arc::StateId;
204  using Weight = typename Arc::Weight;
205 
207  using State = typename Store::State;
208 
209  using FstImpl<Arc>::SetType;
214 
222 
223  friend class StateIterator<RelabelFst<Arc>>;
224 
226  const std::vector<std::pair<Label, Label>> &ipairs,
227  const std::vector<std::pair<Label, Label>> &opairs,
228  const RelabelFstOptions &opts)
229  : CacheImpl<Arc>(opts),
230  fst_(fst.Copy()),
231  input_map_(ipairs.begin(), ipairs.end()),
232  output_map_(opairs.begin(), opairs.end()),
233  relabel_input_(!ipairs.empty()),
234  relabel_output_(!opairs.empty()) {
236  SetType("relabel");
237  }
238 
240  const SymbolTable *old_isymbols,
241  const SymbolTable *new_isymbols,
242  const SymbolTable *old_osymbols,
243  const SymbolTable *new_osymbols,
244  const RelabelFstOptions &opts)
245  : CacheImpl<Arc>(opts),
246  fst_(fst.Copy()),
247  relabel_input_(false),
248  relabel_output_(false) {
249  SetType("relabel");
251  SetInputSymbols(old_isymbols);
252  SetOutputSymbols(old_osymbols);
253  if (old_isymbols && new_isymbols &&
254  old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
255  for (SymbolTableIterator siter(*old_isymbols); !siter.Done();
256  siter.Next()) {
257  input_map_[siter.Value()] = new_isymbols->Find(siter.Symbol());
258  }
259  SetInputSymbols(new_isymbols);
260  relabel_input_ = true;
261  }
262  if (old_osymbols && new_osymbols &&
263  old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
264  for (SymbolTableIterator siter(*old_osymbols); !siter.Done();
265  siter.Next()) {
266  output_map_[siter.Value()] = new_osymbols->Find(siter.Symbol());
267  }
268  SetOutputSymbols(new_osymbols);
269  relabel_output_ = true;
270  }
271  }
272 
274  : CacheImpl<Arc>(impl),
275  fst_(impl.fst_->Copy(true)),
276  input_map_(impl.input_map_),
277  output_map_(impl.output_map_),
278  relabel_input_(impl.relabel_input_),
279  relabel_output_(impl.relabel_output_) {
280  SetType("relabel");
284  }
285 
287  if (!HasStart()) SetStart(fst_->Start());
288  return CacheImpl<Arc>::Start();
289  }
290 
292  if (!HasFinal(s)) SetFinal(s, fst_->Final(s));
293  return CacheImpl<Arc>::Final(s);
294  }
295 
296  size_t NumArcs(StateId s) {
297  if (!HasArcs(s)) Expand(s);
298  return CacheImpl<Arc>::NumArcs(s);
299  }
300 
302  if (!HasArcs(s)) Expand(s);
304  }
305 
307  if (!HasArcs(s)) Expand(s);
309  }
310 
311  uint64 Properties() const override { return Properties(kFstProperties); }
312 
313  // Sets error if found, and returns other FST impl properties.
314  uint64 Properties(uint64 mask) const override {
315  if ((mask & kError) && fst_->Properties(kError, false)) {
316  SetProperties(kError, kError);
317  }
318  return FstImpl<Arc>::Properties(mask);
319  }
320 
322  if (!HasArcs(s)) Expand(s);
324  }
325 
326  void Expand(StateId s) {
327  for (ArcIterator<Fst<Arc>> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
328  auto arc = aiter.Value();
329  if (relabel_input_) {
330  auto it = input_map_.find(arc.ilabel);
331  if (it != input_map_.end()) arc.ilabel = it->second;
332  }
333  if (relabel_output_) {
334  auto it = output_map_.find(arc.olabel);
335  if (it != output_map_.end()) {
336  arc.olabel = it->second;
337  }
338  }
339  PushArc(s, std::move(arc));
340  }
341  SetArcs(s);
342  }
343 
344  private:
345  std::unique_ptr<const Fst<Arc>> fst_;
346 
347  std::unordered_map<Label, Label> input_map_;
348  std::unordered_map<Label, Label> output_map_;
349  bool relabel_input_;
350  bool relabel_output_;
351 };
352 
353 } // namespace internal
354 
355 // This class attaches interface to implementation and handles
356 // reference counting, delegating most methods to ImplToFst.
357 template <class A>
358 class RelabelFst : public ImplToFst<internal::RelabelFstImpl<A>> {
359  public:
360  using Arc = A;
361  using Label = typename Arc::Label;
362  using StateId = typename Arc::StateId;
363  using Weight = typename Arc::Weight;
364 
366  using State = typename Store::State;
368 
369  friend class ArcIterator<RelabelFst<A>>;
370  friend class StateIterator<RelabelFst<A>>;
371 
373  const std::vector<std::pair<Label, Label>> &ipairs,
374  const std::vector<std::pair<Label, Label>> &opairs,
375  const RelabelFstOptions &opts = RelabelFstOptions())
376  : ImplToFst<Impl>(std::make_shared<Impl>(fst, ipairs, opairs, opts)) {}
377 
378  RelabelFst(const Fst<Arc> &fst, const SymbolTable *new_isymbols,
379  const SymbolTable *new_osymbols,
380  const RelabelFstOptions &opts = RelabelFstOptions())
381  : ImplToFst<Impl>(
382  std::make_shared<Impl>(fst, fst.InputSymbols(), new_isymbols,
383  fst.OutputSymbols(), new_osymbols, opts)) {}
384 
385  RelabelFst(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
386  const SymbolTable *new_isymbols, const SymbolTable *old_osymbols,
387  const SymbolTable *new_osymbols,
388  const RelabelFstOptions &opts = RelabelFstOptions())
389  : ImplToFst<Impl>(std::make_shared<Impl>(fst, old_isymbols, new_isymbols,
390  old_osymbols, new_osymbols,
391  opts)) {}
392 
393  // See Fst<>::Copy() for doc.
394  RelabelFst(const RelabelFst<Arc> &fst, bool safe = false)
395  : ImplToFst<Impl>(fst, safe) {}
396 
397  // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc.
398  RelabelFst<Arc> *Copy(bool safe = false) const override {
399  return new RelabelFst<Arc>(*this, safe);
400  }
401 
402  void InitStateIterator(StateIteratorData<Arc> *data) const override;
403 
404  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
405  return GetMutableImpl()->InitArcIterator(s, data);
406  }
407 
408  private:
411 
412  RelabelFst &operator=(const RelabelFst &) = delete;
413 };
414 
415 // Specialization for RelabelFst.
416 template <class Arc>
418  public:
419  using StateId = typename Arc::StateId;
420 
421  explicit StateIterator(const RelabelFst<Arc> &fst)
422  : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
423 
424  bool Done() const final { return siter_.Done(); }
425 
426  StateId Value() const final { return s_; }
427 
428  void Next() final {
429  if (!siter_.Done()) {
430  ++s_;
431  siter_.Next();
432  }
433  }
434 
435  void Reset() final {
436  s_ = 0;
437  siter_.Reset();
438  }
439 
440  private:
441  const internal::RelabelFstImpl<Arc>* impl_;
442  StateIterator<Fst<Arc>> siter_;
443  StateId s_;
444 
445  StateIterator(const StateIterator &) = delete;
446  StateIterator &operator=(const StateIterator &) = delete;
447 };
448 
449 // Specialization for RelabelFst.
450 template <class Arc>
451 class ArcIterator<RelabelFst<Arc>> : public CacheArcIterator<RelabelFst<Arc>> {
452  public:
453  using StateId = typename Arc::StateId;
454 
456  : CacheArcIterator<RelabelFst<Arc>>(fst.GetMutableImpl(), s) {
457  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
458  }
459 };
460 
461 template <class Arc>
463  StateIteratorData<Arc> *data) const {
464  data->base = new StateIterator<RelabelFst<Arc>>(*this);
465 }
466 
467 // Useful alias when using StdArc.
469 
470 } // namespace fst
471 
472 #endif // FST_RELABEL_H_
void SetFinal(StateId s, Weight weight)
Definition: cache.h:905
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: relabel.h:462
constexpr int kNoLabel
Definition: fst.h:179
virtual const string & LabeledCheckSum() const
Definition: symbol-table.h:308
RelabelFst(const RelabelFst< Arc > &fst, bool safe=false)
Definition: relabel.h:394
uint64_t uint64
Definition: types.h:32
StateIterator(const RelabelFst< Arc > &fst)
Definition: relabel.h:421
const SymbolTable * OutputSymbols() const
Definition: fst.h:690
const SymbolTable * InputSymbols() const override=0
void Expand(StateId s)
Definition: relabel.h:326
virtual void SetInputSymbols(const SymbolTable *isyms)=0
#define LOG(type)
Definition: log.h:48
typename Arc::StateId StateId
Definition: relabel.h:203
void SetOutputSymbols(const SymbolTable *osyms)
Definition: fst.h:700
typename Arc::StateId StateId
Definition: fst.h:330
typename Arc::Weight Weight
Definition: relabel.h:204
constexpr uint64 kFstProperties
Definition: properties.h:301
constexpr uint64 kCopyProperties
Definition: properties.h:138
virtual uint64 Properties() const
Definition: fst.h:666
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
Definition: relabel.h:29
virtual uint64 Properties(uint64 mask, bool test) const =0
void SetType(const string &type)
Definition: fst.h:664
uint64 RelabelProperties(uint64 inprops)
Definition: properties.cc:307
#define FSTERROR()
Definition: util.h:35
const SymbolTable * OutputSymbols() const override=0
StateIteratorBase< Arc > * base
Definition: fst.h:351
RelabelFstImpl(const RelabelFstImpl< Arc > &impl)
Definition: relabel.h:273
uint64 Properties(uint64 mask) const override
Definition: relabel.h:314
Weight Final(StateId s)
Definition: relabel.h:291
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:648
RelabelFst< Arc > * Copy(bool safe=false) const override
Definition: relabel.h:398
#define VLOG(level)
Definition: log.h:49
size_t NumInputEpsilons(StateId s)
Definition: relabel.h:301
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
Definition: cache.h:1041
void SetInputSymbols(const SymbolTable *isyms)
Definition: fst.h:696
RelabelFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts)
Definition: relabel.h:225
size_t NumOutputEpsilons(StateId s)
Definition: relabel.h:306
virtual string Find(int64 key) const
Definition: symbol-table.h:301
void PushArc(StateId s, const Arc &arc)
Definition: cache.h:932
ArcIterator(const RelabelFst< Arc > &fst, StateId s)
Definition: relabel.h:455
size_t NumArcs(StateId s)
Definition: relabel.h:296
RelabelFst(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:385
uint64 Properties() const override
Definition: relabel.h:311
typename RelabelFst< Arc >::Arc Arc
Definition: cache.h:1194
typename Arc::StateId StateId
Definition: relabel.h:362
const SymbolTable * InputSymbols() const
Definition: fst.h:688
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: relabel.h:321
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: relabel.h:404
typename Arc::Label Label
Definition: relabel.h:361
typename Arc::Label Label
Definition: relabel.h:202
constexpr uint64 kError
Definition: properties.h:33
typename Store::State State
Definition: relabel.h:366
RelabelFst(const Fst< Arc > &fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:378
RelabelFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:372
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
RelabelFstImpl(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
Definition: relabel.h:239
typename CacheState< Arc >::Arc Arc
Definition: cache.h:837
internal::RelabelFstImpl< A > * GetMutableImpl() const
Definition: fst.h:947
typename Arc::StateId StateId
Definition: relabel.h:453
CacheOptions RelabelFstOptions
Definition: relabel.h:188
const internal::RelabelFstImpl< A > * GetImpl() const
Definition: fst.h:945
typename Arc::Weight Weight
Definition: relabel.h:363
virtual void SetProperties(uint64 props, uint64 mask)=0