FST  openfst-1.8.2
OpenFst Library
shortest-distance.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes to find shortest distance in an FST.
19 
20 #ifndef FST_SHORTEST_DISTANCE_H_
21 #define FST_SHORTEST_DISTANCE_H_
22 
23 #include <cstddef>
24 #include <vector>
25 
26 #include <fst/log.h>
27 
28 #include <fst/arcfilter.h>
29 #include <fst/cache.h>
30 #include <fst/equal.h>
31 #include <fst/queue.h>
32 #include <fst/reverse.h>
33 #include <fst/test-properties.h>
34 
35 
36 namespace fst {
37 
38 // A representable float for shortest distance and shortest path algorithms.
39 inline constexpr float kShortestDelta = 1e-6;
40 
41 template <class Arc, class Queue, class ArcFilter>
43  using StateId = typename Arc::StateId;
44 
45  Queue *state_queue; // Queue discipline used; owned by caller.
46  ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
47  StateId source; // If kNoStateId, use the FST's initial state.
48  float delta; // Determines the degree of convergence required
49  bool first_path; // For a semiring with the path property (o.w.
50  // undefined), compute the shortest-distances along
51  // along the first path to a final state found
52  // by the algorithm. That path is the shortest-path
53  // only if the FST has a unique final state (or all
54  // the final states have the same final weight), the
55  // queue discipline is shortest-first and all the
56  // weights in the FST are between One() and Zero()
57  // according to NaturalLess.
58 
59  ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
60  StateId source = kNoStateId,
61  float delta = kShortestDelta, bool first_path = false)
62  : state_queue(state_queue),
63  arc_filter(arc_filter),
64  source(source),
65  delta(delta),
66  first_path(first_path) {}
67 };
68 
69 namespace internal {
70 
71 // Computation state of the shortest-distance algorithm. Reusable information
72 // is maintained across calls to member function ShortestDistance(source) when
73 // retain is true for improved efficiency when calling multiple times from
74 // different source states (e.g., in epsilon removal). Contrary to the usual
75 // conventions, fst may not be freed before this class. Vector distance
76 // should not be modified by the user between these calls. The Error() method
77 // returns true iff an error was encountered.
78 template <class Arc, class Queue, class ArcFilter,
79  class WeightEqual = WeightApproxEqual>
81  public:
82  using StateId = typename Arc::StateId;
83  using Weight = typename Arc::Weight;
84 
86  const Fst<Arc> &fst, std::vector<Weight> *distance,
87  const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
88  : fst_(fst),
89  distance_(distance),
90  state_queue_(opts.state_queue),
91  arc_filter_(opts.arc_filter),
92  weight_equal_(opts.delta),
93  first_path_(opts.first_path),
94  retain_(retain),
95  source_id_(0),
96  error_(false) {
97  distance_->clear();
98  if (fst.Properties(kExpanded, false) == kExpanded) {
99  const auto num_states = CountStates(fst);
100  distance_->reserve(num_states);
101  adder_.reserve(num_states);
102  radder_.reserve(num_states);
103  enqueued_.reserve(num_states);
104  }
105  }
106 
108 
109  bool Error() const { return error_; }
110 
111  private:
112  void EnsureDistanceIndexIsValid(std::size_t index) {
113  while (distance_->size() <= index) {
114  distance_->push_back(Weight::Zero());
115  adder_.push_back(Adder<Weight>());
116  radder_.push_back(Adder<Weight>());
117  enqueued_.push_back(false);
118  }
119  DCHECK_LT(index, distance_->size());
120  }
121 
122  void EnsureSourcesIndexIsValid(std::size_t index) {
123  while (sources_.size() <= index) {
124  sources_.push_back(kNoStateId);
125  }
126  DCHECK_LT(index, sources_.size());
127  }
128 
129  const Fst<Arc> &fst_;
130  std::vector<Weight> *distance_;
131  Queue *state_queue_;
132  ArcFilter arc_filter_;
133  WeightEqual weight_equal_; // Determines when relaxation stops.
134  const bool first_path_;
135  const bool retain_; // Retain and reuse information across calls.
136 
137  std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
138  std::vector<Adder<Weight>> radder_; // Relaxation distance.
139  std::vector<bool> enqueued_; // Is state enqueued?
140  std::vector<StateId> sources_; // Source ID for ith state in distance_,
141  // (r)adder_, and enqueued_ if retained.
142  StateId source_id_; // Unique ID characterizing each call.
143  bool error_;
144 };
145 
146 // Compute the shortest distance; if source is kNoStateId, uses the initial
147 // state of the FST.
148 template <class Arc, class Queue, class ArcFilter, class WeightEqual>
149 void ShortestDistanceState<Arc, Queue, ArcFilter,
151  if (fst_.Start() == kNoStateId) {
152  if (fst_.Properties(kError, false)) error_ = true;
153  return;
154  }
155  if (!(Weight::Properties() & kRightSemiring)) {
156  FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
157  << Weight::Type();
158  error_ = true;
159  return;
160  }
161  if (first_path_ && !(Weight::Properties() & kPath)) {
162  FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
163  << "Weight does not have the path property: " << Weight::Type();
164  error_ = true;
165  return;
166  }
167  state_queue_->Clear();
168  if (!retain_) {
169  distance_->clear();
170  adder_.clear();
171  radder_.clear();
172  enqueued_.clear();
173  }
174  if (source == kNoStateId) source = fst_.Start();
175  EnsureDistanceIndexIsValid(source);
176  if (retain_) {
177  EnsureSourcesIndexIsValid(source);
178  sources_[source] = source_id_;
179  }
180  (*distance_)[source] = Weight::One();
181  adder_[source].Reset(Weight::One());
182  radder_[source].Reset(Weight::One());
183  enqueued_[source] = true;
184  state_queue_->Enqueue(source);
185  while (!state_queue_->Empty()) {
186  const auto state = state_queue_->Head();
187  state_queue_->Dequeue();
188  EnsureDistanceIndexIsValid(state);
189  if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
190  enqueued_[state] = false;
191  const auto r = radder_[state].Sum();
192  radder_[state].Reset();
193  for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
194  aiter.Next()) {
195  const auto &arc = aiter.Value();
196  const auto nextstate = arc.nextstate;
197  if (!arc_filter_(arc)) continue;
198  EnsureDistanceIndexIsValid(nextstate);
199  if (retain_) {
200  EnsureSourcesIndexIsValid(nextstate);
201  if (sources_[nextstate] != source_id_) {
202  (*distance_)[nextstate] = Weight::Zero();
203  adder_[nextstate].Reset();
204  radder_[nextstate].Reset();
205  enqueued_[nextstate] = false;
206  sources_[nextstate] = source_id_;
207  }
208  }
209  auto &nd = (*distance_)[nextstate];
210  auto &na = adder_[nextstate];
211  auto &nr = radder_[nextstate];
212  auto weight = Times(r, arc.weight);
213  if (!weight_equal_(nd, Plus(nd, weight))) {
214  nd = na.Add(weight);
215  nr.Add(weight);
216  if (!nd.Member() || !nr.Sum().Member()) {
217  error_ = true;
218  return;
219  }
220  if (!enqueued_[nextstate]) {
221  state_queue_->Enqueue(nextstate);
222  enqueued_[nextstate] = true;
223  } else {
224  state_queue_->Update(nextstate);
225  }
226  }
227  }
228  }
229  ++source_id_;
230  if (fst_.Properties(kError, false)) error_ = true;
231 }
232 
233 } // namespace internal
234 
235 // Shortest-distance algorithm: this version allows fine control
236 // via the options argument. See below for a simpler interface.
237 //
238 // This computes the shortest distance from the opts.source state to each
239 // visited state S and stores the value in the distance vector. An
240 // unvisited state S has distance Zero(), which will be stored in the
241 // distance vector if S is less than the maximum visited state. The state
242 // queue discipline, arc filter, and convergence delta are taken in the
243 // options argument. The distance vector will contain a unique element for
244 // which Member() is false if an error was encountered.
245 //
246 // The weights must must be right distributive and k-closed (i.e., 1 +
247 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
248 //
249 // Complexity:
250 //
251 // Depends on properties of the semiring and the queue discipline.
252 //
253 // For more information, see:
254 //
255 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
256 // problems, Journal of Automata, Languages and
257 // Combinatorics 7(3): 321-350, 2002.
258 template <class Arc, class Queue, class ArcFilter>
260  const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
263  opts, false);
264  sd_state.ShortestDistance(opts.source);
265  if (sd_state.Error()) {
266  distance->assign(1, Arc::Weight::NoWeight());
267  }
268 }
269 
270 // Shortest-distance algorithm: simplified interface. See above for a version
271 // that permits finer control.
272 //
273 // If reverse is false, this computes the shortest distance from the initial
274 // state to each state S and stores the value in the distance vector. If
275 // reverse is true, this computes the shortest distance from each state to the
276 // final states. An unvisited state S has distance Zero(), which will be stored
277 // in the distance vector if S is less than the maximum visited state. The
278 // state queue discipline is automatically-selected. The distance vector will
279 // contain a unique element for which Member() is false if an error was
280 // encountered.
281 //
282 // The weights must must be right (left) distributive if reverse is false (true)
283 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
284 //
285 // Arc weights must satisfy the property that the sum of the weights of one or
286 // more paths from some state S to T is never Zero(). In particular, arc weights
287 // are never Zero().
288 //
289 // Complexity:
290 //
291 // Depends on properties of the semiring and the queue discipline.
292 //
293 // For more information, see:
294 //
295 // Mohri, M. 2002. Semiring framework and algorithms for
296 // shortest-distance problems, Journal of Automata, Languages and
297 // Combinatorics 7(3): 321-350, 2002.
298 template <class Arc>
300  std::vector<typename Arc::Weight> *distance,
301  bool reverse = false, float delta = kShortestDelta) {
302  using StateId = typename Arc::StateId;
303  if (!reverse) {
305  AutoQueue<StateId> state_queue(fst, distance, arc_filter);
307  opts(&state_queue, arc_filter, kNoStateId, delta);
308  ShortestDistance(fst, distance, opts);
309  } else {
310  using ReverseArc = ReverseArc<Arc>;
311  using ReverseWeight = typename ReverseArc::Weight;
312  AnyArcFilter<ReverseArc> rarc_filter;
314  Reverse(fst, &rfst);
315  std::vector<ReverseWeight> rdistance;
316  AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
319  ropts(&state_queue, rarc_filter, kNoStateId, delta);
320  ShortestDistance(rfst, &rdistance, ropts);
321  distance->clear();
322  if (rdistance.size() == 1 && !rdistance[0].Member()) {
323  distance->assign(1, Arc::Weight::NoWeight());
324  return;
325  }
326  DCHECK_GE(rdistance.size(), 1); // reversing added one state
327  distance->reserve(rdistance.size() - 1);
328  while (distance->size() < rdistance.size() - 1) {
329  distance->push_back(rdistance[distance->size() + 1].Reverse());
330  }
331  }
332 }
333 
334 // Return the sum of the weight of all successful paths in an FST, i.e., the
335 // shortest-distance from the initial state to the final states. Returns a
336 // weight such that Member() is false if an error was encountered.
337 template <class Arc>
338 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
339  float delta = kShortestDelta) {
340  using StateId = typename Arc::StateId;
341  using Weight = typename Arc::Weight;
342  std::vector<Weight> distance;
343  if (Weight::Properties() & kRightSemiring) {
344  ShortestDistance(fst, &distance, false, delta);
345  if (distance.size() == 1 && !distance[0].Member()) {
346  return Arc::Weight::NoWeight();
347  }
348  Adder<Weight> adder; // maintains cumulative sum accurately
349  for (StateId state = 0; state < distance.size(); ++state) {
350  adder.Add(Times(distance[state], fst.Final(state)));
351  }
352  return adder.Sum();
353  } else {
354  ShortestDistance(fst, &distance, true, delta);
355  const auto state = fst.Start();
356  if (distance.size() == 1 && !distance[0].Member()) {
357  return Arc::Weight::NoWeight();
358  }
359  return state != kNoStateId && state < distance.size() ? distance[state]
360  : Weight::Zero();
361  }
362 }
363 
364 } // namespace fst
365 
366 #endif // FST_SHORTEST_DISTANCE_H_
virtual uint64_t Properties(uint64_t mask, bool test) const =0
#define DCHECK_LT(x, y)
Definition: log.h:72
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
Weight Sum() const
Definition: weight.h:223
constexpr uint64_t kError
Definition: properties.h:51
virtual Weight Final(StateId) const =0
constexpr int kNoStateId
Definition: fst.h:202
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:34
constexpr uint64_t kRightSemiring
Definition: weight.h:136
#define FSTERROR()
Definition: util.h:53
Weight Add(const Weight &w)
Definition: weight.h:218
typename Arc::Weight::ReverseWeight Weight
Definition: arc.h:173
virtual StateId Start() const =0
constexpr float kShortestDelta
constexpr uint64_t kPath
Definition: weight.h:147
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:169
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
ShortestDistanceState(const Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter, StateId source=kNoStateId, float delta=kShortestDelta, bool first_path=false)
#define DCHECK_GE(x, y)
Definition: log.h:75
constexpr uint64_t kExpanded
Definition: properties.h:45