FST  openfst-1.8.2
OpenFst Library
prune.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions implementing pruning.
19 
20 #ifndef FST_PRUNE_H_
21 #define FST_PRUNE_H_
22 
23 #include <type_traits>
24 #include <utility>
25 #include <vector>
26 
27 #include <fst/log.h>
28 
29 #include <fst/arcfilter.h>
30 #include <fst/heap.h>
31 #include <fst/shortest-distance.h>
32 
33 
34 namespace fst {
35 namespace internal {
36 
37 template <class StateId, class Weight>
38 class PruneCompare {
39  public:
40  PruneCompare(const std::vector<Weight> &idistance,
41  const std::vector<Weight> &fdistance)
42  : idistance_(idistance), fdistance_(fdistance) {}
43 
44  bool operator()(const StateId x, const StateId y) const {
45  const auto wx = Times(IDistance(x), FDistance(x));
46  const auto wy = Times(IDistance(y), FDistance(y));
47  return less_(wx, wy);
48  }
49 
50  private:
51  Weight IDistance(const StateId s) const {
52  return s < idistance_.size() ? idistance_[s] : Weight::Zero();
53  }
54 
55  Weight FDistance(const StateId s) const {
56  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
57  }
58 
59  const std::vector<Weight> &idistance_;
60  const std::vector<Weight> &fdistance_;
61  NaturalLess<Weight> less_;
62 };
63 
64 } // namespace internal
65 
66 template <class Arc, class ArcFilter>
67 struct PruneOptions {
68  using StateId = typename Arc::StateId;
69  using Weight = typename Arc::Weight;
70 
71  explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(),
72  StateId state_threshold = kNoStateId,
73  ArcFilter filter = ArcFilter(),
74  std::vector<Weight> *distance = nullptr,
75  float delta = kDelta, bool threshold_initial = false)
76  : weight_threshold(std::move(weight_threshold)),
77  state_threshold(state_threshold),
78  filter(std::move(filter)),
79  distance(distance),
80  delta(delta),
81  threshold_initial(threshold_initial) {}
82 
83  // Pruning weight threshold.
85  // Pruning state threshold.
87  // Arc filter.
88  ArcFilter filter;
89  // If non-zero, passes in pre-computed shortest distance to final states.
90  const std::vector<Weight> *distance;
91  // Determines the degree of convergence required when computing shortest
92  // distances.
93  float delta;
94  // Determines if the shortest path weight is left (true) or right
95  // (false) multiplied by the threshold to get the limit for
96  // keeping a state or arc (matters if the semiring is not
97  // commutative).
99 };
100 
101 // Pruning algorithm: this version modifies its input and it takes an options
102 // class as an argument. After pruning the FST contains states and arcs that
103 // belong to a successful path in the FST whose weight is no more than the
104 // weight of the shortest path Times() the provided weight threshold. When the
105 // state threshold is not kNoStateId, the output FST is further restricted to
106 // have no more than the number of states in opts.state_threshold. Weights must
107 // have the path property. The weight of any cycle needs to be bounded; i.e.,
108 //
109 // Plus(weight, Weight::One()) == Weight::One()
110 template <class Arc, class ArcFilter>
113  using StateId = typename Arc::StateId;
114  using Weight = typename Arc::Weight;
115  static_assert(IsPath<Weight>::value, "Weight must have path property.");
117  auto ns = fst->NumStates();
118  if (ns < 1) return;
119  std::vector<Weight> idistance(ns, Weight::Zero());
120  std::vector<Weight> tmp;
121  if (!opts.distance) {
122  tmp.reserve(ns);
123  ShortestDistance(*fst, &tmp, true, opts.delta);
124  }
125  const auto *fdistance = opts.distance ? opts.distance : &tmp;
126  if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
127  ((*fdistance)[fst->Start()] == Weight::Zero())) {
128  fst->DeleteStates();
129  return;
130  }
131  internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
132  StateHeap heap(compare);
133  std::vector<bool> visited(ns, false);
134  std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
135  std::vector<StateId> dead;
136  dead.push_back(fst->AddState());
137  NaturalLess<Weight> less;
138  auto s = fst->Start();
139  const auto limit = opts.threshold_initial
140  ? Times(opts.weight_threshold, (*fdistance)[s])
141  : Times((*fdistance)[s], opts.weight_threshold);
142  StateId num_visited = 0;
143 
144  if (!less(limit, (*fdistance)[s])) {
145  idistance[s] = Weight::One();
146  enqueued[s] = heap.Insert(s);
147  ++num_visited;
148  }
149  while (!heap.Empty()) {
150  s = heap.Top();
151  heap.Pop();
152  enqueued[s] = StateHeap::kNoKey;
153  visited[s] = true;
154  if (less(limit, Times(idistance[s], fst->Final(s)))) {
155  fst->SetFinal(s, Weight::Zero());
156  }
157  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
158  aiter.Next()) {
159  auto arc = aiter.Value(); // Copy intended.
160  if (!opts.filter(arc)) continue;
161  const auto weight =
162  Times(Times(idistance[s], arc.weight),
163  arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
164  : Weight::Zero());
165  if (less(limit, weight)) {
166  arc.nextstate = dead[0];
167  aiter.SetValue(arc);
168  continue;
169  }
170  if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
171  idistance[arc.nextstate] = Times(idistance[s], arc.weight);
172  }
173  if (visited[arc.nextstate]) continue;
174  if ((opts.state_threshold != kNoStateId) &&
175  (num_visited >= opts.state_threshold)) {
176  continue;
177  }
178  if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
179  enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
180  ++num_visited;
181  } else {
182  heap.Update(enqueued[arc.nextstate], arc.nextstate);
183  }
184  }
185  }
186  for (StateId i = 0; i < visited.size(); ++i) {
187  if (!visited[i]) dead.push_back(i);
188  }
189  fst->DeleteStates(dead);
190 }
191 
192 // Pruning algorithm: this version modifies its input and takes the
193 // pruning threshold as an argument. It deletes states and arcs in the
194 // FST that do not belong to a successful path whose weight is more
195 // than the weight of the shortest path Times() the provided weight
196 // threshold. When the state threshold is not kNoStateId, the output
197 // FST is further restricted to have no more than the number of states
198 // in opts.state_threshold. Weights must have the path property. The
199 // weight of any cycle needs to be bounded; i.e.,
200 //
201 // Plus(weight, Weight::One()) == Weight::One()
202 template <class Arc>
203 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
204  typename Arc::StateId state_threshold = kNoStateId,
205  float delta = kDelta) {
207  weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
208  Prune(fst, opts);
209 }
210 
211 // Pruning algorithm: this version writes the pruned input FST to an
212 // output MutableFst and it takes an options class as an argument. The
213 // output FST contains states and arcs that belong to a successful
214 // path in the input FST whose weight is more than the weight of the
215 // shortest path Times() the provided weight threshold. When the state
216 // threshold is not kNoStateId, the output FST is further restricted
217 // to have no more than the number of states in
218 // opts.state_threshold. Weights have the path property. The weight
219 // of any cycle needs to be bounded; i.e.,
220 //
221 // Plus(weight, Weight::One()) == Weight::One()
222 template <class Arc, class ArcFilter>
223 void Prune(
224  const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
226  using StateId = typename Arc::StateId;
227  using Weight = typename Arc::Weight;
228  static_assert(IsPath<Weight>::value, "Weight must have path property.");
230  ofst->DeleteStates();
231  ofst->SetInputSymbols(ifst.InputSymbols());
232  ofst->SetOutputSymbols(ifst.OutputSymbols());
233  if (ifst.Start() == kNoStateId) return;
234  NaturalLess<Weight> less;
235  if (less(opts.weight_threshold, Weight::One()) ||
236  (opts.state_threshold == 0)) {
237  return;
238  }
239  std::vector<Weight> idistance;
240  std::vector<Weight> tmp;
241  if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
242  const auto *fdistance = opts.distance ? opts.distance : &tmp;
243  if ((fdistance->size() <= ifst.Start()) ||
244  ((*fdistance)[ifst.Start()] == Weight::Zero())) {
245  return;
246  }
247  internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
248  StateHeap heap(compare);
249  std::vector<StateId> copy;
250  std::vector<size_t> enqueued;
251  std::vector<bool> visited;
252  auto s = ifst.Start();
253  const auto limit = opts.threshold_initial
254  ? Times(opts.weight_threshold, (*fdistance)[s])
255  : Times((*fdistance)[s], opts.weight_threshold);
256  while (copy.size() <= s) copy.push_back(kNoStateId);
257  copy[s] = ofst->AddState();
258  ofst->SetStart(copy[s]);
259  while (idistance.size() <= s) idistance.push_back(Weight::Zero());
260  idistance[s] = Weight::One();
261  while (enqueued.size() <= s) {
262  enqueued.push_back(StateHeap::kNoKey);
263  visited.push_back(false);
264  }
265  enqueued[s] = heap.Insert(s);
266  while (!heap.Empty()) {
267  s = heap.Top();
268  heap.Pop();
269  enqueued[s] = StateHeap::kNoKey;
270  visited[s] = true;
271  if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
272  ofst->SetFinal(copy[s], ifst.Final(s));
273  }
274  for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
275  const auto &arc = aiter.Value();
276  if (!opts.filter(arc)) continue;
277  const auto weight =
278  Times(Times(idistance[s], arc.weight),
279  arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
280  : Weight::Zero());
281  if (less(limit, weight)) continue;
282  if ((opts.state_threshold != kNoStateId) &&
283  (ofst->NumStates() >= opts.state_threshold)) {
284  continue;
285  }
286  while (idistance.size() <= arc.nextstate) {
287  idistance.push_back(Weight::Zero());
288  }
289  if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
290  idistance[arc.nextstate] = Times(idistance[s], arc.weight);
291  }
292  while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
293  if (copy[arc.nextstate] == kNoStateId) {
294  copy[arc.nextstate] = ofst->AddState();
295  }
296  ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
297  copy[arc.nextstate]));
298  while (enqueued.size() <= arc.nextstate) {
299  enqueued.push_back(StateHeap::kNoKey);
300  visited.push_back(false);
301  }
302  if (visited[arc.nextstate]) continue;
303  if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
304  enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
305  } else {
306  heap.Update(enqueued[arc.nextstate], arc.nextstate);
307  }
308  }
309  }
310 }
311 
312 // Pruning algorithm: this version writes the pruned input FST to an
313 // output MutableFst and simply takes the pruning threshold as an
314 // argument. The output FST contains states and arcs that belong to a
315 // successful path in the input FST whose weight is no more than the
316 // weight of the shortest path Times() the provided weight
317 // threshold. When the state threshold is not kNoStateId, the output
318 // FST is further restricted to have no more than the number of states
319 // in opts.state_threshold. Weights must have the path property. The
320 // weight of any cycle needs to be bounded; i.e.,
321 //
322 // Plus(weight, Weight::One()) = Weight::One();
323 template <class Arc>
324 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
325  typename Arc::Weight weight_threshold,
326  typename Arc::StateId state_threshold = kNoStateId,
327  float delta = kDelta) {
329  weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
330  Prune(ifst, ofst, opts);
331 }
332 
333 } // namespace fst
334 
335 #endif // FST_PRUNE_H_
bool operator()(const StateId x, const StateId y) const
Definition: prune.h:44
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
typename Arc::StateId StateId
Definition: prune.h:68
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
Weight weight_threshold
Definition: prune.h:84
typename Arc::Weight Weight
Definition: prune.h:69
constexpr int kNoStateId
Definition: fst.h:202
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
Definition: prune.h:111
virtual StateId Start() const =0
float delta
Definition: prune.h:93
virtual void AddArc(StateId, const Arc &)=0
virtual const SymbolTable * InputSymbols() const =0
StateId state_threshold
Definition: prune.h:86
ArcFilter filter
Definition: prune.h:88
const std::vector< Weight > * distance
Definition: prune.h:90
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
virtual StateId NumStates() const =0
bool threshold_initial
Definition: prune.h:98
PruneCompare(const std::vector< Weight > &idistance, const std::vector< Weight > &fdistance)
Definition: prune.h:40
constexpr float kDelta
Definition: weight.h:130
PruneOptions(const Weight &weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId, ArcFilter filter=ArcFilter(), std::vector< Weight > *distance=nullptr, float delta=kDelta, bool threshold_initial=false)
Definition: prune.h:71
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
Definition: weight.h:159
virtual const SymbolTable * OutputSymbols() const =0