FST  openfst-1.8.2
OpenFst Library
relabel.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes to relabel an FST (either on input or output).
19 
20 #ifndef FST_RELABEL_H_
21 #define FST_RELABEL_H_
22 
23 #include <cstdint>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include <fst/log.h>
29 
30 #include <fst/cache.h>
31 #include <fst/test-properties.h>
32 
33 #include <unordered_map>
34 
35 namespace fst {
36 
37 // Relabels either the input labels or output labels. The old to
38 // new labels are specified using a vector of std::pair<Label, Label>.
39 // Any label associations not specified are assumed to be identity
40 // mapping. The destination labels must be valid labels (e.g., not kNoLabel).
41 template <class Arc>
42 void Relabel(
44  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
45  &ipairs,
46  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
47  &opairs) {
48  using Label = typename Arc::Label;
49  const auto props = fst->Properties(kFstProperties, false);
50  // Constructs label-to-label maps.
51  const std::unordered_map<Label, Label> input_map(
52  ipairs.begin(), ipairs.end());
53  const std::unordered_map<Label, Label> output_map(
54  opairs.begin(), opairs.end());
55  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
56  siter.Next()) {
57  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
58  !aiter.Done(); aiter.Next()) {
59  auto arc = aiter.Value();
60  // dense_hash_map does not support find on the empty_key_val.
61  // These labels should never be in an FST anyway.
62  DCHECK_NE(arc.ilabel, kNoLabel);
63  DCHECK_NE(arc.olabel, kNoLabel);
64  // Relabels input.
65  auto it = input_map.find(arc.ilabel);
66  if (it != input_map.end()) {
67  if (it->second == kNoLabel) {
68  FSTERROR() << "Input symbol ID " << arc.ilabel
69  << " missing from target vocabulary";
71  return;
72  }
73  arc.ilabel = it->second;
74  }
75  // Relabels output.
76  it = output_map.find(arc.olabel);
77  if (it != output_map.end()) {
78  if (it->second == kNoLabel) {
79  FSTERROR() << "Output symbol id " << arc.olabel
80  << " missing from target vocabulary";
82  return;
83  }
84  arc.olabel = it->second;
85  }
86  aiter.SetValue(arc);
87  }
88  }
90 }
91 
92 // Relabels either the input labels or output labels. The old to
93 // new labels are specified using pairs of old and new symbol tables.
94 // The tables must contain (at least) all labels on the appropriate side of the
95 // FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any
96 // missing symbol in new_i(o)symbols table.
97 template <class Arc>
98 void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
99  const SymbolTable *new_isymbols,
100  const std::string &unknown_isymbol, bool attach_new_isymbols,
101  const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
102  const std::string &unknown_osymbol, bool attach_new_osymbols) {
103  using Label = typename Arc::Label;
104  // Constructs vectors of input-side label pairs.
105  std::vector<std::pair<Label, Label>> ipairs;
106  if (old_isymbols && new_isymbols) {
107  size_t num_missing_syms = 0;
108  Label unknown_ilabel = kNoLabel;
109  if (!unknown_isymbol.empty()) {
110  unknown_ilabel = new_isymbols->Find(unknown_isymbol);
111  if (unknown_ilabel == kNoLabel) {
112  VLOG(1) << "Input symbol '" << unknown_isymbol
113  << "' missing from target symbol table";
114  ++num_missing_syms;
115  }
116  }
117 
118  for (const auto &sitem : *old_isymbols) {
119  const auto old_index = sitem.Label();
120  const auto symbol = sitem.Symbol();
121  auto new_index = new_isymbols->Find(symbol);
122  if (new_index == kNoLabel) {
123  if (unknown_ilabel != kNoLabel) {
124  new_index = unknown_ilabel;
125  } else {
126  VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol
127  << "' missing from target symbol table";
128  ++num_missing_syms;
129  }
130  }
131  ipairs.emplace_back(old_index, new_index);
132  }
133  if (num_missing_syms > 0) {
134  LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
135  << " input symbols";
136  }
137  if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols);
138  }
139  // Constructs vectors of output-side label pairs.
140  std::vector<std::pair<Label, Label>> opairs;
141  if (old_osymbols && new_osymbols) {
142  size_t num_missing_syms = 0;
143  Label unknown_olabel = kNoLabel;
144  if (!unknown_osymbol.empty()) {
145  unknown_olabel = new_osymbols->Find(unknown_osymbol);
146  if (unknown_olabel == kNoLabel) {
147  VLOG(1) << "Output symbol '" << unknown_osymbol
148  << "' missing from target symbol table";
149  ++num_missing_syms;
150  }
151  }
152  for (const auto &sitem : *old_osymbols) {
153  const auto old_index = sitem.Label();
154  const auto symbol = sitem.Symbol();
155  auto new_index = new_osymbols->Find(symbol);
156  if (new_index == kNoLabel) {
157  if (unknown_olabel != kNoLabel) {
158  new_index = unknown_olabel;
159  } else {
160  VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol
161  << "' missing from target symbol table";
162  ++num_missing_syms;
163  }
164  }
165  opairs.emplace_back(old_index, new_index);
166  }
167  if (num_missing_syms > 0) {
168  LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
169  << " output symbols";
170  }
171  if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols);
172  }
173  // Calls relabel using vector of relabel pairs.
174  Relabel(fst, ipairs, opairs);
175 }
176 
177 // Same as previous but no special allowance for unknown symbols. Kept
178 // for backward compat.
179 template <class Arc>
180 void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
181  const SymbolTable *new_isymbols, bool attach_new_isymbols,
182  const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
183  bool attach_new_osymbols) {
184  Relabel(fst, old_isymbols, new_isymbols, "" /* no unknown isymbol */,
185  attach_new_isymbols, old_osymbols, new_osymbols,
186  "" /* no unknown osymbol */, attach_new_osymbols);
187 }
188 
189 // Relabels either the input labels or output labels. The old to
190 // new labels are specified using symbol tables. Any label associations not
191 // specified are assumed to be identity mapping.
192 template <class Arc>
193 void Relabel(MutableFst<Arc> *fst, const SymbolTable *new_isymbols,
194  const SymbolTable *new_osymbols) {
195  Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(),
196  new_osymbols, true);
197 }
198 
200 
201 template <class Arc>
202 class RelabelFst;
203 
204 namespace internal {
205 
206 // Relabels an FST from one symbol set to another. Relabeling can either be on
207 // input or output space. RelabelFst implements a delayed version of the
208 // relabel. Arcs are relabeled on the fly and not cached; i.e., each request is
209 // recomputed.
210 template <class Arc>
211 class RelabelFstImpl : public CacheImpl<Arc> {
212  public:
213  using Label = typename Arc::Label;
214  using StateId = typename Arc::StateId;
215  using Weight = typename Arc::Weight;
216 
218  using State = typename Store::State;
219 
220  using FstImpl<Arc>::SetType;
225 
233 
234  friend class StateIterator<RelabelFst<Arc>>;
235 
237  const std::vector<std::pair<Label, Label>> &ipairs,
238  const std::vector<std::pair<Label, Label>> &opairs,
239  const RelabelFstOptions &opts)
240  : CacheImpl<Arc>(opts),
241  fst_(fst.Copy()),
242  input_map_(ipairs.begin(), ipairs.end()),
243  output_map_(opairs.begin(), opairs.end()),
244  relabel_input_(!ipairs.empty()),
245  relabel_output_(!opairs.empty()) {
247  SetType("relabel");
248  }
249 
250  RelabelFstImpl(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
251  const SymbolTable *new_isymbols,
252  const SymbolTable *old_osymbols,
253  const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
254  : CacheImpl<Arc>(opts),
255  fst_(fst.Copy()),
256  relabel_input_(false),
257  relabel_output_(false) {
258  SetType("relabel");
260  SetInputSymbols(old_isymbols);
261  SetOutputSymbols(old_osymbols);
262  if (old_isymbols && new_isymbols &&
263  old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
264  for (const auto &sitem : *old_isymbols) {
265  input_map_[sitem.Label()] = new_isymbols->Find(sitem.Symbol());
266  }
267  SetInputSymbols(new_isymbols);
268  relabel_input_ = true;
269  }
270  if (old_osymbols && new_osymbols &&
271  old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
272  for (const auto &sitem : *old_osymbols) {
273  output_map_[sitem.Label()] = new_osymbols->Find(sitem.Symbol());
274  }
275  SetOutputSymbols(new_osymbols);
276  relabel_output_ = true;
277  }
278  }
279 
281  : CacheImpl<Arc>(impl),
282  fst_(impl.fst_->Copy(true)),
283  input_map_(impl.input_map_),
284  output_map_(impl.output_map_),
285  relabel_input_(impl.relabel_input_),
286  relabel_output_(impl.relabel_output_) {
287  SetType("relabel");
291  }
292 
294  if (!HasStart()) SetStart(fst_->Start());
295  return CacheImpl<Arc>::Start();
296  }
297 
299  if (!HasFinal(s)) SetFinal(s, fst_->Final(s));
300  return CacheImpl<Arc>::Final(s);
301  }
302 
303  size_t NumArcs(StateId s) {
304  if (!HasArcs(s)) Expand(s);
305  return CacheImpl<Arc>::NumArcs(s);
306  }
307 
309  if (!HasArcs(s)) Expand(s);
311  }
312 
314  if (!HasArcs(s)) Expand(s);
316  }
317 
318  uint64_t Properties() const override { return Properties(kFstProperties); }
319 
320  // Sets error if found, and returns other FST impl properties.
321  uint64_t Properties(uint64_t mask) const override {
322  if ((mask & kError) && fst_->Properties(kError, false)) {
323  SetProperties(kError, kError);
324  }
325  return FstImpl<Arc>::Properties(mask);
326  }
327 
329  if (!HasArcs(s)) Expand(s);
331  }
332 
333  void Expand(StateId s) {
334  for (ArcIterator<Fst<Arc>> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
335  auto arc = aiter.Value();
336  if (relabel_input_) {
337  auto it = input_map_.find(arc.ilabel);
338  if (it != input_map_.end()) arc.ilabel = it->second;
339  }
340  if (relabel_output_) {
341  auto it = output_map_.find(arc.olabel);
342  if (it != output_map_.end()) {
343  arc.olabel = it->second;
344  }
345  }
346  PushArc(s, std::move(arc));
347  }
348  SetArcs(s);
349  }
350 
351  private:
352  std::unique_ptr<const Fst<Arc>> fst_;
353 
354  std::unordered_map<Label, Label> input_map_;
355  std::unordered_map<Label, Label> output_map_;
356  bool relabel_input_;
357  bool relabel_output_;
358 };
359 
360 } // namespace internal
361 
362 // This class attaches interface to implementation and handles
363 // reference counting, delegating most methods to ImplToFst.
364 template <class A>
365 class RelabelFst : public ImplToFst<internal::RelabelFstImpl<A>> {
366  public:
367  using Arc = A;
368  using Label = typename Arc::Label;
369  using StateId = typename Arc::StateId;
370  using Weight = typename Arc::Weight;
371 
373  using State = typename Store::State;
375 
376  friend class ArcIterator<RelabelFst<A>>;
377  friend class StateIterator<RelabelFst<A>>;
378 
380  const std::vector<std::pair<Label, Label>> &ipairs,
381  const std::vector<std::pair<Label, Label>> &opairs,
382  const RelabelFstOptions &opts = RelabelFstOptions())
383  : ImplToFst<Impl>(std::make_shared<Impl>(fst, ipairs, opairs, opts)) {}
384 
385  RelabelFst(const Fst<Arc> &fst, const SymbolTable *new_isymbols,
386  const SymbolTable *new_osymbols,
387  const RelabelFstOptions &opts = RelabelFstOptions())
388  : ImplToFst<Impl>(
389  std::make_shared<Impl>(fst, fst.InputSymbols(), new_isymbols,
390  fst.OutputSymbols(), new_osymbols, opts)) {}
391 
392  RelabelFst(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
393  const SymbolTable *new_isymbols, const SymbolTable *old_osymbols,
394  const SymbolTable *new_osymbols,
395  const RelabelFstOptions &opts = RelabelFstOptions())
396  : ImplToFst<Impl>(std::make_shared<Impl>(fst, old_isymbols, new_isymbols,
397  old_osymbols, new_osymbols,
398  opts)) {}
399 
400  // See Fst<>::Copy() for doc.
401  RelabelFst(const RelabelFst &fst, bool safe = false)
402  : ImplToFst<Impl>(fst, safe) {}
403 
404  // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc.
405  RelabelFst *Copy(bool safe = false) const override {
406  return new RelabelFst(*this, safe);
407  }
408 
409  void InitStateIterator(StateIteratorData<Arc> *data) const override;
410 
411  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
412  return GetMutableImpl()->InitArcIterator(s, data);
413  }
414 
415  private:
418 
419  RelabelFst &operator=(const RelabelFst &) = delete;
420 };
421 
422 // Specialization for RelabelFst.
423 template <class Arc>
425  public:
426  using StateId = typename Arc::StateId;
427 
428  explicit StateIterator(const RelabelFst<Arc> &fst)
429  : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
430 
431  bool Done() const final { return siter_.Done(); }
432 
433  StateId Value() const final { return s_; }
434 
435  void Next() final {
436  if (!siter_.Done()) {
437  ++s_;
438  siter_.Next();
439  }
440  }
441 
442  void Reset() final {
443  s_ = 0;
444  siter_.Reset();
445  }
446 
447  private:
448  const internal::RelabelFstImpl<Arc> *impl_;
449  StateIterator<Fst<Arc>> siter_;
450  StateId s_;
451 
452  StateIterator(const StateIterator &) = delete;
453  StateIterator &operator=(const StateIterator &) = delete;
454 };
455 
456 // Specialization for RelabelFst.
457 template <class Arc>
458 class ArcIterator<RelabelFst<Arc>> : public CacheArcIterator<RelabelFst<Arc>> {
459  public:
460  using StateId = typename Arc::StateId;
461 
463  : CacheArcIterator<RelabelFst<Arc>>(fst.GetMutableImpl(), s) {
464  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
465  }
466 };
467 
468 template <class Arc>
470  StateIteratorData<Arc> *data) const {
471  data->base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
472 }
473 
474 // Useful alias when using StdArc.
476 
477 } // namespace fst
478 
479 #endif // FST_RELABEL_H_
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: relabel.h:469
constexpr int kNoLabel
Definition: fst.h:201
virtual uint64_t Properties(uint64_t mask, bool test) const =0
StateIterator(const RelabelFst< Arc > &fst)
Definition: relabel.h:428
void SetFinal(StateId s, Weight weight=Weight::One())
Definition: cache.h:923
const SymbolTable * OutputSymbols() const
Definition: fst.h:762
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
Definition: properties.h:51
void Expand(StateId s)
Definition: relabel.h:333
virtual void SetInputSymbols(const SymbolTable *isyms)=0
#define LOG(type)
Definition: log.h:49
typename Arc::StateId StateId
Definition: relabel.h:214
void SetOutputSymbols(const SymbolTable *osyms)
Definition: fst.h:772
typename Arc::StateId StateId
Definition: fst.h:361
typename Arc::Weight Weight
Definition: relabel.h:215
uint64_t RelabelProperties(uint64_t inprops)
Definition: properties.cc:327
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
Definition: relabel.h:42
uint64_t Properties(uint64_t mask) const override
Definition: relabel.h:321
#define FSTERROR()
Definition: util.h:53
const SymbolTable * OutputSymbols() const override=0
RelabelFstImpl(const RelabelFstImpl< Arc > &impl)
Definition: relabel.h:280
virtual uint64_t Properties() const
Definition: fst.h:702
constexpr uint64_t kCopyProperties
Definition: properties.h:162
virtual void SetProperties(uint64_t props, uint64_t mask)=0
RelabelFst * Copy(bool safe=false) const override
Definition: relabel.h:405
Weight Final(StateId s)
Definition: relabel.h:298
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:666
#define VLOG(level)
Definition: log.h:50
size_t NumInputEpsilons(StateId s)
Definition: relabel.h:308
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
Definition: cache.h:1042
void SetInputSymbols(const SymbolTable *isyms)
Definition: fst.h:768
RelabelFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts)
Definition: relabel.h:236
size_t NumOutputEpsilons(StateId s)
Definition: relabel.h:313
constexpr uint64_t kFstProperties
Definition: properties.h:325
void PushArc(StateId s, const Arc &arc)
Definition: cache.h:933
ArcIterator(const RelabelFst< Arc > &fst, StateId s)
Definition: relabel.h:462
size_t NumArcs(StateId s)
Definition: relabel.h:303
RelabelFst(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:392
typename RelabelFst< Arc >::Arc Arc
Definition: cache.h:1195
typename Arc::StateId StateId
Definition: relabel.h:369
const SymbolTable * InputSymbols() const
Definition: fst.h:760
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: relabel.h:328
void SetType(std::string_view type)
Definition: fst.h:700
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: relabel.h:411
typename Arc::Label Label
Definition: relabel.h:368
typename Arc::Label Label
Definition: relabel.h:213
typename Store::State State
Definition: relabel.h:373
RelabelFst(const Fst< Arc > &fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:385
RelabelFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:379
std::string Find(int64_t key) const
Definition: symbol-table.h:446
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
const std::string & LabeledCheckSum() const
Definition: symbol-table.h:453
RelabelFstImpl(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
Definition: relabel.h:250
#define DCHECK_NE(x, y)
Definition: log.h:76
typename CacheState< Arc >::Arc Arc
Definition: cache.h:852
internal::RelabelFstImpl< A > * GetMutableImpl() const
Definition: fst.h:1028
typename Arc::StateId StateId
Definition: relabel.h:460
RelabelFst(const RelabelFst &fst, bool safe=false)
Definition: relabel.h:401
CacheOptions RelabelFstOptions
Definition: relabel.h:199
const internal::RelabelFstImpl< A > * GetImpl() const
Definition: fst.h:1026
uint64_t Properties() const override
Definition: relabel.h:318
typename Arc::Weight Weight
Definition: relabel.h:370