FST  openfst-1.8.2
OpenFst Library
push.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Class to reweight/push an FST, and utility functions to weigh and reweight
19 // an FST.
20 
21 #ifndef FST_PUSH_H_
22 #define FST_PUSH_H_
23 
24 #include <cstdint>
25 #include <vector>
26 
27 #include <fst/log.h>
28 
29 #include <fst/arc-map.h>
30 #include <fst/factor-weight.h>
31 #include <fst/fst.h>
32 #include <fst/reweight.h>
33 #include <fst/shortest-distance.h>
34 
35 
36 namespace fst {
37 
38 // Computes the total weight (sum of the weights of all accepting paths) from
39 // the output of ShortestDistance, using the shortest distance from the final
40 // state when reverse is true and from the initial state otherwise.
41 template <class Arc>
42 typename Arc::Weight ComputeTotalWeight(
43  const Fst<Arc> &fst, const std::vector<typename Arc::Weight> &distance,
44  bool reverse) {
45  if (reverse) {
46  return fst.Start() < distance.size() ? distance[fst.Start()]
47  : Arc::Weight::Zero();
48  }
49  auto sum = Arc::Weight::Zero();
50  for (typename Arc::StateId s = 0; s < distance.size(); ++s) {
51  sum = Plus(sum, Times(distance[s], fst.Final(s)));
52  }
53  return sum;
54 }
55 
56 // Divides the weight of every accepting path by a fixed weight. This weight
57 // is also divided at the final state if at_final is true and at the initial
58 // state otherwise.
59 template <class Arc>
60 void RemoveWeight(MutableFst<Arc> *fst, const typename Arc::Weight &weight,
61  bool at_final) {
62  using Weight = typename Arc::Weight;
63  if ((weight == Weight::One()) || (weight == Weight::Zero())) return;
64  if (at_final) {
65  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
66  siter.Next()) {
67  fst->SetFinal(siter.Value(),
68  Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT));
69  }
70  } else {
71  const auto start = fst->Start();
72  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, start); !aiter.Done();
73  aiter.Next()) {
74  auto arc = aiter.Value();
75  arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT);
76  aiter.SetValue(arc);
77  }
78  fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT));
79  }
80 }
81 
82 // Pushes the weights in FST in the requested direction. If pushing towards the
83 // initial state, the sum of the weight of the outgoing transitions and final
84 // weight at a non-initial state is equal to One() in the resulting machine. If
85 // pushing towards the final state, the same property holds on the reverse
86 // machine.
87 //
88 // Weight needs to be left distributive when pushing towards the initial state
89 // and right distributive when pushing towards the final states.
90 template <class Arc>
92  float delta = kShortestDelta, bool remove_total_weight = false) {
93  using Weight = typename Arc::Weight;
94  std::vector<Weight> distance;
95  const bool reverse = type == REWEIGHT_TO_INITIAL;
96  ShortestDistance(*fst, &distance, reverse, delta);
97  if (remove_total_weight) {
98  const auto total_weight = ComputeTotalWeight(*fst, distance, reverse);
99  Reweight(fst, distance, type);
100  RemoveWeight(fst, total_weight, !reverse);
101  } else {
102  Reweight(fst, distance, type);
103  }
104 }
105 
106 inline constexpr uint8_t kPushWeights = 0x01;
107 inline constexpr uint8_t kPushLabels = 0x02;
108 inline constexpr uint8_t kPushRemoveTotalWeight = 0x04;
109 inline constexpr uint8_t kPushRemoveCommonAffix = 0x08;
110 
111 // Pushes the weights and/or labels of the input FST into the output mutable FST
112 // by pushing weights and/or labels (as determined by the ptype argument)
113 // towards the initial state or final states (as determined by the rtype
114 // template parameter). The weight type must be left distributive when pushing
115 // weights towards the initial state, and right distribution when pushing
116 // weights towards the final states.
117 template <class Arc, ReweightType rtype>
118 void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint8_t ptype,
119  float delta = kShortestDelta) {
120  using Label = typename Arc::Label;
121  using Weight = typename Arc::Weight;
122  if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
123  *ofst = ifst;
124  Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
125  } else if (ptype & kPushLabels) {
126  const auto gtype =
129  std::vector<GallicWeight> gdistance;
131  ArcMap(ifst, &gfst, ToGallicMapper<Arc, gtype>());
132  if (ptype & kPushWeights) {
133  ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
134  } else {
135  ArcMapFst uwfst(ifst, RmWeightMapper<Arc>());
136  ArcMapFst guwfst(uwfst, ToGallicMapper<Arc, gtype>());
137  ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
138  }
139  auto total_weight = GallicWeight::One();
140  if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
141  total_weight =
142  ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
143  total_weight = GallicWeight(
144  ptype & kPushRemoveCommonAffix
145  ? total_weight.Value1()
147  ptype & kPushRemoveTotalWeight ? total_weight.Value2()
148  : Weight::One());
149  }
150  Reweight(&gfst, gdistance, rtype);
151  if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
152  RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
153  }
155  fwfst(gfst);
156  ArcMap(fwfst, ofst, FromGallicMapper<Arc, gtype>());
157  ofst->SetOutputSymbols(ifst.OutputSymbols());
158  } else {
159  LOG(WARNING) << "Push: pushing type is set to 0, so not pushing";
160  *ofst = ifst;
161  }
162 }
163 
164 } // namespace fst
165 
166 #endif // FST_PUSH_H_
void ArcMap(MutableFst< A > *fst, C *mapper)
Definition: arc-map.h:110
constexpr uint8_t kPushRemoveCommonAffix
Definition: push.h:109
constexpr uint8_t kPushLabels
Definition: push.h:107
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
void RemoveWeight(MutableFst< Arc > *fst, const typename Arc::Weight &weight, bool at_final)
Definition: push.h:60
ReweightType
Definition: reweight.h:34
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
#define LOG(type)
Definition: log.h:49
virtual Weight Final(StateId) const =0
Arc::Weight ComputeTotalWeight(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > &distance, bool reverse)
Definition: push.h:42
constexpr uint8_t kPushWeights
Definition: push.h:106
virtual StateId Start() const =0
static const GallicWeight & One()
constexpr float kShortestDelta
void Push(MutableFst< Arc > *fst, ReweightType type=REWEIGHT_TO_INITIAL, float delta=kShortestDelta, bool remove_total_weight=false)
Definition: push.h:91
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:66
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
constexpr uint8_t kPushRemoveTotalWeight
Definition: push.h:108
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
Definition: reweight.h:45
virtual const SymbolTable * OutputSymbols() const =0