FST  openfst-1.8.3
OpenFst Library
factor-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to factor weights in an FST.
19 
20 #ifndef FST_FACTOR_WEIGHT_H_
21 #define FST_FACTOR_WEIGHT_H_
22 
23 #include <algorithm>
24 #include <cstddef>
25 #include <cstdint>
26 #include <memory>
27 #include <string>
28 #include <utility>
29 #include <vector>
30 
31 #include <fst/log.h>
32 #include <fst/cache.h>
33 #include <fst/fst.h>
34 #include <fst/impl-to-fst.h>
35 #include <fst/properties.h>
36 #include <fst/string-weight.h>
37 #include <fst/union-weight.h>
38 #include <fst/weight.h>
39 #include <unordered_map>
40 
41 namespace fst {
42 
43 inline constexpr uint8_t kFactorFinalWeights = 0x01;
44 inline constexpr uint8_t kFactorArcWeights = 0x02;
45 
46 template <class Arc>
48  using Label = typename Arc::Label;
49 
50  float delta;
51  uint8_t mode; // Factor arc weights and/or final weights.
52  Label final_ilabel; // Input label of arc when factoring final weights.
53  Label final_olabel; // Output label of arc when factoring final weights.
54  bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
55  bool increment_final_olabel; // at state, increment labels to make distinct?
56 
57  explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta,
58  uint8_t mode = kFactorArcWeights |
59  kFactorFinalWeights,
60  Label final_ilabel = 0, Label final_olabel = 0,
61  bool increment_final_ilabel = false,
62  bool increment_final_olabel = false)
63  : CacheOptions(opts),
64  delta(delta),
65  mode(mode),
66  final_ilabel(final_ilabel),
67  final_olabel(final_olabel),
68  increment_final_ilabel(increment_final_ilabel),
69  increment_final_olabel(increment_final_olabel) {}
70 
71  explicit FactorWeightOptions(float delta = kDelta,
72  uint8_t mode = kFactorArcWeights |
73  kFactorFinalWeights,
74  Label final_ilabel = 0, Label final_olabel = 0,
75  bool increment_final_ilabel = false,
76  bool increment_final_olabel = false)
77  : delta(delta),
78  mode(mode),
79  final_ilabel(final_ilabel),
80  final_olabel(final_olabel),
81  increment_final_ilabel(increment_final_ilabel),
82  increment_final_olabel(increment_final_olabel) {}
83 };
84 
85 // A factor iterator takes as argument a weight w and returns a sequence of
86 // pairs of weights (xi, yi) such that the sum of the products xi times yi is
87 // equal to w. If w is fully factored, the iterator should return nothing.
88 //
89 // template <class W>
90 // class FactorIterator {
91 // public:
92 // explicit FactorIterator(W w);
93 //
94 // bool Done() const;
95 //
96 // void Next();
97 //
98 // std::pair<W, W> Value() const;
99 //
100 // void Reset();
101 // }
102 
103 // Factors trivially.
104 template <class W>
106  public:
107  explicit IdentityFactor(const W &weight) {}
108 
109  bool Done() const { return true; }
110 
111  void Next() {}
112 
113  std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
114 
115  void Reset() {}
116 };
117 
118 // Factor the Fst to unfold it as needed so that every two paths leading to the
119 // same state have the same weight. Requires applying only to arc weights
120 // (FactorWeightOptions::mode == kFactorArcWeights).
121 template <class W>
122 class OneFactor {
123  public:
124  explicit OneFactor(const W &w) : weight_(w), done_(w == W::One()) {}
125 
126  bool Done() const { return done_; }
127 
128  void Next() { done_ = true; }
129 
130  std::pair<W, W> Value() const { return std::make_pair(W::One(), weight_); }
131 
132  void Reset() { done_ = weight_ == W::One(); }
133 
134  private:
135  W weight_;
136  bool done_;
137 };
138 
139 // Factors a StringWeight w as 'ab' where 'a' is a label.
140 template <typename Label, StringType S = STRING_LEFT>
142  public:
143  explicit StringFactor(const StringWeight<Label, S> &weight)
144  : weight_(weight), done_(weight.Size() <= 1) {}
145 
146  bool Done() const { return done_; }
147 
148  void Next() { done_ = true; }
149 
150  std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const {
151  using Weight = StringWeight<Label, S>;
152  typename Weight::Iterator siter(weight_);
153  Weight w1(siter.Value());
154  Weight w2;
155  for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
156  return std::make_pair(w1, w2);
157  }
158 
159  void Reset() { done_ = weight_.Size() <= 1; }
160 
161  private:
162  const StringWeight<Label, S> weight_;
163  bool done_;
164 };
165 
166 // Factor a GallicWeight using StringFactor.
167 template <class Label, class W, GallicType G = GALLIC_LEFT>
169  public:
171 
172  explicit GallicFactor(const GW &weight)
173  : weight_(weight), done_(weight.Value1().Size() <= 1) {}
174 
175  bool Done() const { return done_; }
176 
177  void Next() { done_ = true; }
178 
179  std::pair<GW, GW> Value() const {
180  StringFactor<Label, GallicStringType(G)> siter(weight_.Value1());
181  GW w1(siter.Value().first, weight_.Value2());
182  GW w2(siter.Value().second, W::One());
183  return std::make_pair(w1, w2);
184  }
185 
186  void Reset() { done_ = weight_.Value1().Size() <= 1; }
187 
188  private:
189  const GW weight_;
190  bool done_;
191 };
192 
193 // Specialization for the (general) GALLIC type GallicWeight.
194 template <class Label, class W>
196  public:
199 
200  explicit GallicFactor(const GW &weight)
201  : iter_(weight),
202  done_(weight.Size() == 0 ||
203  (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
204 
205  bool Done() const { return done_ || iter_.Done(); }
206 
207  void Next() { iter_.Next(); }
208 
209  void Reset() { iter_.Reset(); }
210 
211  std::pair<GW, GW> Value() const {
212  const auto weight = iter_.Value();
214  weight.Value1());
215  GRW w1(siter.Value().first, weight.Value2());
216  GRW w2(siter.Value().second, W::One());
217  return std::make_pair(GW(w1), GW(w2));
218  }
219 
220  private:
222  bool done_;
223 };
224 
225 namespace internal {
226 
227 // Implementation class for FactorWeight
228 template <class Arc, class FactorIterator>
229 class FactorWeightFstImpl : public CacheImpl<Arc> {
230  public:
231  using Label = typename Arc::Label;
232  using StateId = typename Arc::StateId;
233  using Weight = typename Arc::Weight;
234 
235  using FstImpl<Arc>::SetType;
239 
240  using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
241  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
242  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
243  using CacheBaseImpl<CacheState<Arc>>::HasStart;
244  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
245  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
246  using CacheBaseImpl<CacheState<Arc>>::SetStart;
247 
248  struct Element {
249  Element() = default;
250 
251  Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
252 
253  StateId state; // Input state ID.
254  Weight weight; // Residual weight.
255  };
256 
258  : CacheImpl<Arc>(opts),
259  fst_(fst.Copy()),
260  delta_(opts.delta),
261  mode_(opts.mode),
262  final_ilabel_(opts.final_ilabel),
263  final_olabel_(opts.final_olabel),
264  increment_final_ilabel_(opts.increment_final_ilabel),
265  increment_final_olabel_(opts.increment_final_olabel) {
266  SetType("factor_weight");
267  const auto props = fst.Properties(kFstProperties, false);
268  SetProperties(FactorWeightProperties(props), kCopyProperties);
269  SetInputSymbols(fst.InputSymbols());
270  SetOutputSymbols(fst.OutputSymbols());
271  if (mode_ == 0) {
272  LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; "
273  << "factoring neither arc weights nor final weights";
274  }
275  }
276 
278  : CacheImpl<Arc>(impl),
279  fst_(impl.fst_->Copy(true)),
280  delta_(impl.delta_),
281  mode_(impl.mode_),
282  final_ilabel_(impl.final_ilabel_),
283  final_olabel_(impl.final_olabel_),
284  increment_final_ilabel_(impl.increment_final_ilabel_),
285  increment_final_olabel_(impl.increment_final_olabel_) {
286  SetType("factor_weight");
287  SetProperties(impl.Properties(), kCopyProperties);
288  SetInputSymbols(impl.InputSymbols());
289  SetOutputSymbols(impl.OutputSymbols());
290  }
291 
293  if (!HasStart()) {
294  const auto s = fst_->Start();
295  if (s == kNoStateId) return kNoStateId;
296  SetStart(FindState(Element(fst_->Start(), Weight::One())));
297  }
298  return CacheImpl<Arc>::Start();
299  }
300 
302  if (!HasFinal(s)) {
303  const auto &element = elements_[s];
304  const auto weight =
305  element.state == kNoStateId
306  ? element.weight
307  : Times(element.weight, fst_->Final(element.state));
308  FactorIterator siter(weight);
309  if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
310  SetFinal(s, weight);
311  } else {
312  SetFinal(s, Weight::Zero());
313  }
314  }
315  return CacheImpl<Arc>::Final(s);
316  }
317 
318  size_t NumArcs(StateId s) {
319  if (!HasArcs(s)) Expand(s);
320  return CacheImpl<Arc>::NumArcs(s);
321  }
322 
324  if (!HasArcs(s)) Expand(s);
326  }
327 
329  if (!HasArcs(s)) Expand(s);
331  }
332 
333  uint64_t Properties() const override { return Properties(kFstProperties); }
334 
335  // Sets error if found, and returns other FST impl properties.
336  uint64_t Properties(uint64_t mask) const override {
337  if ((mask & kError) && fst_->Properties(kError, false)) {
338  SetProperties(kError, kError);
339  }
340  return FstImpl<Arc>::Properties(mask);
341  }
342 
344  if (!HasArcs(s)) Expand(s);
346  }
347 
348  // Finds state corresponding to an element, creating new state if element not
349  // found.
350  StateId FindState(const Element &element) {
351  if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() &&
352  element.state != kNoStateId) {
353  while (unfactored_.size() <= element.state)
354  unfactored_.push_back(kNoStateId);
355  if (unfactored_[element.state] == kNoStateId) {
356  unfactored_[element.state] = elements_.size();
357  elements_.push_back(element);
358  }
359  return unfactored_[element.state];
360  } else {
361  const auto insert_result =
362  element_map_.emplace(element, elements_.size());
363  if (insert_result.second) {
364  elements_.push_back(element);
365  }
366  return insert_result.first->second;
367  }
368  }
369 
370  // Computes the outgoing transitions from a state, creating new destination
371  // states as needed.
372  void Expand(StateId s) {
373  const auto element = elements_[s];
374  if (element.state != kNoStateId) {
375  for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done();
376  ait.Next()) {
377  const auto &arc = ait.Value();
378  auto weight = Times(element.weight, arc.weight);
379  FactorIterator fiter(weight);
380  if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
381  const auto dest = FindState(Element(arc.nextstate, Weight::One()));
382  EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
383  } else {
384  for (; !fiter.Done(); fiter.Next()) {
385  auto pair = fiter.Value();
386  const auto dest =
387  FindState(Element(arc.nextstate, pair.second.Quantize(delta_)));
388  EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
389  }
390  }
391  }
392  }
393  if ((mode_ & kFactorFinalWeights) &&
394  ((element.state == kNoStateId) ||
395  (fst_->Final(element.state) != Weight::Zero()))) {
396  const auto weight =
397  element.state == kNoStateId
398  ? element.weight
399  : Times(element.weight, fst_->Final(element.state));
400  auto ilabel = final_ilabel_;
401  auto olabel = final_olabel_;
402  for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
403  auto pair = fiter.Value();
404  const auto dest =
405  FindState(Element(kNoStateId, pair.second.Quantize(delta_)));
406  EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
407  if (increment_final_ilabel_) ++ilabel;
408  if (increment_final_olabel_) ++olabel;
409  }
410  }
411  SetArcs(s);
412  }
413 
414  private:
415  // Equality function for Elements, assume weights have been quantized.
416  class ElementEqual {
417  public:
418  bool operator()(const Element &x, const Element &y) const {
419  return x.state == y.state && x.weight == y.weight;
420  }
421  };
422 
423  // Hash function for Elements to Fst states.
424  class ElementKey {
425  public:
426  size_t operator()(const Element &x) const {
427  static constexpr auto prime = 7853;
428  return static_cast<size_t>(x.state * prime + x.weight.Hash());
429  }
430  };
431 
432  using ElementMap =
433  std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
434 
435  std::unique_ptr<const Fst<Arc>> fst_;
436  float delta_;
437  uint8_t mode_; // Factoring arc and/or final weights.
438  Label final_ilabel_; // ilabel of arc created when factoring final weights.
439  Label final_olabel_; // olabel of arc created when factoring final weights.
440  bool increment_final_ilabel_; // When factoring final weights results in
441  bool increment_final_olabel_; // mutiple arcs, increment labels?
442  std::vector<Element> elements_; // Mapping from FST state to Element.
443  ElementMap element_map_; // Mapping from Element to FST state.
444  // Mapping between old/new StateId for states that do not need to be factored
445  // when mode_ is 0 or kFactorFinalWeights.
446  std::vector<StateId> unfactored_;
447 };
448 
449 } // namespace internal
450 
451 // FactorWeightFst takes as template parameter a FactorIterator as defined
452 // above. The result of weight factoring is a transducer equivalent to the
453 // input whose path weights have been factored according to the FactorIterator.
454 // States and transitions will be added as necessary. The algorithm is a
455 // generalization to arbitrary weights of the second step of the input
456 // epsilon-normalization algorithm.
457 //
458 // This class attaches interface to implementation and handles reference
459 // counting, delegating most methods to ImplToFst.
460 template <class A, class FactorIterator>
462  : public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
463  public:
464  using Arc = A;
465  using StateId = typename Arc::StateId;
466  using Weight = typename Arc::Weight;
467 
469  using State = typename Store::State;
471 
472  friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>;
473  friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
474 
475  explicit FactorWeightFst(const Fst<Arc> &fst)
476  : ImplToFst<Impl>(
477  std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
478 
480  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
481 
482  // See Fst<>::Copy() for doc.
484  : ImplToFst<Impl>(fst, copy) {}
485 
486  // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
487  FactorWeightFst *Copy(bool copy = false) const override {
488  return new FactorWeightFst(*this, copy);
489  }
490 
491  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
492 
493  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
494  GetMutableImpl()->InitArcIterator(s, data);
495  }
496 
497  private:
500 
501  FactorWeightFst &operator=(const FactorWeightFst &) = delete;
502 };
503 
504 // Specialization for FactorWeightFst.
505 template <class Arc, class FactorIterator>
506 class StateIterator<FactorWeightFst<Arc, FactorIterator>>
507  : public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> {
508  public:
510  : CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>(
511  fst, fst.GetMutableImpl()) {}
512 };
513 
514 // Specialization for FactorWeightFst.
515 template <class Arc, class FactorIterator>
516 class ArcIterator<FactorWeightFst<Arc, FactorIterator>>
517  : public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> {
518  public:
519  using StateId = typename Arc::StateId;
520 
522  : CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>(
523  fst.GetMutableImpl(), s) {
524  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
525  }
526 };
527 
528 template <class Arc, class FactorIterator>
530  StateIteratorData<Arc> *data) const {
531  data->base =
532  std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
533  *this);
534 }
535 
536 } // namespace fst
537 
538 #endif // FST_FACTOR_WEIGHT_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:124
bool Done() const
FactorWeightFst(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
FactorWeightFst(const FactorWeightFst &fst, bool copy)
std::pair< W, W > Value() const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
uint64_t Properties(uint64_t mask) const override
uint64_t Properties() const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:107
const SymbolTable * OutputSymbols() const
Definition: fst.h:761
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
constexpr uint64_t kError
Definition: properties.h:52
StateId FindState(const Element &element)
#define LOG(type)
Definition: log.h:53
SetType
Definition: set-weight.h:59
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
Definition: cache.h:1156
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:113
constexpr int kNoStateId
Definition: fst.h:196
std::pair< W, W > Value() const
StringFactor(const StringWeight< Label, S > &weight)
bool Done() const
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:118
FactorWeightFstImpl(const FactorWeightFstImpl< Arc, FactorIterator > &impl)
constexpr uint64_t kCopyProperties
Definition: properties.h:163
uint64_t FactorWeightProperties(uint64_t inprops)
Definition: properties.cc:155
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:673
std::pair< GW, GW > Value() const
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
FactorWeightOptions(float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
Definition: factor-weight.h:71
FactorWeightOptions(const CacheOptions &opts, float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
Definition: factor-weight.h:57
FactorWeightFst(const Fst< Arc > &fst)
constexpr uint64_t kFstProperties
Definition: properties.h:326
constexpr uint8_t kFactorFinalWeights
Definition: factor-weight.h:43
bool Done() const
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
Definition: cache.h:1202
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:759
std::pair< StringWeight< Label, S >, StringWeight< Label, S > > Value() const
FactorWeightFstImpl(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
FactorWeightFst * Copy(bool copy=false) const override
OneFactor(const W &w)
typename Arc::Label Label
Definition: factor-weight.h:48
void InitStateIterator(StateIteratorData< Arc > *data) const override
constexpr uint8_t kFactorArcWeights
Definition: factor-weight.h:44
GallicFactor(const GW &weight)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:323
bool Done() const
typename CacheState< Arc >::Arc Arc
Definition: cache.h:859
Impl * GetMutableImpl() const
Definition: impl-to-fst.h:125
ArcIterator(const FactorWeightFst< Arc, FactorIterator > &fst, StateId s)
constexpr float kDelta
Definition: weight.h:133
IdentityFactor(const W &weight)
std::pair< GW, GW > Value() const
StateIterator(const FactorWeightFst< Arc, FactorIterator > &fst)
const Impl * GetImpl() const
Definition: impl-to-fst.h:123
virtual const SymbolTable * OutputSymbols() const =0