FST  openfst-1.8.3
OpenFst Library
relabel.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes to relabel an FST (either on input or output).
19 
20 #ifndef FST_RELABEL_H_
21 #define FST_RELABEL_H_
22 
23 #include <cstddef>
24 #include <cstdint>
25 #include <memory>
26 #include <string>
27 #include <utility>
28 #include <vector>
29 
30 #include <fst/log.h>
31 #include <fst/arc.h>
32 #include <fst/cache.h>
33 #include <fst/float-weight.h>
34 #include <fst/fst.h>
35 #include <fst/impl-to-fst.h>
36 #include <fst/mutable-fst.h>
37 #include <fst/properties.h>
38 #include <fst/symbol-table.h>
39 #include <fst/util.h>
40 #include <unordered_map>
41 
42 namespace fst {
43 
44 // Relabels either the input labels or output labels. The old to
45 // new labels are specified using a vector of std::pair<Label, Label>.
46 // Any label associations not specified are assumed to be identity
47 // mapping. The destination labels must be valid labels (e.g., not kNoLabel).
48 template <class Arc>
49 void Relabel(
51  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
52  &ipairs,
53  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
54  &opairs) {
55  using Label = typename Arc::Label;
56  const auto props = fst->Properties(kFstProperties, false);
57  // Constructs label-to-label maps.
58  const std::unordered_map<Label, Label> input_map(
59  ipairs.begin(), ipairs.end());
60  const std::unordered_map<Label, Label> output_map(
61  opairs.begin(), opairs.end());
62  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
63  siter.Next()) {
64  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
65  !aiter.Done(); aiter.Next()) {
66  auto arc = aiter.Value();
67  // dense_hash_map does not support find on the empty_key_val.
68  // These labels should never be in an FST anyway.
69  DCHECK_NE(arc.ilabel, kNoLabel);
70  DCHECK_NE(arc.olabel, kNoLabel);
71  // Relabels input.
72  if (auto it = input_map.find(arc.ilabel); it != input_map.end()) {
73  if (it->second == kNoLabel) {
74  FSTERROR() << "Input symbol ID " << arc.ilabel
75  << " missing from target vocabulary";
77  return;
78  }
79  arc.ilabel = it->second;
80  }
81  // Relabels output.
82  if (auto it = output_map.find(arc.olabel); it != output_map.end()) {
83  if (it->second == kNoLabel) {
84  FSTERROR() << "Output symbol id " << arc.olabel
85  << " missing from target vocabulary";
87  return;
88  }
89  arc.olabel = it->second;
90  }
91  aiter.SetValue(arc);
92  }
93  }
95 }
96 
97 // Relabels either the input labels or output labels. The old to
98 // new labels are specified using pairs of old and new symbol tables.
99 // The tables must contain (at least) all labels on the appropriate side of the
100 // FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any
101 // missing symbol in new_i(o)symbols table.
102 template <class Arc>
103 void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
104  const SymbolTable *new_isymbols,
105  const std::string &unknown_isymbol, bool attach_new_isymbols,
106  const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
107  const std::string &unknown_osymbol, bool attach_new_osymbols) {
108  using Label = typename Arc::Label;
109  // Constructs vectors of input-side label pairs.
110  std::vector<std::pair<Label, Label>> ipairs;
111  if (old_isymbols && new_isymbols) {
112  size_t num_missing_syms = 0;
113  Label unknown_ilabel = kNoLabel;
114  if (!unknown_isymbol.empty()) {
115  unknown_ilabel = new_isymbols->Find(unknown_isymbol);
116  if (unknown_ilabel == kNoLabel) {
117  VLOG(1) << "Input symbol '" << unknown_isymbol
118  << "' missing from target symbol table";
119  ++num_missing_syms;
120  }
121  }
122 
123  for (const auto &sitem : *old_isymbols) {
124  const auto old_index = sitem.Label();
125  const auto symbol = sitem.Symbol();
126  auto new_index = new_isymbols->Find(symbol);
127  if (new_index == kNoLabel) {
128  if (unknown_ilabel != kNoLabel) {
129  new_index = unknown_ilabel;
130  } else {
131  VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol
132  << "' missing from target symbol table";
133  ++num_missing_syms;
134  }
135  }
136  ipairs.emplace_back(old_index, new_index);
137  }
138  if (num_missing_syms > 0) {
139  LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
140  << " input symbols";
141  }
142  if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols);
143  }
144  // Constructs vectors of output-side label pairs.
145  std::vector<std::pair<Label, Label>> opairs;
146  if (old_osymbols && new_osymbols) {
147  size_t num_missing_syms = 0;
148  Label unknown_olabel = kNoLabel;
149  if (!unknown_osymbol.empty()) {
150  unknown_olabel = new_osymbols->Find(unknown_osymbol);
151  if (unknown_olabel == kNoLabel) {
152  VLOG(1) << "Output symbol '" << unknown_osymbol
153  << "' missing from target symbol table";
154  ++num_missing_syms;
155  }
156  }
157  for (const auto &sitem : *old_osymbols) {
158  const auto old_index = sitem.Label();
159  const auto symbol = sitem.Symbol();
160  auto new_index = new_osymbols->Find(symbol);
161  if (new_index == kNoLabel) {
162  if (unknown_olabel != kNoLabel) {
163  new_index = unknown_olabel;
164  } else {
165  VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol
166  << "' missing from target symbol table";
167  ++num_missing_syms;
168  }
169  }
170  opairs.emplace_back(old_index, new_index);
171  }
172  if (num_missing_syms > 0) {
173  LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
174  << " output symbols";
175  }
176  if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols);
177  }
178  // Calls relabel using vector of relabel pairs.
179  Relabel(fst, ipairs, opairs);
180 }
181 
182 // Same as previous but no special allowance for unknown symbols. Kept
183 // for backward compat.
184 template <class Arc>
185 void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
186  const SymbolTable *new_isymbols, bool attach_new_isymbols,
187  const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
188  bool attach_new_osymbols) {
189  Relabel(fst, old_isymbols, new_isymbols, "" /* no unknown isymbol */,
190  attach_new_isymbols, old_osymbols, new_osymbols,
191  "" /* no unknown osymbol */, attach_new_osymbols);
192 }
193 
194 // Relabels either the input labels or output labels. The old to
195 // new labels are specified using symbol tables. Any label associations not
196 // specified are assumed to be identity mapping.
197 template <class Arc>
198 void Relabel(MutableFst<Arc> *fst, const SymbolTable *new_isymbols,
199  const SymbolTable *new_osymbols) {
200  Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(),
201  new_osymbols, true);
202 }
203 
205 
206 template <class Arc>
207 class RelabelFst;
208 
209 namespace internal {
210 
211 // Relabels an FST from one symbol set to another. Relabeling can either be on
212 // input or output space. RelabelFst implements a delayed version of the
213 // relabel. Arcs are relabeled on the fly and not cached; i.e., each request is
214 // recomputed.
215 template <class Arc>
216 class RelabelFstImpl : public CacheImpl<Arc> {
217  public:
218  using Label = typename Arc::Label;
219  using StateId = typename Arc::StateId;
220  using Weight = typename Arc::Weight;
221 
223  using State = typename Store::State;
224 
225  using FstImpl<Arc>::SetType;
230 
238 
239  friend class StateIterator<RelabelFst<Arc>>;
240 
242  const std::vector<std::pair<Label, Label>> &ipairs,
243  const std::vector<std::pair<Label, Label>> &opairs,
244  const RelabelFstOptions &opts)
245  : CacheImpl<Arc>(opts),
246  fst_(fst.Copy()),
247  input_map_(ipairs.begin(), ipairs.end()),
248  output_map_(opairs.begin(), opairs.end()),
249  relabel_input_(!ipairs.empty()),
250  relabel_output_(!opairs.empty()) {
252  SetType("relabel");
253  }
254 
255  RelabelFstImpl(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
256  const SymbolTable *new_isymbols,
257  const SymbolTable *old_osymbols,
258  const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
259  : CacheImpl<Arc>(opts),
260  fst_(fst.Copy()),
261  relabel_input_(false),
262  relabel_output_(false) {
263  SetType("relabel");
265  SetInputSymbols(old_isymbols);
266  SetOutputSymbols(old_osymbols);
267  if (old_isymbols && new_isymbols &&
268  old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
269  for (const auto &sitem : *old_isymbols) {
270  input_map_[sitem.Label()] = new_isymbols->Find(sitem.Symbol());
271  }
272  SetInputSymbols(new_isymbols);
273  relabel_input_ = true;
274  }
275  if (old_osymbols && new_osymbols &&
276  old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
277  for (const auto &sitem : *old_osymbols) {
278  output_map_[sitem.Label()] = new_osymbols->Find(sitem.Symbol());
279  }
280  SetOutputSymbols(new_osymbols);
281  relabel_output_ = true;
282  }
283  }
284 
286  : CacheImpl<Arc>(impl),
287  fst_(impl.fst_->Copy(true)),
288  input_map_(impl.input_map_),
289  output_map_(impl.output_map_),
290  relabel_input_(impl.relabel_input_),
291  relabel_output_(impl.relabel_output_) {
292  SetType("relabel");
296  }
297 
299  if (!HasStart()) SetStart(fst_->Start());
300  return CacheImpl<Arc>::Start();
301  }
302 
304  if (!HasFinal(s)) SetFinal(s, fst_->Final(s));
305  return CacheImpl<Arc>::Final(s);
306  }
307 
308  size_t NumArcs(StateId s) {
309  if (!HasArcs(s)) Expand(s);
310  return CacheImpl<Arc>::NumArcs(s);
311  }
312 
314  if (!HasArcs(s)) Expand(s);
316  }
317 
319  if (!HasArcs(s)) Expand(s);
321  }
322 
323  uint64_t Properties() const override { return Properties(kFstProperties); }
324 
325  // Sets error if found, and returns other FST impl properties.
326  uint64_t Properties(uint64_t mask) const override {
327  if ((mask & kError) && fst_->Properties(kError, false)) {
328  SetProperties(kError, kError);
329  }
330  return FstImpl<Arc>::Properties(mask);
331  }
332 
334  if (!HasArcs(s)) Expand(s);
336  }
337 
338  void Expand(StateId s) {
339  for (ArcIterator<Fst<Arc>> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
340  auto arc = aiter.Value();
341  if (relabel_input_) {
342  if (auto it = input_map_.find(arc.ilabel); it != input_map_.end()) {
343  arc.ilabel = it->second;
344  }
345  }
346  if (relabel_output_) {
347  if (auto it = output_map_.find(arc.olabel); it != output_map_.end()) {
348  arc.olabel = it->second;
349  }
350  }
351  PushArc(s, std::move(arc));
352  }
353  SetArcs(s);
354  }
355 
356  private:
357  std::unique_ptr<const Fst<Arc>> fst_;
358 
359  std::unordered_map<Label, Label> input_map_;
360  std::unordered_map<Label, Label> output_map_;
361  bool relabel_input_;
362  bool relabel_output_;
363 };
364 
365 } // namespace internal
366 
367 // This class attaches interface to implementation and handles
368 // reference counting, delegating most methods to ImplToFst.
369 template <class A>
370 class RelabelFst : public ImplToFst<internal::RelabelFstImpl<A>> {
371  public:
372  using Arc = A;
373  using Label = typename Arc::Label;
374  using StateId = typename Arc::StateId;
375  using Weight = typename Arc::Weight;
376 
378  using State = typename Store::State;
380 
381  friend class ArcIterator<RelabelFst<A>>;
382  friend class StateIterator<RelabelFst<A>>;
383 
385  const std::vector<std::pair<Label, Label>> &ipairs,
386  const std::vector<std::pair<Label, Label>> &opairs,
387  const RelabelFstOptions &opts = RelabelFstOptions())
388  : ImplToFst<Impl>(std::make_shared<Impl>(fst, ipairs, opairs, opts)) {}
389 
390  RelabelFst(const Fst<Arc> &fst, const SymbolTable *new_isymbols,
391  const SymbolTable *new_osymbols,
392  const RelabelFstOptions &opts = RelabelFstOptions())
393  : ImplToFst<Impl>(
394  std::make_shared<Impl>(fst, fst.InputSymbols(), new_isymbols,
395  fst.OutputSymbols(), new_osymbols, opts)) {}
396 
397  RelabelFst(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
398  const SymbolTable *new_isymbols, const SymbolTable *old_osymbols,
399  const SymbolTable *new_osymbols,
400  const RelabelFstOptions &opts = RelabelFstOptions())
401  : ImplToFst<Impl>(std::make_shared<Impl>(fst, old_isymbols, new_isymbols,
402  old_osymbols, new_osymbols,
403  opts)) {}
404 
405  // See Fst<>::Copy() for doc.
406  RelabelFst(const RelabelFst &fst, bool safe = false)
407  : ImplToFst<Impl>(fst, safe) {}
408 
409  // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc.
410  RelabelFst *Copy(bool safe = false) const override {
411  return new RelabelFst(*this, safe);
412  }
413 
414  void InitStateIterator(StateIteratorData<Arc> *data) const override;
415 
416  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
417  return GetMutableImpl()->InitArcIterator(s, data);
418  }
419 
420  private:
423 
424  RelabelFst &operator=(const RelabelFst &) = delete;
425 };
426 
427 // Specialization for RelabelFst.
428 template <class Arc>
430  public:
431  using StateId = typename Arc::StateId;
432 
433  explicit StateIterator(const RelabelFst<Arc> &fst)
434  : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
435 
436  bool Done() const final { return siter_.Done(); }
437 
438  StateId Value() const final { return s_; }
439 
440  void Next() final {
441  if (!siter_.Done()) {
442  ++s_;
443  siter_.Next();
444  }
445  }
446 
447  void Reset() final {
448  s_ = 0;
449  siter_.Reset();
450  }
451 
452  private:
453  const internal::RelabelFstImpl<Arc> *impl_;
454  StateIterator<Fst<Arc>> siter_;
455  StateId s_;
456 
457  StateIterator(const StateIterator &) = delete;
458  StateIterator &operator=(const StateIterator &) = delete;
459 };
460 
461 // Specialization for RelabelFst.
462 template <class Arc>
463 class ArcIterator<RelabelFst<Arc>> : public CacheArcIterator<RelabelFst<Arc>> {
464  public:
465  using StateId = typename Arc::StateId;
466 
468  : CacheArcIterator<RelabelFst<Arc>>(fst.GetMutableImpl(), s) {
469  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
470  }
471 };
472 
473 template <class Arc>
475  StateIteratorData<Arc> *data) const {
476  data->base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
477 }
478 
479 // Useful alias when using StdArc.
481 
482 } // namespace fst
483 
484 #endif // FST_RELABEL_H_
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: relabel.h:474
constexpr int kNoLabel
Definition: fst.h:195
virtual uint64_t Properties(uint64_t mask, bool test) const =0
StateIterator(const RelabelFst< Arc > &fst)
Definition: relabel.h:433
void SetFinal(StateId s, Weight weight=Weight::One())
Definition: cache.h:930
const SymbolTable * OutputSymbols() const
Definition: fst.h:761
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
Definition: properties.h:52
void Expand(StateId s)
Definition: relabel.h:338
virtual void SetInputSymbols(const SymbolTable *isyms)=0
#define LOG(type)
Definition: log.h:53
typename Arc::StateId StateId
Definition: relabel.h:219
void SetOutputSymbols(const SymbolTable *osyms)
Definition: fst.h:771
typename Arc::StateId StateId
Definition: fst.h:361
typename Arc::Weight Weight
Definition: relabel.h:220
uint64_t RelabelProperties(uint64_t inprops)
Definition: properties.cc:326
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
Definition: relabel.h:49
uint64_t Properties(uint64_t mask) const override
Definition: relabel.h:326
#define FSTERROR()
Definition: util.h:56
const SymbolTable * OutputSymbols() const override=0
RelabelFstImpl(const RelabelFstImpl< Arc > &impl)
Definition: relabel.h:285
virtual uint64_t Properties() const
Definition: fst.h:701
constexpr uint64_t kCopyProperties
Definition: properties.h:163
virtual void SetProperties(uint64_t props, uint64_t mask)=0
RelabelFst * Copy(bool safe=false) const override
Definition: relabel.h:410
Weight Final(StateId s)
Definition: relabel.h:303
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:673
#define VLOG(level)
Definition: log.h:54
size_t NumInputEpsilons(StateId s)
Definition: relabel.h:313
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
Definition: cache.h:1049
void SetInputSymbols(const SymbolTable *isyms)
Definition: fst.h:767
RelabelFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts)
Definition: relabel.h:241
size_t NumOutputEpsilons(StateId s)
Definition: relabel.h:318
constexpr uint64_t kFstProperties
Definition: properties.h:326
void PushArc(StateId s, const Arc &arc)
Definition: cache.h:940
ArcIterator(const RelabelFst< Arc > &fst, StateId s)
Definition: relabel.h:467
size_t NumArcs(StateId s)
Definition: relabel.h:308
RelabelFst(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:397
typename RelabelFst< Arc >::Arc Arc
Definition: cache.h:1202
typename Arc::StateId StateId
Definition: relabel.h:374
const SymbolTable * InputSymbols() const
Definition: fst.h:759
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: relabel.h:333
void SetType(std::string_view type)
Definition: fst.h:699
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: relabel.h:416
typename Arc::Label Label
Definition: relabel.h:373
typename Arc::Label Label
Definition: relabel.h:218
typename Store::State State
Definition: relabel.h:378
RelabelFst(const Fst< Arc > &fst, const SymbolTable *new_isymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:390
RelabelFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &ipairs, const std::vector< std::pair< Label, Label >> &opairs, const RelabelFstOptions &opts=RelabelFstOptions())
Definition: relabel.h:384
std::string Find(int64_t key) const
Definition: symbol-table.h:450
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
const std::string & LabeledCheckSum() const
Definition: symbol-table.h:457
RelabelFstImpl(const Fst< Arc > &fst, const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
Definition: relabel.h:255
#define DCHECK_NE(x, y)
Definition: log.h:80
typename CacheState< Arc >::Arc Arc
Definition: cache.h:859
internal::RelabelFstImpl< A > * GetMutableImpl() const
Definition: impl-to-fst.h:125
typename Arc::StateId StateId
Definition: relabel.h:465
RelabelFst(const RelabelFst &fst, bool safe=false)
Definition: relabel.h:406
CacheOptions RelabelFstOptions
Definition: relabel.h:204
const internal::RelabelFstImpl< A > * GetImpl() const
Definition: impl-to-fst.h:123
uint64_t Properties() const override
Definition: relabel.h:323
typename Arc::Weight Weight
Definition: relabel.h:375