FST  openfst-1.8.3
OpenFst Library
prune.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions implementing pruning.
19 
20 #ifndef FST_PRUNE_H_
21 #define FST_PRUNE_H_
22 
23 #include <cstddef>
24 #include <cstdlib>
25 #include <type_traits>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/log.h>
30 #include <fst/arcfilter.h>
31 #include <fst/fst.h>
32 #include <fst/heap.h>
33 #include <fst/mutable-fst.h>
34 #include <fst/shortest-distance.h>
35 #include <fst/weight.h>
36 
37 namespace fst {
38 namespace internal {
39 
40 template <class StateId, class Weight>
41 class PruneCompare {
42  public:
43  PruneCompare(const std::vector<Weight> &idistance,
44  const std::vector<Weight> &fdistance)
45  : idistance_(idistance), fdistance_(fdistance) {}
46 
47  bool operator()(const StateId x, const StateId y) const {
48  const auto wx = Times(IDistance(x), FDistance(x));
49  const auto wy = Times(IDistance(y), FDistance(y));
50  return less_(wx, wy);
51  }
52 
53  private:
54  Weight IDistance(const StateId s) const {
55  return s < idistance_.size() ? idistance_[s] : Weight::Zero();
56  }
57 
58  Weight FDistance(const StateId s) const {
59  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
60  }
61 
62  const std::vector<Weight> &idistance_;
63  const std::vector<Weight> &fdistance_;
64  NaturalLess<Weight> less_;
65 };
66 
67 } // namespace internal
68 
69 template <class Arc, class ArcFilter>
70 struct PruneOptions {
71  using StateId = typename Arc::StateId;
72  using Weight = typename Arc::Weight;
73 
74  explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(),
75  StateId state_threshold = kNoStateId,
76  ArcFilter filter = ArcFilter(),
77  std::vector<Weight> *distance = nullptr,
78  float delta = kDelta, bool threshold_initial = false)
79  : weight_threshold(std::move(weight_threshold)),
80  state_threshold(state_threshold),
81  filter(std::move(filter)),
82  distance(distance),
83  delta(delta),
84  threshold_initial(threshold_initial) {}
85 
86  // Pruning weight threshold.
88  // Pruning state threshold.
90  // Arc filter.
91  ArcFilter filter;
92  // If non-zero, passes in pre-computed shortest distance to final states.
93  const std::vector<Weight> *distance;
94  // Determines the degree of convergence required when computing shortest
95  // distances.
96  float delta;
97  // Determines if the shortest path weight is left (true) or right
98  // (false) multiplied by the threshold to get the limit for
99  // keeping a state or arc (matters if the semiring is not
100  // commutative).
102 };
103 
104 // Pruning algorithm: this version modifies its input and it takes an options
105 // class as an argument. After pruning the FST contains states and arcs that
106 // belong to a successful path in the FST whose weight is no more than the
107 // weight of the shortest path Times() the provided weight threshold. When the
108 // state threshold is not kNoStateId, the output FST is further restricted to
109 // have no more than the number of states in opts.state_threshold. Weights must
110 // have the path property. The weight of any cycle needs to be bounded; i.e.,
111 //
112 // Plus(weight, Weight::One()) == Weight::One()
113 template <class Arc, class ArcFilter>
116  using StateId = typename Arc::StateId;
117  using Weight = typename Arc::Weight;
118  static_assert(IsPath<Weight>::value, "Weight must have path property.");
120  auto ns = fst->NumStates();
121  if (ns < 1) return;
122  std::vector<Weight> idistance(ns, Weight::Zero());
123  std::vector<Weight> tmp;
124  if (!opts.distance) {
125  tmp.reserve(ns);
126  ShortestDistance(*fst, &tmp, true, opts.delta);
127  }
128  const auto *fdistance = opts.distance ? opts.distance : &tmp;
129  if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
130  ((*fdistance)[fst->Start()] == Weight::Zero())) {
131  fst->DeleteStates();
132  return;
133  }
134  internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
135  StateHeap heap(compare);
136  std::vector<bool> visited(ns, false);
137  std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
138  std::vector<StateId> dead;
139  dead.push_back(fst->AddState());
140  NaturalLess<Weight> less;
141  auto s = fst->Start();
142  const auto limit = opts.threshold_initial
143  ? Times(opts.weight_threshold, (*fdistance)[s])
144  : Times((*fdistance)[s], opts.weight_threshold);
145  StateId num_visited = 0;
146 
147  if (!less(limit, (*fdistance)[s])) {
148  idistance[s] = Weight::One();
149  enqueued[s] = heap.Insert(s);
150  ++num_visited;
151  }
152  while (!heap.Empty()) {
153  s = heap.Top();
154  heap.Pop();
155  enqueued[s] = StateHeap::kNoKey;
156  visited[s] = true;
157  if (less(limit, Times(idistance[s], fst->Final(s)))) {
158  fst->SetFinal(s, Weight::Zero());
159  }
160  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
161  aiter.Next()) {
162  auto arc = aiter.Value(); // Copy intended.
163  if (!opts.filter(arc)) continue;
164  const auto weight =
165  Times(Times(idistance[s], arc.weight),
166  arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
167  : Weight::Zero());
168  if (less(limit, weight)) {
169  arc.nextstate = dead[0];
170  aiter.SetValue(arc);
171  continue;
172  }
173  if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
174  idistance[arc.nextstate] = Times(idistance[s], arc.weight);
175  }
176  if (visited[arc.nextstate]) continue;
177  if ((opts.state_threshold != kNoStateId) &&
178  (num_visited >= opts.state_threshold)) {
179  continue;
180  }
181  if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
182  enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
183  ++num_visited;
184  } else {
185  heap.Update(enqueued[arc.nextstate], arc.nextstate);
186  }
187  }
188  }
189  for (StateId i = 0; i < visited.size(); ++i) {
190  if (!visited[i]) dead.push_back(i);
191  }
192  fst->DeleteStates(dead);
193 }
194 
195 // Pruning algorithm: this version modifies its input and takes the
196 // pruning threshold as an argument. It deletes states and arcs in the
197 // FST that do not belong to a successful path whose weight is more
198 // than the weight of the shortest path Times() the provided weight
199 // threshold. When the state threshold is not kNoStateId, the output
200 // FST is further restricted to have no more than the number of states
201 // in opts.state_threshold. Weights must have the path property. The
202 // weight of any cycle needs to be bounded; i.e.,
203 //
204 // Plus(weight, Weight::One()) == Weight::One()
205 template <class Arc>
206 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
207  typename Arc::StateId state_threshold = kNoStateId,
208  float delta = kDelta) {
210  weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
211  Prune(fst, opts);
212 }
213 
214 // Pruning algorithm: this version writes the pruned input FST to an
215 // output MutableFst and it takes an options class as an argument. The
216 // output FST contains states and arcs that belong to a successful
217 // path in the input FST whose weight is more than the weight of the
218 // shortest path Times() the provided weight threshold. When the state
219 // threshold is not kNoStateId, the output FST is further restricted
220 // to have no more than the number of states in
221 // opts.state_threshold. Weights have the path property. The weight
222 // of any cycle needs to be bounded; i.e.,
223 //
224 // Plus(weight, Weight::One()) == Weight::One()
225 template <class Arc, class ArcFilter>
226 void Prune(
227  const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
229  using StateId = typename Arc::StateId;
230  using Weight = typename Arc::Weight;
231  static_assert(IsPath<Weight>::value, "Weight must have path property.");
233  ofst->DeleteStates();
234  ofst->SetInputSymbols(ifst.InputSymbols());
235  ofst->SetOutputSymbols(ifst.OutputSymbols());
236  if (ifst.Start() == kNoStateId) return;
237  NaturalLess<Weight> less;
238  if (less(opts.weight_threshold, Weight::One()) ||
239  (opts.state_threshold == 0)) {
240  return;
241  }
242  std::vector<Weight> idistance;
243  std::vector<Weight> tmp;
244  if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
245  const auto *fdistance = opts.distance ? opts.distance : &tmp;
246  if ((fdistance->size() <= ifst.Start()) ||
247  ((*fdistance)[ifst.Start()] == Weight::Zero())) {
248  return;
249  }
250  internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
251  StateHeap heap(compare);
252  std::vector<StateId> copy;
253  std::vector<size_t> enqueued;
254  std::vector<bool> visited;
255  auto s = ifst.Start();
256  const auto limit = opts.threshold_initial
257  ? Times(opts.weight_threshold, (*fdistance)[s])
258  : Times((*fdistance)[s], opts.weight_threshold);
259  while (copy.size() <= s) copy.push_back(kNoStateId);
260  copy[s] = ofst->AddState();
261  ofst->SetStart(copy[s]);
262  while (idistance.size() <= s) idistance.push_back(Weight::Zero());
263  idistance[s] = Weight::One();
264  while (enqueued.size() <= s) {
265  enqueued.push_back(StateHeap::kNoKey);
266  visited.push_back(false);
267  }
268  enqueued[s] = heap.Insert(s);
269  while (!heap.Empty()) {
270  s = heap.Top();
271  heap.Pop();
272  enqueued[s] = StateHeap::kNoKey;
273  visited[s] = true;
274  if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
275  ofst->SetFinal(copy[s], ifst.Final(s));
276  }
277  for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
278  const auto &arc = aiter.Value();
279  if (!opts.filter(arc)) continue;
280  const auto weight =
281  Times(Times(idistance[s], arc.weight),
282  arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
283  : Weight::Zero());
284  if (less(limit, weight)) continue;
285  if ((opts.state_threshold != kNoStateId) &&
286  (ofst->NumStates() >= opts.state_threshold)) {
287  continue;
288  }
289  while (idistance.size() <= arc.nextstate) {
290  idistance.push_back(Weight::Zero());
291  }
292  if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
293  idistance[arc.nextstate] = Times(idistance[s], arc.weight);
294  }
295  while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
296  if (copy[arc.nextstate] == kNoStateId) {
297  copy[arc.nextstate] = ofst->AddState();
298  }
299  ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
300  copy[arc.nextstate]));
301  while (enqueued.size() <= arc.nextstate) {
302  enqueued.push_back(StateHeap::kNoKey);
303  visited.push_back(false);
304  }
305  if (visited[arc.nextstate]) continue;
306  if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
307  enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
308  } else {
309  heap.Update(enqueued[arc.nextstate], arc.nextstate);
310  }
311  }
312  }
313 }
314 
315 // Pruning algorithm: this version writes the pruned input FST to an
316 // output MutableFst and simply takes the pruning threshold as an
317 // argument. The output FST contains states and arcs that belong to a
318 // successful path in the input FST whose weight is no more than the
319 // weight of the shortest path Times() the provided weight
320 // threshold. When the state threshold is not kNoStateId, the output
321 // FST is further restricted to have no more than the number of states
322 // in opts.state_threshold. Weights must have the path property. The
323 // weight of any cycle needs to be bounded; i.e.,
324 //
325 // Plus(weight, Weight::One()) = Weight::One();
326 template <class Arc>
327 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
328  typename Arc::Weight weight_threshold,
329  typename Arc::StateId state_threshold = kNoStateId,
330  float delta = kDelta) {
332  weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
333  Prune(ifst, ofst, opts);
334 }
335 
336 } // namespace fst
337 
338 #endif // FST_PRUNE_H_
bool operator()(const StateId x, const StateId y) const
Definition: prune.h:47
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
typename Arc::StateId StateId
Definition: prune.h:71
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
Weight weight_threshold
Definition: prune.h:87
typename Arc::Weight Weight
Definition: prune.h:72
constexpr int kNoStateId
Definition: fst.h:196
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
Definition: prune.h:114
virtual StateId Start() const =0
float delta
Definition: prune.h:96
virtual void AddArc(StateId, const Arc &)=0
virtual const SymbolTable * InputSymbols() const =0
StateId state_threshold
Definition: prune.h:89
ArcFilter filter
Definition: prune.h:91
const std::vector< Weight > * distance
Definition: prune.h:93
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
virtual StateId NumStates() const =0
bool threshold_initial
Definition: prune.h:101
PruneCompare(const std::vector< Weight > &idistance, const std::vector< Weight > &fdistance)
Definition: prune.h:43
constexpr float kDelta
Definition: weight.h:133
PruneOptions(const Weight &weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId, ArcFilter filter=ArcFilter(), std::vector< Weight > *distance=nullptr, float delta=kDelta, bool threshold_initial=false)
Definition: prune.h:74
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
Definition: weight.h:162
virtual const SymbolTable * OutputSymbols() const =0