FST  openfst-1.7.2
OpenFst Library
prune.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions implementing pruning.
5 
6 #ifndef FST_PRUNE_H_
7 #define FST_PRUNE_H_
8 
9 #include <type_traits>
10 #include <utility>
11 #include <vector>
12 
13 #include <fst/log.h>
14 
15 #include <fst/arcfilter.h>
16 #include <fst/heap.h>
17 #include <fst/shortest-distance.h>
18 
19 
20 namespace fst {
21 namespace internal {
22 
23 template <class StateId, class Weight>
24 class PruneCompare {
25  public:
26  PruneCompare(const std::vector<Weight> &idistance,
27  const std::vector<Weight> &fdistance)
28  : idistance_(idistance), fdistance_(fdistance) {}
29 
30  bool operator()(const StateId x, const StateId y) const {
31  const auto wx = Times(IDistance(x), FDistance(x));
32  const auto wy = Times(IDistance(y), FDistance(y));
33  return less_(wx, wy);
34  }
35 
36  private:
37  Weight IDistance(const StateId s) const {
38  return s < idistance_.size() ? idistance_[s] : Weight::Zero();
39  }
40 
41  Weight FDistance(const StateId s) const {
42  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
43  }
44 
45  const std::vector<Weight> &idistance_;
46  const std::vector<Weight> &fdistance_;
47  NaturalLess<Weight> less_;
48 };
49 
50 } // namespace internal
51 
52 template <class Arc, class ArcFilter>
53 struct PruneOptions {
54  using StateId = typename Arc::StateId;
55  using Weight = typename Arc::Weight;
56 
57  explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(),
58  StateId state_threshold = kNoStateId,
59  ArcFilter filter = ArcFilter(),
60  std::vector<Weight> *distance = nullptr,
61  float delta = kDelta, bool threshold_initial = false)
62  : weight_threshold(std::move(weight_threshold)),
63  state_threshold(state_threshold),
64  filter(std::move(filter)),
65  distance(distance),
66  delta(delta),
67  threshold_initial(threshold_initial) {}
68 
69  // Pruning weight threshold.
71  // Pruning state threshold.
73  // Arc filter.
74  ArcFilter filter;
75  // If non-zero, passes in pre-computed shortest distance to final states.
76  const std::vector<Weight> *distance;
77  // Determines the degree of convergence required when computing shortest
78  // distances.
79  float delta;
80  // Determines if the shortest path weight is left (true) or right
81  // (false) multiplied by the threshold to get the limit for
82  // keeping a state or arc (matters if the semiring is not
83  // commutative).
85 };
86 
87 // Pruning algorithm: this version modifies its input and it takes an options
88 // class as an argument. After pruning the FST contains states and arcs that
89 // belong to a successful path in the FST whose weight is no more than the
90 // weight of the shortest path Times() the provided weight threshold. When the
91 // state threshold is not kNoStateId, the output FST is further restricted to
92 // have no more than the number of states in opts.state_threshold. Weights must
93 // have the path property. The weight of any cycle needs to be bounded; i.e.,
94 //
95 // Plus(weight, Weight::One()) == Weight::One()
96 template <class Arc, class ArcFilter,
97  typename std::enable_if<(Arc::Weight::Properties() & kPath) ==
98  kPath>::type * = nullptr>
101  using StateId = typename Arc::StateId;
102  using Weight = typename Arc::Weight;
104  auto ns = fst->NumStates();
105  if (ns < 1) return;
106  std::vector<Weight> idistance(ns, Weight::Zero());
107  std::vector<Weight> tmp;
108  if (!opts.distance) {
109  tmp.reserve(ns);
110  ShortestDistance(*fst, &tmp, true, opts.delta);
111  }
112  const auto *fdistance = opts.distance ? opts.distance : &tmp;
113  if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
114  ((*fdistance)[fst->Start()] == Weight::Zero())) {
115  fst->DeleteStates();
116  return;
117  }
118  internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
119  StateHeap heap(compare);
120  std::vector<bool> visited(ns, false);
121  std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
122  std::vector<StateId> dead;
123  dead.push_back(fst->AddState());
124  NaturalLess<Weight> less;
125  auto s = fst->Start();
126  const auto limit = opts.threshold_initial ?
127  Times(opts.weight_threshold, (*fdistance)[s]) :
128  Times((*fdistance)[s], opts.weight_threshold);
129  StateId num_visited = 0;
130 
131  if (!less(limit, (*fdistance)[s])) {
132  idistance[s] = Weight::One();
133  enqueued[s] = heap.Insert(s);
134  ++num_visited;
135  }
136  while (!heap.Empty()) {
137  s = heap.Top();
138  heap.Pop();
139  enqueued[s] = StateHeap::kNoKey;
140  visited[s] = true;
141  if (less(limit, Times(idistance[s], fst->Final(s)))) {
142  fst->SetFinal(s, Weight::Zero());
143  }
144  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
145  aiter.Next()) {
146  auto arc = aiter.Value(); // Copy intended.
147  if (!opts.filter(arc)) continue;
148  const auto weight = Times(Times(idistance[s], arc.weight),
149  arc.nextstate < fdistance->size() ?
150  (*fdistance)[arc.nextstate] : Weight::Zero());
151  if (less(limit, weight)) {
152  arc.nextstate = dead[0];
153  aiter.SetValue(arc);
154  continue;
155  }
156  if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
157  idistance[arc.nextstate] = Times(idistance[s], arc.weight);
158  }
159  if (visited[arc.nextstate]) continue;
160  if ((opts.state_threshold != kNoStateId) &&
161  (num_visited >= opts.state_threshold)) {
162  continue;
163  }
164  if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
165  enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
166  ++num_visited;
167  } else {
168  heap.Update(enqueued[arc.nextstate], arc.nextstate);
169  }
170  }
171  }
172  for (StateId i = 0; i < visited.size(); ++i) {
173  if (!visited[i]) dead.push_back(i);
174  }
175  fst->DeleteStates(dead);
176 }
177 
178 template <class Arc, class ArcFilter,
179  typename std::enable_if<(Arc::Weight::Properties() & kPath) !=
180  kPath>::type * = nullptr>
181 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts =
183  FSTERROR() << "Prune: Weight needs to have the path property: "
184  << Arc::Weight::Type();
185  fst->SetProperties(kError, kError);
186 }
187 
188 // Pruning algorithm: this version modifies its input and takes the
189 // pruning threshold as an argument. It deletes states and arcs in the
190 // FST that do not belong to a successful path whose weight is more
191 // than the weight of the shortest path Times() the provided weight
192 // threshold. When the state threshold is not kNoStateId, the output
193 // FST is further restricted to have no more than the number of states
194 // in opts.state_threshold. Weights must have the path property. The
195 // weight of any cycle needs to be bounded; i.e.,
196 //
197 // Plus(weight, Weight::One()) == Weight::One()
198 template <class Arc>
199 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
200  typename Arc::StateId state_threshold = kNoStateId,
201  float delta = kDelta) {
203  weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
204  Prune(fst, opts);
205 }
206 
207 // Pruning algorithm: this version writes the pruned input FST to an
208 // output MutableFst and it takes an options class as an argument. The
209 // output FST contains states and arcs that belong to a successful
210 // path in the input FST whose weight is more than the weight of the
211 // shortest path Times() the provided weight threshold. When the state
212 // threshold is not kNoStateId, the output FST is further restricted
213 // to have no more than the number of states in
214 // opts.state_threshold. Weights have the path property. The weight
215 // of any cycle needs to be bounded; i.e.,
216 //
217 // Plus(weight, Weight::One()) == Weight::One()
218 template <class Arc, class ArcFilter,
219  typename std::enable_if<IsPath<typename Arc::Weight>::value>::type * =
220  nullptr>
221 void Prune(
222  const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
224  using StateId = typename Arc::StateId;
225  using Weight = typename Arc::Weight;
227  ofst->DeleteStates();
228  ofst->SetInputSymbols(ifst.InputSymbols());
229  ofst->SetOutputSymbols(ifst.OutputSymbols());
230  if (ifst.Start() == kNoStateId) return;
231  NaturalLess<Weight> less;
232  if (less(opts.weight_threshold, Weight::One()) ||
233  (opts.state_threshold == 0)) {
234  return;
235  }
236  std::vector<Weight> idistance;
237  std::vector<Weight> tmp;
238  if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
239  const auto *fdistance = opts.distance ? opts.distance : &tmp;
240  if ((fdistance->size() <= ifst.Start()) ||
241  ((*fdistance)[ifst.Start()] == Weight::Zero())) {
242  return;
243  }
244  internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
245  StateHeap heap(compare);
246  std::vector<StateId> copy;
247  std::vector<size_t> enqueued;
248  std::vector<bool> visited;
249  auto s = ifst.Start();
250  const auto limit = opts.threshold_initial ?
251  Times(opts.weight_threshold, (*fdistance)[s]) :
252  Times((*fdistance)[s], opts.weight_threshold);
253  while (copy.size() <= s) copy.push_back(kNoStateId);
254  copy[s] = ofst->AddState();
255  ofst->SetStart(copy[s]);
256  while (idistance.size() <= s) idistance.push_back(Weight::Zero());
257  idistance[s] = Weight::One();
258  while (enqueued.size() <= s) {
259  enqueued.push_back(StateHeap::kNoKey);
260  visited.push_back(false);
261  }
262  enqueued[s] = heap.Insert(s);
263  while (!heap.Empty()) {
264  s = heap.Top();
265  heap.Pop();
266  enqueued[s] = StateHeap::kNoKey;
267  visited[s] = true;
268  if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
269  ofst->SetFinal(copy[s], ifst.Final(s));
270  }
271  for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
272  const auto &arc = aiter.Value();
273  if (!opts.filter(arc)) continue;
274  const auto weight = Times(Times(idistance[s], arc.weight),
275  arc.nextstate < fdistance->size() ?
276  (*fdistance)[arc.nextstate] : Weight::Zero());
277  if (less(limit, weight)) continue;
278  if ((opts.state_threshold != kNoStateId) &&
279  (ofst->NumStates() >= opts.state_threshold)) {
280  continue;
281  }
282  while (idistance.size() <= arc.nextstate) {
283  idistance.push_back(Weight::Zero());
284  }
285  if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
286  idistance[arc.nextstate] = Times(idistance[s], arc.weight);
287  }
288  while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
289  if (copy[arc.nextstate] == kNoStateId) {
290  copy[arc.nextstate] = ofst->AddState();
291  }
292  ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
293  copy[arc.nextstate]));
294  while (enqueued.size() <= arc.nextstate) {
295  enqueued.push_back(StateHeap::kNoKey);
296  visited.push_back(false);
297  }
298  if (visited[arc.nextstate]) continue;
299  if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
300  enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
301  } else {
302  heap.Update(enqueued[arc.nextstate], arc.nextstate);
303  }
304  }
305  }
306 }
307 
308 template <class Arc, class ArcFilter,
309  typename std::enable_if<!IsPath<typename Arc::Weight>::value>::type
310  * = nullptr>
311 void Prune(const Fst<Arc> &, MutableFst<Arc> *ofst,
313  FSTERROR() << "Prune: Weight needs to have the path property: "
314  << Arc::Weight::Type();
315  ofst->SetProperties(kError, kError);
316 }
317 
318 // Pruning algorithm: this version writes the pruned input FST to an
319 // output MutableFst and simply takes the pruning threshold as an
320 // argument. The output FST contains states and arcs that belong to a
321 // successful path in the input FST whose weight is no more than the
322 // weight of the shortest path Times() the provided weight
323 // threshold. When the state threshold is not kNoStateId, the output
324 // FST is further restricted to have no more than the number of states
325 // in opts.state_threshold. Weights must have the path property. The
326 // weight of any cycle needs to be bounded; i.e.,
327 //
328 // Plus(weight, Weight::One()) = Weight::One();
329 template <class Arc>
330 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
331  typename Arc::Weight weight_threshold,
332  typename Arc::StateId state_threshold = kNoStateId,
333  float delta = kDelta) {
335  weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
336  Prune(ifst, ofst, opts);
337 }
338 
339 } // namespace fst
340 
341 #endif // FST_PRUNE_H_
virtual void AddArc(StateId, const Arc &arc)=0
bool operator()(const StateId x, const StateId y) const
Definition: prune.h:30
typename Arc::StateId StateId
Definition: prune.h:54
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
virtual void SetStart(StateId)=0
Weight weight_threshold
Definition: prune.h:70
typename Arc::Weight Weight
Definition: prune.h:55
constexpr int kNoStateId
Definition: fst.h:180
#define FSTERROR()
Definition: util.h:35
virtual void SetFinal(StateId, Weight)=0
virtual StateId Start() const =0
float delta
Definition: prune.h:79
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
Definition: prune.h:99
virtual const SymbolTable * InputSymbols() const =0
StateId state_threshold
Definition: prune.h:72
ArcFilter filter
Definition: prune.h:74
const std::vector< Weight > * distance
Definition: prune.h:76
constexpr uint64 kPath
Definition: weight.h:126
constexpr uint64 kError
Definition: properties.h:33
virtual StateId AddState()=0
virtual void DeleteStates(const std::vector< StateId > &)=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
virtual StateId NumStates() const =0
bool threshold_initial
Definition: prune.h:84
PruneCompare(const std::vector< Weight > &idistance, const std::vector< Weight > &fdistance)
Definition: prune.h:26
constexpr float kDelta
Definition: weight.h:109
PruneOptions(const Weight &weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId, ArcFilter filter=ArcFilter(), std::vector< Weight > *distance=nullptr, float delta=kDelta, bool threshold_initial=false)
Definition: prune.h:57
virtual void SetProperties(uint64 props, uint64 mask)=0
virtual const SymbolTable * OutputSymbols() const =0