FST  openfst-1.7.2
OpenFst Library
shortest-distance.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes to find shortest distance in an FST.
5 
6 #ifndef FST_SHORTEST_DISTANCE_H_
7 #define FST_SHORTEST_DISTANCE_H_
8 
9 #include <cstddef>
10 #include <vector>
11 
12 #include <fst/log.h>
13 
14 #include <fst/arcfilter.h>
15 #include <fst/cache.h>
16 #include <fst/queue.h>
17 #include <fst/reverse.h>
18 #include <fst/test-properties.h>
19 
20 
21 namespace fst {
22 
23 // A representable float for shortest distance and shortest path algorithms.
24 constexpr float kShortestDelta = 1e-6;
25 
26 template <class Arc, class Queue, class ArcFilter>
28  using StateId = typename Arc::StateId;
29 
30  Queue *state_queue; // Queue discipline used; owned by caller.
31  ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
32  StateId source; // If kNoStateId, use the FST's initial state.
33  float delta; // Determines the degree of convergence required
34  bool first_path; // For a semiring with the path property (o.w.
35  // undefined), compute the shortest-distances along
36  // along the first path to a final state found
37  // by the algorithm. That path is the shortest-path
38  // only if the FST has a unique final state (or all
39  // the final states have the same final weight), the
40  // queue discipline is shortest-first and all the
41  // weights in the FST are between One() and Zero()
42  // according to NaturalLess.
43 
44  ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
45  StateId source = kNoStateId,
46  float delta = kShortestDelta,
47  bool first_path = false)
48  : state_queue(state_queue),
49  arc_filter(arc_filter),
50  source(source),
51  delta(delta),
52  first_path(first_path) {}
53 };
54 
55 namespace internal {
56 
57 // Computation state of the shortest-distance algorithm. Reusable information
58 // is maintained across calls to member function ShortestDistance(source) when
59 // retain is true for improved efficiency when calling multiple times from
60 // different source states (e.g., in epsilon removal). Contrary to the usual
61 // conventions, fst may not be freed before this class. Vector distance
62 // should not be modified by the user between these calls. The Error() method
63 // returns true iff an error was encountered.
64 template <class Arc, class Queue, class ArcFilter>
66  public:
67  using StateId = typename Arc::StateId;
68  using Weight = typename Arc::Weight;
69 
71  const Fst<Arc> &fst, std::vector<Weight> *distance,
72  const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
73  : fst_(fst),
74  distance_(distance),
75  state_queue_(opts.state_queue),
76  arc_filter_(opts.arc_filter),
77  delta_(opts.delta),
78  first_path_(opts.first_path),
79  retain_(retain),
80  source_id_(0),
81  error_(false) {
82  distance_->clear();
83  if (fst.Properties(kExpanded, false) == kExpanded) {
84  const auto num_states = CountStates(fst);
85  distance_->reserve(num_states);
86  adder_.reserve(num_states);
87  radder_.reserve(num_states);
88  enqueued_.reserve(num_states);
89  }
90  }
91 
93 
94  bool Error() const { return error_; }
95 
96  private:
97  void EnsureDistanceIndexIsValid(std::size_t index) {
98  while (distance_->size() <= index) {
99  distance_->push_back(Weight::Zero());
100  adder_.push_back(Adder<Weight>());
101  radder_.push_back(Adder<Weight>());
102  enqueued_.push_back(false);
103  }
104  DCHECK_LT(index, distance_->size());
105  }
106 
107  void EnsureSourcesIndexIsValid(std::size_t index) {
108  while (sources_.size() <= index) {
109  sources_.push_back(kNoStateId);
110  }
111  DCHECK_LT(index, sources_.size());
112  }
113 
114  const Fst<Arc> &fst_;
115  std::vector<Weight> *distance_;
116  Queue *state_queue_;
117  ArcFilter arc_filter_;
118  const float delta_;
119  const bool first_path_;
120  const bool retain_; // Retain and reuse information across calls.
121 
122  std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
123  std::vector<Adder<Weight>> radder_; // Relaxation distance.
124  std::vector<bool> enqueued_; // Is state enqueued?
125  std::vector<StateId> sources_; // Source ID for ith state in distance_,
126  // (r)adder_, and enqueued_ if retained.
127  StateId source_id_; // Unique ID characterizing each call.
128  bool error_;
129 };
130 
131 // Compute the shortest distance; if source is kNoStateId, uses the initial
132 // state of the FST.
133 template <class Arc, class Queue, class ArcFilter>
135  StateId source) {
136  if (fst_.Start() == kNoStateId) {
137  if (fst_.Properties(kError, false)) error_ = true;
138  return;
139  }
140  if (!(Weight::Properties() & kRightSemiring)) {
141  FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
142  << Weight::Type();
143  error_ = true;
144  return;
145  }
146  if (first_path_ && !(Weight::Properties() & kPath)) {
147  FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
148  << "Weight does not have the path property: " << Weight::Type();
149  error_ = true;
150  return;
151  }
152  state_queue_->Clear();
153  if (!retain_) {
154  distance_->clear();
155  adder_.clear();
156  radder_.clear();
157  enqueued_.clear();
158  }
159  if (source == kNoStateId) source = fst_.Start();
160  EnsureDistanceIndexIsValid(source);
161  if (retain_) {
162  EnsureSourcesIndexIsValid(source);
163  sources_[source] = source_id_;
164  }
165  (*distance_)[source] = Weight::One();
166  adder_[source].Reset(Weight::One());
167  radder_[source].Reset(Weight::One());
168  enqueued_[source] = true;
169  state_queue_->Enqueue(source);
170  while (!state_queue_->Empty()) {
171  const auto state = state_queue_->Head();
172  state_queue_->Dequeue();
173  EnsureDistanceIndexIsValid(state);
174  if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
175  enqueued_[state] = false;
176  const auto r = radder_[state].Sum();
177  radder_[state].Reset();
178  for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
179  aiter.Next()) {
180  const auto &arc = aiter.Value();
181  const auto nextstate = arc.nextstate;
182  if (!arc_filter_(arc)) continue;
183  EnsureDistanceIndexIsValid(nextstate);
184  if (retain_) {
185  EnsureSourcesIndexIsValid(nextstate);
186  if (sources_[nextstate] != source_id_) {
187  (*distance_)[nextstate] = Weight::Zero();
188  adder_[nextstate].Reset();
189  radder_[nextstate].Reset();
190  enqueued_[nextstate] = false;
191  sources_[nextstate] = source_id_;
192  }
193  }
194  auto &nd = (*distance_)[nextstate];
195  auto &na = adder_[nextstate];
196  auto &nr = radder_[nextstate];
197  auto weight = Times(r, arc.weight);
198  if (!ApproxEqual(nd, Plus(nd, weight), delta_)) {
199  nd = na.Add(weight);
200  nr.Add(weight);
201  if (!nd.Member() || !nr.Sum().Member()) {
202  error_ = true;
203  return;
204  }
205  if (!enqueued_[nextstate]) {
206  state_queue_->Enqueue(nextstate);
207  enqueued_[nextstate] = true;
208  } else {
209  state_queue_->Update(nextstate);
210  }
211  }
212  }
213  }
214  ++source_id_;
215  if (fst_.Properties(kError, false)) error_ = true;
216 }
217 
218 } // namespace internal
219 
220 // Shortest-distance algorithm: this version allows fine control
221 // via the options argument. See below for a simpler interface.
222 //
223 // This computes the shortest distance from the opts.source state to each
224 // visited state S and stores the value in the distance vector. An
225 // unvisited state S has distance Zero(), which will be stored in the
226 // distance vector if S is less than the maximum visited state. The state
227 // queue discipline, arc filter, and convergence delta are taken in the
228 // options argument. The distance vector will contain a unique element for
229 // which Member() is false if an error was encountered.
230 //
231 // The weights must must be right distributive and k-closed (i.e., 1 +
232 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
233 //
234 // Complexity:
235 //
236 // Depends on properties of the semiring and the queue discipline.
237 //
238 // For more information, see:
239 //
240 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
241 // problems, Journal of Automata, Languages and
242 // Combinatorics 7(3): 321-350, 2002.
243 template <class Arc, class Queue, class ArcFilter>
245  const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
248  opts, false);
249  sd_state.ShortestDistance(opts.source);
250  if (sd_state.Error()) {
251  distance->assign(1, Arc::Weight::NoWeight());
252  }
253 }
254 
255 // Shortest-distance algorithm: simplified interface. See above for a version
256 // that permits finer control.
257 //
258 // If reverse is false, this computes the shortest distance from the initial
259 // state to each state S and stores the value in the distance vector. If
260 // reverse is true, this computes the shortest distance from each state to the
261 // final states. An unvisited state S has distance Zero(), which will be stored
262 // in the distance vector if S is less than the maximum visited state. The
263 // state queue discipline is automatically-selected. The distance vector will
264 // contain a unique element for which Member() is false if an error was
265 // encountered.
266 //
267 // The weights must must be right (left) distributive if reverse is false (true)
268 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
269 //
270 // Arc weights must satisfy the property that the sum of the weights of one or
271 // more paths from some state S to T is never Zero(). In particular, arc weights
272 // are never Zero().
273 //
274 // Complexity:
275 //
276 // Depends on properties of the semiring and the queue discipline.
277 //
278 // For more information, see:
279 //
280 // Mohri, M. 2002. Semiring framework and algorithms for
281 // shortest-distance problems, Journal of Automata, Languages and
282 // Combinatorics 7(3): 321-350, 2002.
283 template <class Arc>
285  std::vector<typename Arc::Weight> *distance,
286  bool reverse = false, float delta = kShortestDelta) {
287  using StateId = typename Arc::StateId;
288  if (!reverse) {
290  AutoQueue<StateId> state_queue(fst, distance, arc_filter);
292  opts(&state_queue, arc_filter, kNoStateId, delta);
293  ShortestDistance(fst, distance, opts);
294  } else {
295  using ReverseArc = ReverseArc<Arc>;
296  using ReverseWeight = typename ReverseArc::Weight;
297  AnyArcFilter<ReverseArc> rarc_filter;
299  Reverse(fst, &rfst);
300  std::vector<ReverseWeight> rdistance;
301  AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
304  ropts(&state_queue, rarc_filter, kNoStateId, delta);
305  ShortestDistance(rfst, &rdistance, ropts);
306  distance->clear();
307  if (rdistance.size() == 1 && !rdistance[0].Member()) {
308  distance->assign(1, Arc::Weight::NoWeight());
309  return;
310  }
311  DCHECK_GE(rdistance.size(), 1); // reversing added one state
312  distance->reserve(rdistance.size() - 1);
313  while (distance->size() < rdistance.size() - 1) {
314  distance->push_back(rdistance[distance->size() + 1].Reverse());
315  }
316  }
317 }
318 
319 // Return the sum of the weight of all successful paths in an FST, i.e., the
320 // shortest-distance from the initial state to the final states. Returns a
321 // weight such that Member() is false if an error was encountered.
322 template <class Arc>
323 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
324  float delta = kShortestDelta) {
325  using StateId = typename Arc::StateId;
326  using Weight = typename Arc::Weight;
327  std::vector<Weight> distance;
328  if (Weight::Properties() & kRightSemiring) {
329  ShortestDistance(fst, &distance, false, delta);
330  if (distance.size() == 1 && !distance[0].Member()) {
331  return Arc::Weight::NoWeight();
332  }
333  Adder<Weight> adder; // maintains cumulative sum accurately
334  for (StateId state = 0; state < distance.size(); ++state) {
335  adder.Add(Times(distance[state], fst.Final(state)));
336  }
337  return adder.Sum();
338  } else {
339  ShortestDistance(fst, &distance, true, delta);
340  const auto state = fst.Start();
341  if (distance.size() == 1 && !distance[0].Member()) {
342  return Arc::Weight::NoWeight();
343  }
344  return state != kNoStateId && state < distance.size() ? distance[state]
345  : Weight::Zero();
346  }
347 }
348 
349 } // namespace fst
350 
351 #endif // FST_SHORTEST_DISTANCE_H_
#define DCHECK_LT(x, y)
Definition: log.h:72
constexpr uint64 kRightSemiring
Definition: weight.h:115
virtual Weight Final(StateId) const =0
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
Arc::Weight ShortestDistance(const Fst< Arc > &fst, float delta=kShortestDelta)
constexpr int kNoStateId
Definition: fst.h:180
constexpr uint64 kExpanded
Definition: properties.h:27
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:20
virtual uint64 Properties(uint64 mask, bool test) const =0
#define FSTERROR()
Definition: util.h:35
Weight Add(const Weight &w)
Definition: weight.h:216
typename AWeight::ReverseWeight Weight
Definition: arc.h:141
Weight Sum()
Definition: weight.h:221
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
virtual StateId Start() const =0
constexpr float kShortestDelta
ShortestDistanceState(const Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:154
constexpr uint64 kPath
Definition: weight.h:126
constexpr bool ApproxEqual(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2, float delta=kDelta)
Definition: float-weight.h:140
constexpr uint64 kError
Definition: properties.h:33
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter, StateId source=kNoStateId, float delta=kShortestDelta, bool first_path=false)
#define DCHECK_GE(x, y)
Definition: log.h:75