FST  openfst-1.8.3
OpenFst Library
reweight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Function to reweight an FST.
19 
20 #ifndef FST_REWEIGHT_H_
21 #define FST_REWEIGHT_H_
22 
23 #include <cstdint>
24 #include <vector>
25 
26 #include <fst/log.h>
27 #include <fst/fst.h>
28 #include <fst/mutable-fst.h>
29 #include <fst/properties.h>
30 #include <fst/util.h>
31 #include <fst/weight.h>
32 
33 namespace fst {
34 
36 
37 // Reweights an FST according to a vector of potentials in a given direction.
38 // The weight must be left distributive when reweighting towards the initial
39 // state and right distributive when reweighting towards the final states.
40 //
41 // An arc of weight w, with an origin state of potential p and destination state
42 // of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting
43 // torwards the initial state, and by (p \otimes w) \otimes q^-1 when
44 // reweighting towards the final states.
45 template <class Arc>
47  const std::vector<typename Arc::Weight> &potential,
48  ReweightType type) {
49  using Weight = typename Arc::Weight;
50  if (fst->NumStates() == 0) return;
51  // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
52  // way to "deregister" this operation for non-distributive semirings so an
53  // informative error message is produced.
54  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
55  FSTERROR() << "Reweight: Reweighting to the final states requires "
56  << "Weight to be right distributive: " << Weight::Type();
58  return;
59  }
60  // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
61  // way to "deregister" this operation for non-distributive semirings so an
62  // informative error message is produced.
63  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
64  FSTERROR() << "Reweight: Reweighting to the initial state requires "
65  << "Weight to be left distributive: " << Weight::Type();
67  return;
68  }
69  const uint64_t input_props = fst->Properties(kFstProperties, false);
70  StateIterator<MutableFst<Arc>> siter(*fst);
71  for (; !siter.Done(); siter.Next()) {
72  const auto s = siter.Value();
73  if (s == potential.size()) break;
74  const auto &weight = potential[s];
75  if (weight != Weight::Zero()) {
76  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
77  aiter.Next()) {
78  auto arc = aiter.Value();
79  if (arc.nextstate >= potential.size()) continue;
80  const auto &nextweight = potential[arc.nextstate];
81  if (nextweight == Weight::Zero()) continue;
82  if (type == REWEIGHT_TO_INITIAL) {
83  arc.weight =
84  Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT);
85  }
86  if (type == REWEIGHT_TO_FINAL) {
87  arc.weight =
88  Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT);
89  }
90  aiter.SetValue(arc);
91  }
92  if (type == REWEIGHT_TO_INITIAL) {
93  fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT));
94  }
95  }
96  if (type == REWEIGHT_TO_FINAL) {
97  fst->SetFinal(s, Times(weight, fst->Final(s)));
98  }
99  }
100  // This handles elements past the end of the potentials array.
101  for (; !siter.Done(); siter.Next()) {
102  const auto s = siter.Value();
103  if (type == REWEIGHT_TO_FINAL) {
104  fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s)));
105  }
106  }
107  const auto startweight = fst->Start() < potential.size()
108  ? potential[fst->Start()]
109  : Weight::Zero();
110  bool added_start_epsilon = false;
111  if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
112  if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
113  const auto s = fst->Start();
114  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
115  aiter.Next()) {
116  auto arc = aiter.Value();
117  if (type == REWEIGHT_TO_INITIAL) {
118  arc.weight = Times(startweight, arc.weight);
119  } else {
120  arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
121  arc.weight);
122  }
123  aiter.SetValue(arc);
124  }
125  if (type == REWEIGHT_TO_INITIAL) {
126  fst->SetFinal(s, Times(startweight, fst->Final(s)));
127  } else {
128  fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
129  fst->Final(s)));
130  }
131  } else {
132  const auto s = fst->AddState();
133  const auto weight =
134  (type == REWEIGHT_TO_INITIAL)
135  ? startweight
136  : Divide(Weight::One(), startweight, DIVIDE_RIGHT);
137  fst->AddArc(s, Arc(0, 0, weight, fst->Start()));
138  fst->SetStart(s);
139  added_start_epsilon = true;
140  }
141  }
142  fst->SetProperties(ReweightProperties(input_props, added_start_epsilon) |
143  fst->Properties(kFstProperties, false),
145 }
146 
147 } // namespace fst
148 
149 #endif // FST_REWEIGHT_H_
virtual uint64_t Properties(uint64_t mask, bool test) const =0
ReweightType
Definition: reweight.h:35
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
constexpr uint64_t kError
Definition: properties.h:52
constexpr uint64_t kInitialAcyclic
Definition: properties.h:116
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
constexpr uint64_t kRightSemiring
Definition: weight.h:139
#define FSTERROR()
Definition: util.h:56
virtual void SetProperties(uint64_t props, uint64_t mask)=0
uint64_t ReweightProperties(uint64_t inprops, bool added_start_epsilon)
Definition: properties.cc:347
virtual StateId Start() const =0
StateId Value() const
Definition: fst.h:419
constexpr uint64_t kFstProperties
Definition: properties.h:326
virtual void AddArc(StateId, const Arc &)=0
void Next()
Definition: fst.h:421
bool Done() const
Definition: fst.h:415
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual StateId NumStates() const =0
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
Definition: reweight.h:46