FST  openfst-1.8.3
OpenFst Library
push.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Class to reweight/push an FST, and utility functions to weigh and reweight
19 // an FST.
20 
21 #ifndef FST_PUSH_H_
22 #define FST_PUSH_H_
23 
24 #include <cstdint>
25 #include <vector>
26 
27 #include <fst/log.h>
28 #include <fst/arc-map.h>
29 #include <fst/arc.h>
30 #include <fst/factor-weight.h>
31 #include <fst/fst.h>
32 #include <fst/mutable-fst.h>
33 #include <fst/reweight.h>
34 #include <fst/shortest-distance.h>
35 #include <fst/string-weight.h>
36 #include <fst/vector-fst.h>
37 #include <fst/weight.h>
38 
39 namespace fst {
40 
41 // Computes the total weight (sum of the weights of all accepting paths) from
42 // the output of ShortestDistance, using the shortest distance from the final
43 // state when reverse is true and from the initial state otherwise.
44 template <class Arc>
45 typename Arc::Weight ComputeTotalWeight(
46  const Fst<Arc> &fst, const std::vector<typename Arc::Weight> &distance,
47  bool reverse) {
48  if (reverse) {
49  return fst.Start() < distance.size() ? distance[fst.Start()]
50  : Arc::Weight::Zero();
51  }
52  auto sum = Arc::Weight::Zero();
53  for (typename Arc::StateId s = 0; s < distance.size(); ++s) {
54  sum = Plus(sum, Times(distance[s], fst.Final(s)));
55  }
56  return sum;
57 }
58 
59 // Divides the weight of every accepting path by a fixed weight. This weight
60 // is also divided at the final state if at_final is true and at the initial
61 // state otherwise.
62 template <class Arc>
63 void RemoveWeight(MutableFst<Arc> *fst, const typename Arc::Weight &weight,
64  bool at_final) {
65  using Weight = typename Arc::Weight;
66  if ((weight == Weight::One()) || (weight == Weight::Zero())) return;
67  if (at_final) {
68  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
69  siter.Next()) {
70  fst->SetFinal(siter.Value(),
71  Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT));
72  }
73  } else {
74  const auto start = fst->Start();
75  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, start); !aiter.Done();
76  aiter.Next()) {
77  auto arc = aiter.Value();
78  arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT);
79  aiter.SetValue(arc);
80  }
81  fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT));
82  }
83 }
84 
85 // Pushes the weights in FST in the requested direction. If pushing towards the
86 // initial state, the sum of the weight of the outgoing transitions and final
87 // weight at a non-initial state is equal to One() in the resulting machine. If
88 // pushing towards the final state, the same property holds on the reverse
89 // machine.
90 //
91 // Weight needs to be left distributive when pushing towards the initial state
92 // and right distributive when pushing towards the final states.
93 template <class Arc>
95  float delta = kShortestDelta, bool remove_total_weight = false) {
96  using Weight = typename Arc::Weight;
97  std::vector<Weight> distance;
98  const bool reverse = type == REWEIGHT_TO_INITIAL;
99  ShortestDistance(*fst, &distance, reverse, delta);
100  if (remove_total_weight) {
101  const auto total_weight = ComputeTotalWeight(*fst, distance, reverse);
102  Reweight(fst, distance, type);
103  RemoveWeight(fst, total_weight, !reverse);
104  } else {
105  Reweight(fst, distance, type);
106  }
107 }
108 
109 inline constexpr uint8_t kPushWeights = 0x01;
110 inline constexpr uint8_t kPushLabels = 0x02;
111 inline constexpr uint8_t kPushRemoveTotalWeight = 0x04;
112 inline constexpr uint8_t kPushRemoveCommonAffix = 0x08;
113 
114 // Pushes the weights and/or labels of the input FST into the output mutable FST
115 // by pushing weights and/or labels (as determined by the ptype argument)
116 // towards the initial state or final states (as determined by the rtype
117 // template parameter). The weight type must be left distributive when pushing
118 // weights towards the initial state, and right distribution when pushing
119 // weights towards the final states.
120 template <class Arc, ReweightType rtype>
121 void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint8_t ptype,
122  float delta = kShortestDelta) {
123  using Label = typename Arc::Label;
124  using Weight = typename Arc::Weight;
125  if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
126  *ofst = ifst;
127  Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
128  } else if (ptype & kPushLabels) {
129  const auto gtype =
132  std::vector<GallicWeight> gdistance;
134  ArcMap(ifst, &gfst, ToGallicMapper<Arc, gtype>());
135  if (ptype & kPushWeights) {
136  ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
137  } else {
138  ArcMapFst uwfst(ifst, RmWeightMapper<Arc>());
139  ArcMapFst guwfst(uwfst, ToGallicMapper<Arc, gtype>());
140  ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
141  }
142  auto total_weight = GallicWeight::One();
143  if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
144  total_weight =
145  ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
146  total_weight = GallicWeight(
147  ptype & kPushRemoveCommonAffix
148  ? total_weight.Value1()
150  ptype & kPushRemoveTotalWeight ? total_weight.Value2()
151  : Weight::One());
152  }
153  Reweight(&gfst, gdistance, rtype);
154  if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
155  RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
156  }
158  fwfst(gfst);
159  ArcMap(fwfst, ofst, FromGallicMapper<Arc, gtype>());
160  ofst->SetOutputSymbols(ifst.OutputSymbols());
161  } else {
162  LOG(WARNING) << "Push: pushing type is set to 0, so not pushing";
163  *ofst = ifst;
164  }
165 }
166 
167 } // namespace fst
168 
169 #endif // FST_PUSH_H_
void ArcMap(MutableFst< A > *fst, C *mapper)
Definition: arc-map.h:120
constexpr uint8_t kPushRemoveCommonAffix
Definition: push.h:112
constexpr uint8_t kPushLabels
Definition: push.h:110
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
void RemoveWeight(MutableFst< Arc > *fst, const typename Arc::Weight &weight, bool at_final)
Definition: push.h:63
ReweightType
Definition: reweight.h:35
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
#define LOG(type)
Definition: log.h:53
virtual Weight Final(StateId) const =0
Arc::Weight ComputeTotalWeight(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > &distance, bool reverse)
Definition: push.h:45
constexpr uint8_t kPushWeights
Definition: push.h:109
virtual StateId Start() const =0
static const GallicWeight & One()
constexpr float kShortestDelta
void Push(MutableFst< Arc > *fst, ReweightType type=REWEIGHT_TO_INITIAL, float delta=kShortestDelta, bool remove_total_weight=false)
Definition: push.h:94
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
constexpr uint8_t kPushRemoveTotalWeight
Definition: push.h:111
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
Definition: reweight.h:46
virtual const SymbolTable * OutputSymbols() const =0