FST  openfst-1.8.2
OpenFst Library
factor-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to factor weights in an FST.
19 
20 #ifndef FST_FACTOR_WEIGHT_H_
21 #define FST_FACTOR_WEIGHT_H_
22 
23 #include <algorithm>
24 #include <cstdint>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/log.h>
30 
31 #include <fst/cache.h>
32 #include <fst/test-properties.h>
33 
34 #include <unordered_map>
35 
36 namespace fst {
37 
38 inline constexpr uint8_t kFactorFinalWeights = 0x01;
39 inline constexpr uint8_t kFactorArcWeights = 0x02;
40 
41 template <class Arc>
43  using Label = typename Arc::Label;
44 
45  float delta;
46  uint8_t mode; // Factor arc weights and/or final weights.
47  Label final_ilabel; // Input label of arc when factoring final weights.
48  Label final_olabel; // Output label of arc when factoring final weights.
49  bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
50  bool increment_final_olabel; // at state, increment labels to make distinct?
51 
52  explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta,
53  uint8_t mode = kFactorArcWeights |
54  kFactorFinalWeights,
55  Label final_ilabel = 0, Label final_olabel = 0,
56  bool increment_final_ilabel = false,
57  bool increment_final_olabel = false)
58  : CacheOptions(opts),
59  delta(delta),
60  mode(mode),
61  final_ilabel(final_ilabel),
62  final_olabel(final_olabel),
63  increment_final_ilabel(increment_final_ilabel),
64  increment_final_olabel(increment_final_olabel) {}
65 
66  explicit FactorWeightOptions(float delta = kDelta,
67  uint8_t mode = kFactorArcWeights |
68  kFactorFinalWeights,
69  Label final_ilabel = 0, Label final_olabel = 0,
70  bool increment_final_ilabel = false,
71  bool increment_final_olabel = false)
72  : delta(delta),
73  mode(mode),
74  final_ilabel(final_ilabel),
75  final_olabel(final_olabel),
76  increment_final_ilabel(increment_final_ilabel),
77  increment_final_olabel(increment_final_olabel) {}
78 };
79 
80 // A factor iterator takes as argument a weight w and returns a sequence of
81 // pairs of weights (xi, yi) such that the sum of the products xi times yi is
82 // equal to w. If w is fully factored, the iterator should return nothing.
83 //
84 // template <class W>
85 // class FactorIterator {
86 // public:
87 // explicit FactorIterator(W w);
88 //
89 // bool Done() const;
90 //
91 // void Next();
92 //
93 // std::pair<W, W> Value() const;
94 //
95 // void Reset();
96 // }
97 
98 // Factors trivially.
99 template <class W>
101  public:
102  explicit IdentityFactor(const W &weight) {}
103 
104  bool Done() const { return true; }
105 
106  void Next() {}
107 
108  std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
109 
110  void Reset() {}
111 };
112 
113 // Factor the Fst to unfold it as needed so that every two paths leading to the
114 // same state have the same weight. Requires applying only to arc weights
115 // (FactorWeightOptions::mode == kFactorArcWeights).
116 template <class W>
117 class OneFactor {
118  public:
119  explicit OneFactor(const W &w) : weight_(w), done_(w == W::One()) {}
120 
121  bool Done() const { return done_; }
122 
123  void Next() { done_ = true; }
124 
125  std::pair<W, W> Value() const { return std::make_pair(W::One(), weight_); }
126 
127  void Reset() { done_ = weight_ == W::One(); }
128 
129  private:
130  W weight_;
131  bool done_;
132 };
133 
134 // Factors a StringWeight w as 'ab' where 'a' is a label.
135 template <typename Label, StringType S = STRING_LEFT>
137  public:
138  explicit StringFactor(const StringWeight<Label, S> &weight)
139  : weight_(weight), done_(weight.Size() <= 1) {}
140 
141  bool Done() const { return done_; }
142 
143  void Next() { done_ = true; }
144 
145  std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const {
146  using Weight = StringWeight<Label, S>;
147  typename Weight::Iterator siter(weight_);
148  Weight w1(siter.Value());
149  Weight w2;
150  for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
151  return std::make_pair(w1, w2);
152  }
153 
154  void Reset() { done_ = weight_.Size() <= 1; }
155 
156  private:
157  const StringWeight<Label, S> weight_;
158  bool done_;
159 };
160 
161 // Factor a GallicWeight using StringFactor.
162 template <class Label, class W, GallicType G = GALLIC_LEFT>
164  public:
166 
167  explicit GallicFactor(const GW &weight)
168  : weight_(weight), done_(weight.Value1().Size() <= 1) {}
169 
170  bool Done() const { return done_; }
171 
172  void Next() { done_ = true; }
173 
174  std::pair<GW, GW> Value() const {
175  StringFactor<Label, GallicStringType(G)> siter(weight_.Value1());
176  GW w1(siter.Value().first, weight_.Value2());
177  GW w2(siter.Value().second, W::One());
178  return std::make_pair(w1, w2);
179  }
180 
181  void Reset() { done_ = weight_.Value1().Size() <= 1; }
182 
183  private:
184  const GW weight_;
185  bool done_;
186 };
187 
188 // Specialization for the (general) GALLIC type GallicWeight.
189 template <class Label, class W>
191  public:
194 
195  explicit GallicFactor(const GW &weight)
196  : iter_(weight),
197  done_(weight.Size() == 0 ||
198  (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
199 
200  bool Done() const { return done_ || iter_.Done(); }
201 
202  void Next() { iter_.Next(); }
203 
204  void Reset() { iter_.Reset(); }
205 
206  std::pair<GW, GW> Value() const {
207  const auto weight = iter_.Value();
209  weight.Value1());
210  GRW w1(siter.Value().first, weight.Value2());
211  GRW w2(siter.Value().second, W::One());
212  return std::make_pair(GW(w1), GW(w2));
213  }
214 
215  private:
217  bool done_;
218 };
219 
220 namespace internal {
221 
222 // Implementation class for FactorWeight
223 template <class Arc, class FactorIterator>
224 class FactorWeightFstImpl : public CacheImpl<Arc> {
225  public:
226  using Label = typename Arc::Label;
227  using StateId = typename Arc::StateId;
228  using Weight = typename Arc::Weight;
229 
230  using FstImpl<Arc>::SetType;
234 
235  using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
236  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
237  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
238  using CacheBaseImpl<CacheState<Arc>>::HasStart;
239  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
240  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
241  using CacheBaseImpl<CacheState<Arc>>::SetStart;
242 
243  struct Element {
244  Element() {}
245 
246  Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
247 
248  StateId state; // Input state ID.
249  Weight weight; // Residual weight.
250  };
251 
253  : CacheImpl<Arc>(opts),
254  fst_(fst.Copy()),
255  delta_(opts.delta),
256  mode_(opts.mode),
257  final_ilabel_(opts.final_ilabel),
258  final_olabel_(opts.final_olabel),
259  increment_final_ilabel_(opts.increment_final_ilabel),
260  increment_final_olabel_(opts.increment_final_olabel) {
261  SetType("factor_weight");
262  const auto props = fst.Properties(kFstProperties, false);
263  SetProperties(FactorWeightProperties(props), kCopyProperties);
264  SetInputSymbols(fst.InputSymbols());
265  SetOutputSymbols(fst.OutputSymbols());
266  if (mode_ == 0) {
267  LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; "
268  << "factoring neither arc weights nor final weights";
269  }
270  }
271 
273  : CacheImpl<Arc>(impl),
274  fst_(impl.fst_->Copy(true)),
275  delta_(impl.delta_),
276  mode_(impl.mode_),
277  final_ilabel_(impl.final_ilabel_),
278  final_olabel_(impl.final_olabel_),
279  increment_final_ilabel_(impl.increment_final_ilabel_),
280  increment_final_olabel_(impl.increment_final_olabel_) {
281  SetType("factor_weight");
282  SetProperties(impl.Properties(), kCopyProperties);
283  SetInputSymbols(impl.InputSymbols());
284  SetOutputSymbols(impl.OutputSymbols());
285  }
286 
288  if (!HasStart()) {
289  const auto s = fst_->Start();
290  if (s == kNoStateId) return kNoStateId;
291  SetStart(FindState(Element(fst_->Start(), Weight::One())));
292  }
293  return CacheImpl<Arc>::Start();
294  }
295 
297  if (!HasFinal(s)) {
298  const auto &element = elements_[s];
299  const auto weight =
300  element.state == kNoStateId
301  ? element.weight
302  : Times(element.weight, fst_->Final(element.state));
303  FactorIterator siter(weight);
304  if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
305  SetFinal(s, weight);
306  } else {
307  SetFinal(s, Weight::Zero());
308  }
309  }
310  return CacheImpl<Arc>::Final(s);
311  }
312 
313  size_t NumArcs(StateId s) {
314  if (!HasArcs(s)) Expand(s);
315  return CacheImpl<Arc>::NumArcs(s);
316  }
317 
319  if (!HasArcs(s)) Expand(s);
321  }
322 
324  if (!HasArcs(s)) Expand(s);
326  }
327 
328  uint64_t Properties() const override { return Properties(kFstProperties); }
329 
330  // Sets error if found, and returns other FST impl properties.
331  uint64_t Properties(uint64_t mask) const override {
332  if ((mask & kError) && fst_->Properties(kError, false)) {
333  SetProperties(kError, kError);
334  }
335  return FstImpl<Arc>::Properties(mask);
336  }
337 
339  if (!HasArcs(s)) Expand(s);
341  }
342 
343  // Finds state corresponding to an element, creating new state if element not
344  // found.
345  StateId FindState(const Element &element) {
346  if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() &&
347  element.state != kNoStateId) {
348  while (unfactored_.size() <= element.state)
349  unfactored_.push_back(kNoStateId);
350  if (unfactored_[element.state] == kNoStateId) {
351  unfactored_[element.state] = elements_.size();
352  elements_.push_back(element);
353  }
354  return unfactored_[element.state];
355  } else {
356  const auto insert_result =
357  element_map_.emplace(element, elements_.size());
358  if (insert_result.second) {
359  elements_.push_back(element);
360  }
361  return insert_result.first->second;
362  }
363  }
364 
365  // Computes the outgoing transitions from a state, creating new destination
366  // states as needed.
367  void Expand(StateId s) {
368  const auto element = elements_[s];
369  if (element.state != kNoStateId) {
370  for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done();
371  ait.Next()) {
372  const auto &arc = ait.Value();
373  auto weight = Times(element.weight, arc.weight);
374  FactorIterator fiter(weight);
375  if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
376  const auto dest = FindState(Element(arc.nextstate, Weight::One()));
377  EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
378  } else {
379  for (; !fiter.Done(); fiter.Next()) {
380  auto pair = fiter.Value();
381  const auto dest =
382  FindState(Element(arc.nextstate, pair.second.Quantize(delta_)));
383  EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
384  }
385  }
386  }
387  }
388  if ((mode_ & kFactorFinalWeights) &&
389  ((element.state == kNoStateId) ||
390  (fst_->Final(element.state) != Weight::Zero()))) {
391  const auto weight =
392  element.state == kNoStateId
393  ? element.weight
394  : Times(element.weight, fst_->Final(element.state));
395  auto ilabel = final_ilabel_;
396  auto olabel = final_olabel_;
397  for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
398  auto pair = fiter.Value();
399  const auto dest =
400  FindState(Element(kNoStateId, pair.second.Quantize(delta_)));
401  EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
402  if (increment_final_ilabel_) ++ilabel;
403  if (increment_final_olabel_) ++olabel;
404  }
405  }
406  SetArcs(s);
407  }
408 
409  private:
410  // Equality function for Elements, assume weights have been quantized.
411  class ElementEqual {
412  public:
413  bool operator()(const Element &x, const Element &y) const {
414  return x.state == y.state && x.weight == y.weight;
415  }
416  };
417 
418  // Hash function for Elements to Fst states.
419  class ElementKey {
420  public:
421  size_t operator()(const Element &x) const {
422  static constexpr auto prime = 7853;
423  return static_cast<size_t>(x.state * prime + x.weight.Hash());
424  }
425  };
426 
427  using ElementMap =
428  std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
429 
430  std::unique_ptr<const Fst<Arc>> fst_;
431  float delta_;
432  uint8_t mode_; // Factoring arc and/or final weights.
433  Label final_ilabel_; // ilabel of arc created when factoring final weights.
434  Label final_olabel_; // olabel of arc created when factoring final weights.
435  bool increment_final_ilabel_; // When factoring final weights results in
436  bool increment_final_olabel_; // mutiple arcs, increment labels?
437  std::vector<Element> elements_; // Mapping from FST state to Element.
438  ElementMap element_map_; // Mapping from Element to FST state.
439  // Mapping between old/new StateId for states that do not need to be factored
440  // when mode_ is 0 or kFactorFinalWeights.
441  std::vector<StateId> unfactored_;
442 };
443 
444 } // namespace internal
445 
446 // FactorWeightFst takes as template parameter a FactorIterator as defined
447 // above. The result of weight factoring is a transducer equivalent to the
448 // input whose path weights have been factored according to the FactorIterator.
449 // States and transitions will be added as necessary. The algorithm is a
450 // generalization to arbitrary weights of the second step of the input
451 // epsilon-normalization algorithm.
452 //
453 // This class attaches interface to implementation and handles reference
454 // counting, delegating most methods to ImplToFst.
455 template <class A, class FactorIterator>
457  : public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
458  public:
459  using Arc = A;
460  using StateId = typename Arc::StateId;
461  using Weight = typename Arc::Weight;
462 
464  using State = typename Store::State;
466 
467  friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>;
468  friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
469 
470  explicit FactorWeightFst(const Fst<Arc> &fst)
471  : ImplToFst<Impl>(
472  std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
473 
475  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
476 
477  // See Fst<>::Copy() for doc.
479  : ImplToFst<Impl>(fst, copy) {}
480 
481  // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
482  FactorWeightFst *Copy(bool copy = false) const override {
483  return new FactorWeightFst(*this, copy);
484  }
485 
486  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
487 
488  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
489  GetMutableImpl()->InitArcIterator(s, data);
490  }
491 
492  private:
495 
496  FactorWeightFst &operator=(const FactorWeightFst &) = delete;
497 };
498 
499 // Specialization for FactorWeightFst.
500 template <class Arc, class FactorIterator>
501 class StateIterator<FactorWeightFst<Arc, FactorIterator>>
502  : public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> {
503  public:
505  : CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>(
506  fst, fst.GetMutableImpl()) {}
507 };
508 
509 // Specialization for FactorWeightFst.
510 template <class Arc, class FactorIterator>
511 class ArcIterator<FactorWeightFst<Arc, FactorIterator>>
512  : public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> {
513  public:
514  using StateId = typename Arc::StateId;
515 
517  : CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>(
518  fst.GetMutableImpl(), s) {
519  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
520  }
521 };
522 
523 template <class Arc, class FactorIterator>
525  StateIteratorData<Arc> *data) const {
526  data->base =
527  std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
528  *this);
529 }
530 
531 } // namespace fst
532 
533 #endif // FST_FACTOR_WEIGHT_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:114
bool Done() const
FactorWeightFst(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
FactorWeightFst(const FactorWeightFst &fst, bool copy)
std::pair< W, W > Value() const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
uint64_t Properties(uint64_t mask) const override
uint64_t Properties() const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:97
const SymbolTable * OutputSymbols() const
Definition: fst.h:762
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
constexpr uint64_t kError
Definition: properties.h:51
StateId FindState(const Element &element)
#define LOG(type)
Definition: log.h:49
SetType
Definition: set-weight.h:52
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
Definition: cache.h:1149
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:103
constexpr int kNoStateId
Definition: fst.h:202
std::pair< W, W > Value() const
StringFactor(const StringWeight< Label, S > &weight)
bool Done() const
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:108
FactorWeightFstImpl(const FactorWeightFstImpl< Arc, FactorIterator > &impl)
constexpr uint64_t kCopyProperties
Definition: properties.h:162
uint64_t FactorWeightProperties(uint64_t inprops)
Definition: properties.cc:156
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:666
std::pair< GW, GW > Value() const
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
FactorWeightOptions(float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
Definition: factor-weight.h:66
FactorWeightOptions(const CacheOptions &opts, float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
Definition: factor-weight.h:52
FactorWeightFst(const Fst< Arc > &fst)
constexpr uint64_t kFstProperties
Definition: properties.h:325
constexpr uint8_t kFactorFinalWeights
Definition: factor-weight.h:38
bool Done() const
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
Definition: cache.h:1195
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:760
std::pair< StringWeight< Label, S >, StringWeight< Label, S > > Value() const
FactorWeightFstImpl(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
FactorWeightFst * Copy(bool copy=false) const override
OneFactor(const W &w)
typename Arc::Label Label
Definition: factor-weight.h:43
void InitStateIterator(StateIteratorData< Arc > *data) const override
constexpr uint8_t kFactorArcWeights
Definition: factor-weight.h:39
GallicFactor(const GW &weight)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:316
bool Done() const
typename CacheState< Arc >::Arc Arc
Definition: cache.h:852
Impl * GetMutableImpl() const
Definition: fst.h:1028
ArcIterator(const FactorWeightFst< Arc, FactorIterator > &fst, StateId s)
constexpr float kDelta
Definition: weight.h:130
IdentityFactor(const W &weight)
std::pair< GW, GW > Value() const
StateIterator(const FactorWeightFst< Arc, FactorIterator > &fst)
const Impl * GetImpl() const
Definition: fst.h:1026
virtual const SymbolTable * OutputSymbols() const =0