FST  openfst-1.7.1
OpenFst Library
push.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to reweight/push an FST, and utility functions to weigh and reweight
5 // an FST.
6 
7 #ifndef FST_PUSH_H_
8 #define FST_PUSH_H_
9 
10 #include <vector>
11 
12 #include <fst/log.h>
13 
14 #include <fst/arc-map.h>
15 #include <fst/factor-weight.h>
16 #include <fst/fst.h>
17 #include <fst/reweight.h>
18 #include <fst/shortest-distance.h>
19 
20 
21 namespace fst {
22 
23 // Computes the total weight (sum of the weights of all accepting paths) from
24 // the output of ShortestDistance, using the shortest distance from the final
25 // state when reverse is true and from the initial state otherwise.
26 template <class Arc>
27 typename Arc::Weight ComputeTotalWeight(
28  const Fst<Arc> &fst, const std::vector<typename Arc::Weight> &distance,
29  bool reverse) {
30  if (reverse) {
31  return fst.Start() < distance.size() ? distance[fst.Start()]
32  : Arc::Weight::Zero();
33  }
34  auto sum = Arc::Weight::Zero();
35  for (typename Arc::StateId s = 0; s < distance.size(); ++s) {
36  sum = Plus(sum, Times(distance[s], fst.Final(s)));
37  }
38  return sum;
39 }
40 
41 // Divides the weight of every accepting path by a fixed weight. This weight
42 // is also divided at the final state if at_final is true and at the initial
43 // state otherwise.
44 template <class Arc>
45 void RemoveWeight(MutableFst<Arc> *fst, const typename Arc::Weight &weight,
46  bool at_final) {
47  using Weight = typename Arc::Weight;
48  if ((weight == Weight::One()) || (weight == Weight::Zero())) return;
49  if (at_final) {
50  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
51  siter.Next()) {
52  fst->SetFinal(siter.Value(),
53  Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT));
54  }
55  } else {
56  const auto start = fst->Start();
57  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, start); !aiter.Done();
58  aiter.Next()) {
59  auto arc = aiter.Value();
60  arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT);
61  aiter.SetValue(arc);
62  }
63  fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT));
64  }
65 }
66 
67 // Pushes the weights in FST in the direction defined by TYPE. If
68 // pushing towards the initial state, the sum of the weight of the
69 // outgoing transitions and final weight at a non-initial state is
70 // equal to One() in the resulting machine. If pushing towards the
71 // final state, the same property holds on the reverse machine.
72 //
73 // Weight needs to be left distributive when pushing towards the
74 // initial state and right distributive when pushing towards the final
75 // states.
76 template <class Arc>
77 void Push(MutableFst<Arc> *fst, ReweightType type, float delta = kDelta,
78  bool remove_total_weight = false) {
79  using Weight = typename Arc::Weight;
80  std::vector<Weight> distance;
81  ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
82  auto total_weight = Weight::One();
83  if (remove_total_weight) {
84  total_weight =
85  ComputeTotalWeight(*fst, distance, type == REWEIGHT_TO_INITIAL);
86  }
87  Reweight(fst, distance, type);
88  if (remove_total_weight) {
89  RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
90  }
91 }
92 
93 constexpr uint32 kPushWeights = 0x0001;
94 constexpr uint32 kPushLabels = 0x0002;
95 constexpr uint32 kPushRemoveTotalWeight = 0x0004;
96 constexpr uint32 kPushRemoveCommonAffix = 0x0008;
97 
98 // Pushes the weights and/or labels of the input FST into the output
99 // mutable FST by pushing weights and/or labels (as determined by the
100 // ptype argument) towards the initial state or final states (as
101 // determined by the rtype template parameter). The weight type must
102 // be left distributive when pushing weights towards the initial state, and
103 // right distribution when pushing weights towards the final states.
104 template <class Arc, ReweightType rtype>
105 void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint32 ptype,
106  float delta = kDelta) {
107  using Label = typename Arc::Label;
108  using Weight = typename Arc::Weight;
109  if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
110  *ofst = ifst;
111  Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
112  } else if (ptype & kPushLabels) {
113  const auto gtype =
116  std::vector<GallicWeight> gdistance;
118  ArcMap(ifst, &gfst, ToGallicMapper<Arc, gtype>());
119  if (ptype & kPushWeights) {
120  ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
121  } else {
125  uwfst, ToGallicMapper<Arc, gtype>());
126  ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
127  }
128  auto total_weight = GallicWeight::One();
129  if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
130  total_weight =
131  ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
132  total_weight = GallicWeight(
133  ptype & kPushRemoveCommonAffix
134  ? total_weight.Value1()
136  ptype & kPushRemoveTotalWeight ? total_weight.Value2()
137  : Weight::One());
138  }
139  Reweight(&gfst, gdistance, rtype);
140  if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
141  RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
142  }
144  fwfst(gfst);
145  ArcMap(fwfst, ofst, FromGallicMapper<Arc, gtype>());
146  ofst->SetOutputSymbols(ifst.OutputSymbols());
147  } else {
148  LOG(WARNING) << "Push: pushing type is set to 0, so not pushing";
149  *ofst = ifst;
150  }
151 }
152 
153 } // namespace fst
154 
155 #endif // FST_PUSH_H_
void ArcMap(MutableFst< A > *fst, C *mapper)
Definition: arc-map.h:94
ExpectationWeight< X1, X2 > Divide(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2, DivideType typ=DIVIDE_ANY)
void RemoveWeight(MutableFst< Arc > *fst, const typename Arc::Weight &weight, bool at_final)
Definition: push.h:45
ReweightType
Definition: reweight.h:17
#define LOG(type)
Definition: log.h:48
virtual Weight Final(StateId) const =0
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
constexpr uint32 kPushRemoveTotalWeight
Definition: push.h:95
Arc::Weight ComputeTotalWeight(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > &distance, bool reverse)
Definition: push.h:27
virtual void SetFinal(StateId, Weight)=0
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
virtual StateId Start() const =0
static const GallicWeight & One()
uint32_t uint32
Definition: types.h:31
constexpr uint32 kPushRemoveCommonAffix
Definition: push.h:96
constexpr uint32 kPushWeights
Definition: push.h:93
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
void Push(MutableFst< Arc > *fst, ReweightType type, float delta=kDelta, bool remove_total_weight=false)
Definition: push.h:77
constexpr float kDelta
Definition: weight.h:109
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
Definition: reweight.h:28
constexpr uint32 kPushLabels
Definition: push.h:94
virtual const SymbolTable * OutputSymbols() const =0