FST  openfst-1.7.1
OpenFst Library
reweight.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Function to reweight an FST.
5 
6 #ifndef FST_REWEIGHT_H_
7 #define FST_REWEIGHT_H_
8 
9 #include <vector>
10 #include <fst/log.h>
11 
12 #include <fst/mutable-fst.h>
13 
14 
15 namespace fst {
16 
18 
19 // Reweights an FST according to a vector of potentials in a given direction.
20 // The weight must be left distributive when reweighting towards the initial
21 // state and right distributive when reweighting towards the final states.
22 //
23 // An arc of weight w, with an origin state of potential p and destination state
24 // of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting
25 // torwards the initial state, and by (p \otimes w) \otimes q^-1 when
26 // reweighting towards the final states.
27 template <class Arc>
29  const std::vector<typename Arc::Weight> &potential,
30  ReweightType type) {
31  using Weight = typename Arc::Weight;
32  if (fst->NumStates() == 0) return;
33  // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
34  // way to "deregister" this operation for non-distributive semirings so an
35  // informative error message is produced.
36  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
37  FSTERROR() << "Reweight: Reweighting to the final states requires "
38  << "Weight to be right distributive: " << Weight::Type();
40  return;
41  }
42  // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
43  // way to "deregister" this operation for non-distributive semirings so an
44  // informative error message is produced.
45  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
46  FSTERROR() << "Reweight: Reweighting to the initial state requires "
47  << "Weight to be left distributive: " << Weight::Type();
49  return;
50  }
51  StateIterator<MutableFst<Arc>> siter(*fst);
52  for (; !siter.Done(); siter.Next()) {
53  const auto s = siter.Value();
54  if (s == potential.size()) break;
55  const auto &weight = potential[s];
56  if (weight != Weight::Zero()) {
57  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
58  aiter.Next()) {
59  auto arc = aiter.Value();
60  if (arc.nextstate >= potential.size()) continue;
61  const auto &nextweight = potential[arc.nextstate];
62  if (nextweight == Weight::Zero()) continue;
63  if (type == REWEIGHT_TO_INITIAL) {
64  arc.weight =
65  Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT);
66  }
67  if (type == REWEIGHT_TO_FINAL) {
68  arc.weight =
69  Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT);
70  }
71  aiter.SetValue(arc);
72  }
73  if (type == REWEIGHT_TO_INITIAL) {
74  fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT));
75  }
76  }
77  if (type == REWEIGHT_TO_FINAL) {
78  fst->SetFinal(s, Times(weight, fst->Final(s)));
79  }
80  }
81  // This handles elements past the end of the potentials array.
82  for (; !siter.Done(); siter.Next()) {
83  const auto s = siter.Value();
84  if (type == REWEIGHT_TO_FINAL) {
85  fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s)));
86  }
87  }
88  const auto startweight = fst->Start() < potential.size()
89  ? potential[fst->Start()]
90  : Weight::Zero();
91  if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
92  if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
93  const auto s = fst->Start();
94  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
95  aiter.Next()) {
96  auto arc = aiter.Value();
97  if (type == REWEIGHT_TO_INITIAL) {
98  arc.weight = Times(startweight, arc.weight);
99  } else {
100  arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
101  arc.weight);
102  }
103  aiter.SetValue(arc);
104  }
105  if (type == REWEIGHT_TO_INITIAL) {
106  fst->SetFinal(s, Times(startweight, fst->Final(s)));
107  } else {
108  fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
109  fst->Final(s)));
110  }
111  } else {
112  const auto s = fst->AddState();
113  const auto weight =
114  (type == REWEIGHT_TO_INITIAL)
115  ? startweight
116  : Divide(Weight::One(), startweight, DIVIDE_RIGHT);
117  fst->AddArc(s, Arc(0, 0, weight, fst->Start()));
118  fst->SetStart(s);
119  }
120  }
123 }
124 
125 } // namespace fst
126 
127 #endif // FST_REWEIGHT_H_
ExpectationWeight< X1, X2 > Divide(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2, DivideType typ=DIVIDE_ANY)
constexpr uint64 kInitialAcyclic
Definition: properties.h:97
constexpr uint64 kRightSemiring
Definition: weight.h:115
virtual void AddArc(StateId, const Arc &arc)=0
ReweightType
Definition: reweight.h:17
virtual Weight Final(StateId) const =0
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
virtual void SetStart(StateId)=0
constexpr uint64 kLeftSemiring
Definition: weight.h:112
constexpr uint64 kFstProperties
Definition: properties.h:301
virtual uint64 Properties(uint64 mask, bool test) const =0
uint64 ReweightProperties(uint64 inprops)
Definition: properties.cc:328
#define FSTERROR()
Definition: util.h:35
virtual void SetFinal(StateId, Weight)=0
virtual StateId Start() const =0
StateId Value() const
Definition: fst.h:387
void Next()
Definition: fst.h:389
bool Done() const
Definition: fst.h:383
constexpr uint64 kError
Definition: properties.h:33
virtual StateId AddState()=0
virtual StateId NumStates() const =0
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
Definition: reweight.h:28
virtual void SetProperties(uint64 props, uint64 mask)=0