FST  openfst-1.7.1
OpenFst Library
factor-weight.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to factor weights in an FST.
5 
6 #ifndef FST_FACTOR_WEIGHT_H_
7 #define FST_FACTOR_WEIGHT_H_
8 
9 #include <algorithm>
10 #include <string>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14 
15 #include <fst/log.h>
16 
17 #include <fst/cache.h>
18 #include <fst/test-properties.h>
19 
20 
21 namespace fst {
22 
23 constexpr uint32 kFactorFinalWeights = 0x00000001;
24 constexpr uint32 kFactorArcWeights = 0x00000002;
25 
26 template <class Arc>
28  using Label = typename Arc::Label;
29 
30  float delta;
31  uint32 mode; // Factor arc weights and/or final weights.
32  Label final_ilabel; // Input label of arc when factoring final weights.
33  Label final_olabel; // Output label of arc when factoring final weights.
34  bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
35  bool increment_final_olabel; // at state, increment labels to make distinct?
36 
37  explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta,
38  uint32 mode = kFactorArcWeights |
39  kFactorFinalWeights,
40  Label final_ilabel = 0, Label final_olabel = 0,
41  bool increment_final_ilabel = false,
42  bool increment_final_olabel = false)
43  : CacheOptions(opts),
44  delta(delta),
45  mode(mode),
46  final_ilabel(final_ilabel),
47  final_olabel(final_olabel),
48  increment_final_ilabel(increment_final_ilabel),
49  increment_final_olabel(increment_final_olabel) {}
50 
51  explicit FactorWeightOptions(float delta = kDelta,
52  uint32 mode = kFactorArcWeights |
53  kFactorFinalWeights,
54  Label final_ilabel = 0, Label final_olabel = 0,
55  bool increment_final_ilabel = false,
56  bool increment_final_olabel = false)
57  : delta(delta),
58  mode(mode),
59  final_ilabel(final_ilabel),
60  final_olabel(final_olabel),
61  increment_final_ilabel(increment_final_ilabel),
62  increment_final_olabel(increment_final_olabel) {}
63 };
64 
65 // A factor iterator takes as argument a weight w and returns a sequence of
66 // pairs of weights (xi, yi) such that the sum of the products xi times yi is
67 // equal to w. If w is fully factored, the iterator should return nothing.
68 //
69 // template <class W>
70 // class FactorIterator {
71 // public:
72 // explicit FactorIterator(W w);
73 //
74 // bool Done() const;
75 //
76 // void Next();
77 //
78 // std::pair<W, W> Value() const;
79 //
80 // void Reset();
81 // }
82 
83 // Factors trivially.
84 template <class W>
86  public:
87  explicit IdentityFactor(const W &weight) {}
88 
89  bool Done() const { return true; }
90 
91  void Next() {}
92 
93  std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
94 
95  void Reset() {}
96 };
97 
98 // Factors a StringWeight w as 'ab' where 'a' is a label.
99 template <typename Label, StringType S = STRING_LEFT>
101  public:
102  explicit StringFactor(const StringWeight<Label, S> &weight)
103  : weight_(weight), done_(weight.Size() <= 1) {}
104 
105  bool Done() const { return done_; }
106 
107  void Next() { done_ = true; }
108 
109  std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const {
110  using Weight = StringWeight<Label, S>;
111  typename Weight::Iterator siter(weight_);
112  Weight w1(siter.Value());
113  Weight w2;
114  for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
115  return std::make_pair(w1, w2);
116  }
117 
118  void Reset() { done_ = weight_.Size() <= 1; }
119 
120  private:
121  const StringWeight<Label, S> weight_;
122  bool done_;
123 };
124 
125 // Factor a GallicWeight using StringFactor.
126 template <class Label, class W, GallicType G = GALLIC_LEFT>
128  public:
130 
131  explicit GallicFactor(const GW &weight)
132  : weight_(weight), done_(weight.Value1().Size() <= 1) {}
133 
134  bool Done() const { return done_; }
135 
136  void Next() { done_ = true; }
137 
138  std::pair<GW, GW> Value() const {
139  StringFactor<Label, GallicStringType(G)> siter(weight_.Value1());
140  GW w1(siter.Value().first, weight_.Value2());
141  GW w2(siter.Value().second, W::One());
142  return std::make_pair(w1, w2);
143  }
144 
145  void Reset() { done_ = weight_.Value1().Size() <= 1; }
146 
147  private:
148  const GW weight_;
149  bool done_;
150 };
151 
152 // Specialization for the (general) GALLIC type GallicWeight.
153 template <class Label, class W>
155  public:
158 
159  explicit GallicFactor(const GW &weight)
160  : iter_(weight),
161  done_(weight.Size() == 0 ||
162  (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
163 
164  bool Done() const { return done_ || iter_.Done(); }
165 
166  void Next() { iter_.Next(); }
167 
168  void Reset() { iter_.Reset(); }
169 
170  std::pair<GW, GW> Value() const {
171  const auto weight = iter_.Value();
173  weight.Value1());
174  GRW w1(siter.Value().first, weight.Value2());
175  GRW w2(siter.Value().second, W::One());
176  return std::make_pair(GW(w1), GW(w2));
177  }
178 
179  private:
181  bool done_;
182 };
183 
184 namespace internal {
185 
186 // Implementation class for FactorWeight
187 template <class Arc, class FactorIterator>
188 class FactorWeightFstImpl : public CacheImpl<Arc> {
189  public:
190  using Label = typename Arc::Label;
191  using StateId = typename Arc::StateId;
192  using Weight = typename Arc::Weight;
193 
194  using FstImpl<Arc>::SetType;
198 
199  using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
200  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
201  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
202  using CacheBaseImpl<CacheState<Arc>>::HasStart;
203  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
204  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
205  using CacheBaseImpl<CacheState<Arc>>::SetStart;
206 
207  struct Element {
208  Element() {}
209 
210  Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
211 
212  StateId state; // Input state ID.
213  Weight weight; // Residual weight.
214  };
215 
217  : CacheImpl<Arc>(opts),
218  fst_(fst.Copy()),
219  delta_(opts.delta),
220  mode_(opts.mode),
221  final_ilabel_(opts.final_ilabel),
222  final_olabel_(opts.final_olabel),
223  increment_final_ilabel_(opts.increment_final_ilabel),
224  increment_final_olabel_(opts.increment_final_olabel) {
225  SetType("factor_weight");
226  const auto props = fst.Properties(kFstProperties, false);
227  SetProperties(FactorWeightProperties(props), kCopyProperties);
228  SetInputSymbols(fst.InputSymbols());
229  SetOutputSymbols(fst.OutputSymbols());
230  if (mode_ == 0) {
231  LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; "
232  << "factoring neither arc weights nor final weights";
233  }
234  }
235 
237  : CacheImpl<Arc>(impl),
238  fst_(impl.fst_->Copy(true)),
239  delta_(impl.delta_),
240  mode_(impl.mode_),
241  final_ilabel_(impl.final_ilabel_),
242  final_olabel_(impl.final_olabel_),
243  increment_final_ilabel_(impl.increment_final_ilabel_),
244  increment_final_olabel_(impl.increment_final_olabel_) {
245  SetType("factor_weight");
246  SetProperties(impl.Properties(), kCopyProperties);
247  SetInputSymbols(impl.InputSymbols());
248  SetOutputSymbols(impl.OutputSymbols());
249  }
250 
252  if (!HasStart()) {
253  const auto s = fst_->Start();
254  if (s == kNoStateId) return kNoStateId;
255  SetStart(FindState(Element(fst_->Start(), Weight::One())));
256  }
257  return CacheImpl<Arc>::Start();
258  }
259 
261  if (!HasFinal(s)) {
262  const auto &element = elements_[s];
263  // TODO(sorenj): fix so cast is unnecessary
264  const auto weight =
265  element.state == kNoStateId
266  ? element.weight
267  : (Weight)Times(element.weight, fst_->Final(element.state));
268  FactorIterator siter(weight);
269  if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
270  SetFinal(s, weight);
271  } else {
272  SetFinal(s, Weight::Zero());
273  }
274  }
275  return CacheImpl<Arc>::Final(s);
276  }
277 
278  size_t NumArcs(StateId s) {
279  if (!HasArcs(s)) Expand(s);
280  return CacheImpl<Arc>::NumArcs(s);
281  }
282 
284  if (!HasArcs(s)) Expand(s);
286  }
287 
289  if (!HasArcs(s)) Expand(s);
291  }
292 
293  uint64 Properties() const override { return Properties(kFstProperties); }
294 
295  // Sets error if found, and returns other FST impl properties.
296  uint64 Properties(uint64 mask) const override {
297  if ((mask & kError) && fst_->Properties(kError, false)) {
298  SetProperties(kError, kError);
299  }
300  return FstImpl<Arc>::Properties(mask);
301  }
302 
304  if (!HasArcs(s)) Expand(s);
306  }
307 
308  // Finds state corresponding to an element, creating new state if element not
309  // found.
310  StateId FindState(const Element &element) {
311  if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() &&
312  element.state != kNoStateId) {
313  while (unfactored_.size() <= element.state)
314  unfactored_.push_back(kNoStateId);
315  if (unfactored_[element.state] == kNoStateId) {
316  unfactored_[element.state] = elements_.size();
317  elements_.push_back(element);
318  }
319  return unfactored_[element.state];
320  } else {
321  const auto insert_result =
322  element_map_.insert(std::make_pair(element, elements_.size()));
323  if (insert_result.second) {
324  elements_.push_back(element);
325  }
326  return insert_result.first->second;
327  }
328  }
329 
330  // Computes the outgoing transitions from a state, creating new destination
331  // states as needed.
332  void Expand(StateId s) {
333  const auto element = elements_[s];
334  if (element.state != kNoStateId) {
335  for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done();
336  ait.Next()) {
337  const auto &arc = ait.Value();
338  auto weight = Times(element.weight, arc.weight);
339  FactorIterator fiter(weight);
340  if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
341  const auto dest = FindState(Element(arc.nextstate, Weight::One()));
342  EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
343  } else {
344  for (; !fiter.Done(); fiter.Next()) {
345  auto pair = fiter.Value();
346  const auto dest =
347  FindState(Element(arc.nextstate, pair.second.Quantize(delta_)));
348  EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
349  }
350  }
351  }
352  }
353  if ((mode_ & kFactorFinalWeights) &&
354  ((element.state == kNoStateId) ||
355  (fst_->Final(element.state) != Weight::Zero()))) {
356  const auto weight =
357  element.state == kNoStateId
358  ? element.weight
359  : Times(element.weight, fst_->Final(element.state));
360  auto ilabel = final_ilabel_;
361  auto olabel = final_olabel_;
362  for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
363  auto pair = fiter.Value();
364  const auto dest =
365  FindState(Element(kNoStateId, pair.second.Quantize(delta_)));
366  EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
367  if (increment_final_ilabel_) ++ilabel;
368  if (increment_final_olabel_) ++olabel;
369  }
370  }
371  SetArcs(s);
372  }
373 
374  private:
375  // Equality function for Elements, assume weights have been quantized.
376  class ElementEqual {
377  public:
378  bool operator()(const Element &x, const Element &y) const {
379  return x.state == y.state && x.weight == y.weight;
380  }
381  };
382 
383  // Hash function for Elements to Fst states.
384  class ElementKey {
385  public:
386  size_t operator()(const Element &x) const {
387  static constexpr auto prime = 7853;
388  return static_cast<size_t>(x.state * prime + x.weight.Hash());
389  }
390  };
391 
392  using ElementMap =
393  std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
394 
395  std::unique_ptr<const Fst<Arc>> fst_;
396  float delta_;
397  uint32 mode_; // Factoring arc and/or final weights.
398  Label final_ilabel_; // ilabel of arc created when factoring final weights.
399  Label final_olabel_; // olabel of arc created when factoring final weights.
400  bool increment_final_ilabel_; // When factoring final weights results in
401  bool increment_final_olabel_; // mutiple arcs, increment labels?
402  std::vector<Element> elements_; // mapping from FST state to Element.
403  ElementMap element_map_; // mapping from Element to FST state.
404  // Mapping between old/new StateId for states that do not need to be factored
405  // when mode_ is 0 or kFactorFinalWeights.
406  std::vector<StateId> unfactored_;
407 };
408 
409 } // namespace internal
410 
411 // FactorWeightFst takes as template parameter a FactorIterator as defined
412 // above. The result of weight factoring is a transducer equivalent to the
413 // input whose path weights have been factored according to the FactorIterator.
414 // States and transitions will be added as necessary. The algorithm is a
415 // generalization to arbitrary weights of the second step of the input
416 // epsilon-normalization algorithm.
417 //
418 // This class attaches interface to implementation and handles reference
419 // counting, delegating most methods to ImplToFst.
420 template <class A, class FactorIterator>
422  : public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
423  public:
424  using Arc = A;
425  using StateId = typename Arc::StateId;
426  using Weight = typename Arc::Weight;
427 
429  using State = typename Store::State;
431 
432  friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>;
433  friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
434 
435  explicit FactorWeightFst(const Fst<Arc> &fst)
436  : ImplToFst<Impl>(
437  std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
438 
440  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
441 
442  // See Fst<>::Copy() for doc.
444  : ImplToFst<Impl>(fst, copy) {}
445 
446  // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
447  FactorWeightFst<Arc, FactorIterator> *Copy(bool copy = false) const override {
448  return new FactorWeightFst<Arc, FactorIterator>(*this, copy);
449  }
450 
451  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
452 
453  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
454  GetMutableImpl()->InitArcIterator(s, data);
455  }
456 
457  private:
460 
461  FactorWeightFst &operator=(const FactorWeightFst &) = delete;
462 };
463 
464 // Specialization for FactorWeightFst.
465 template <class Arc, class FactorIterator>
466 class StateIterator<FactorWeightFst<Arc, FactorIterator>>
467  : public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> {
468  public:
470  : CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>(
471  fst, fst.GetMutableImpl()) {}
472 };
473 
474 // Specialization for FactorWeightFst.
475 template <class Arc, class FactorIterator>
476 class ArcIterator<FactorWeightFst<Arc, FactorIterator>>
477  : public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> {
478  public:
479  using StateId = typename Arc::StateId;
480 
482  : CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>(
483  fst.GetMutableImpl(), s) {
484  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
485  }
486 };
487 
488 template <class Arc, class FactorIterator>
490  StateIteratorData<Arc> *data) const {
492 }
493 
494 } // namespace fst
495 
496 #endif // FST_FACTOR_WEIGHT_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:99
bool Done() const
Definition: factor-weight.h:89
FactorWeightFst(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
constexpr uint32 kFactorArcWeights
Definition: factor-weight.h:24
uint64_t uint64
Definition: types.h:32
uint64 Properties() const override
FactorWeightOptions(float delta=kDelta, uint32 mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
Definition: factor-weight.h:51
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:82
const SymbolTable * OutputSymbols() const
Definition: fst.h:690
StateId FindState(const Element &element)
#define LOG(type)
Definition: log.h:48
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
SetType
Definition: set-weight.h:37
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
Definition: cache.h:1143
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:88
constexpr uint64 kFstProperties
Definition: properties.h:301
constexpr uint64 kCopyProperties
Definition: properties.h:138
constexpr int kNoStateId
Definition: fst.h:180
std::pair< W, W > Value() const
Definition: factor-weight.h:93
constexpr uint32 kFactorFinalWeights
Definition: factor-weight.h:23
StringFactor(const StringWeight< Label, S > &weight)
virtual uint64 Properties(uint64 mask, bool test) const =0
bool Done() const
StateIteratorBase< Arc > * base
Definition: fst.h:351
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:93
FactorWeightFstImpl(const FactorWeightFstImpl< Arc, FactorIterator > &impl)
FactorWeightFst(const FactorWeightFst< Arc, FactorIterator > &fst, bool copy)
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:644
std::pair< GW, GW > Value() const
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
FactorWeightOptions(const CacheOptions &opts, float delta=kDelta, uint32 mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
Definition: factor-weight.h:37
FactorWeightFst(const Fst< Arc > &fst)
uint32_t uint32
Definition: types.h:31
bool Done() const
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
Definition: cache.h:1189
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:688
FactorWeightFst< Arc, FactorIterator > * Copy(bool copy=false) const override
std::pair< StringWeight< Label, S >, StringWeight< Label, S > > Value() const
constexpr uint64 kError
Definition: properties.h:33
FactorWeightFstImpl(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
typename Arc::Label Label
Definition: factor-weight.h:28
void InitStateIterator(StateIteratorData< Arc > *data) const override
uint64 Properties(uint64 mask) const override
GallicFactor(const GW &weight)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:302
typename CacheState< Arc >::Arc Arc
Definition: cache.h:833
Impl * GetMutableImpl() const
Definition: fst.h:947
ArcIterator(const FactorWeightFst< Arc, FactorIterator > &fst, StateId s)
constexpr float kDelta
Definition: weight.h:109
IdentityFactor(const W &weight)
Definition: factor-weight.h:87
uint64 FactorWeightProperties(uint64 inprops)
Definition: properties.cc:137
std::pair< GW, GW > Value() const
StateIterator(const FactorWeightFst< Arc, FactorIterator > &fst)
const Impl * GetImpl() const
Definition: fst.h:945
virtual const SymbolTable * OutputSymbols() const =0