FST  openfst-1.8.2
OpenFst Library
reweight.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Function to reweight an FST.
19 
20 #ifndef FST_REWEIGHT_H_
21 #define FST_REWEIGHT_H_
22 
23 #include <cstdint>
24 #include <vector>
25 
26 #include <fst/log.h>
27 
28 #include <fst/mutable-fst.h>
29 #include <fst/properties.h>
30 
31 
32 namespace fst {
33 
35 
36 // Reweights an FST according to a vector of potentials in a given direction.
37 // The weight must be left distributive when reweighting towards the initial
38 // state and right distributive when reweighting towards the final states.
39 //
40 // An arc of weight w, with an origin state of potential p and destination state
41 // of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting
42 // torwards the initial state, and by (p \otimes w) \otimes q^-1 when
43 // reweighting towards the final states.
44 template <class Arc>
46  const std::vector<typename Arc::Weight> &potential,
47  ReweightType type) {
48  using Weight = typename Arc::Weight;
49  if (fst->NumStates() == 0) return;
50  // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
51  // way to "deregister" this operation for non-distributive semirings so an
52  // informative error message is produced.
53  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
54  FSTERROR() << "Reweight: Reweighting to the final states requires "
55  << "Weight to be right distributive: " << Weight::Type();
57  return;
58  }
59  // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
60  // way to "deregister" this operation for non-distributive semirings so an
61  // informative error message is produced.
62  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
63  FSTERROR() << "Reweight: Reweighting to the initial state requires "
64  << "Weight to be left distributive: " << Weight::Type();
66  return;
67  }
68  const uint64_t input_props = fst->Properties(kFstProperties, false);
69  StateIterator<MutableFst<Arc>> siter(*fst);
70  for (; !siter.Done(); siter.Next()) {
71  const auto s = siter.Value();
72  if (s == potential.size()) break;
73  const auto &weight = potential[s];
74  if (weight != Weight::Zero()) {
75  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
76  aiter.Next()) {
77  auto arc = aiter.Value();
78  if (arc.nextstate >= potential.size()) continue;
79  const auto &nextweight = potential[arc.nextstate];
80  if (nextweight == Weight::Zero()) continue;
81  if (type == REWEIGHT_TO_INITIAL) {
82  arc.weight =
83  Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT);
84  }
85  if (type == REWEIGHT_TO_FINAL) {
86  arc.weight =
87  Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT);
88  }
89  aiter.SetValue(arc);
90  }
91  if (type == REWEIGHT_TO_INITIAL) {
92  fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT));
93  }
94  }
95  if (type == REWEIGHT_TO_FINAL) {
96  fst->SetFinal(s, Times(weight, fst->Final(s)));
97  }
98  }
99  // This handles elements past the end of the potentials array.
100  for (; !siter.Done(); siter.Next()) {
101  const auto s = siter.Value();
102  if (type == REWEIGHT_TO_FINAL) {
103  fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s)));
104  }
105  }
106  const auto startweight = fst->Start() < potential.size()
107  ? potential[fst->Start()]
108  : Weight::Zero();
109  bool added_start_epsilon = false;
110  if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
111  if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
112  const auto s = fst->Start();
113  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
114  aiter.Next()) {
115  auto arc = aiter.Value();
116  if (type == REWEIGHT_TO_INITIAL) {
117  arc.weight = Times(startweight, arc.weight);
118  } else {
119  arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
120  arc.weight);
121  }
122  aiter.SetValue(arc);
123  }
124  if (type == REWEIGHT_TO_INITIAL) {
125  fst->SetFinal(s, Times(startweight, fst->Final(s)));
126  } else {
127  fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
128  fst->Final(s)));
129  }
130  } else {
131  const auto s = fst->AddState();
132  const auto weight =
133  (type == REWEIGHT_TO_INITIAL)
134  ? startweight
135  : Divide(Weight::One(), startweight, DIVIDE_RIGHT);
136  fst->AddArc(s, Arc(0, 0, weight, fst->Start()));
137  fst->SetStart(s);
138  added_start_epsilon = true;
139  }
140  }
141  fst->SetProperties(ReweightProperties(input_props, added_start_epsilon) |
142  fst->Properties(kFstProperties, false),
144 }
145 
146 } // namespace fst
147 
148 #endif // FST_REWEIGHT_H_
virtual uint64_t Properties(uint64_t mask, bool test) const =0
ReweightType
Definition: reweight.h:34
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
constexpr uint64_t kError
Definition: properties.h:51
constexpr uint64_t kInitialAcyclic
Definition: properties.h:115
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
constexpr uint64_t kRightSemiring
Definition: weight.h:136
#define FSTERROR()
Definition: util.h:53
virtual void SetProperties(uint64_t props, uint64_t mask)=0
uint64_t ReweightProperties(uint64_t inprops, bool added_start_epsilon)
Definition: properties.cc:348
virtual StateId Start() const =0
StateId Value() const
Definition: fst.h:419
constexpr uint64_t kFstProperties
Definition: properties.h:325
virtual void AddArc(StateId, const Arc &)=0
void Next()
Definition: fst.h:421
bool Done() const
Definition: fst.h:415
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:66
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual StateId NumStates() const =0
constexpr uint64_t kLeftSemiring
Definition: weight.h:133
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
Definition: reweight.h:45