FST  openfst-1.8.3
OpenFst Library
shortest-distance.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes to find shortest distance in an FST.
19 
20 #ifndef FST_SHORTEST_DISTANCE_H_
21 #define FST_SHORTEST_DISTANCE_H_
22 
23 #include <cstddef>
24 #include <vector>
25 
26 #include <fst/log.h>
27 #include <fst/arc.h>
28 #include <fst/arcfilter.h>
29 #include <fst/cache.h>
30 #include <fst/equal.h>
31 #include <fst/expanded-fst.h>
32 #include <fst/fst.h>
33 #include <fst/properties.h>
34 #include <fst/queue.h>
35 #include <fst/reverse.h>
36 #include <fst/util.h>
37 #include <fst/vector-fst.h>
38 #include <fst/weight.h>
39 
40 namespace fst {
41 
42 // A representable float for shortest distance and shortest path algorithms.
43 inline constexpr float kShortestDelta = 1e-6;
44 
45 template <class Arc, class Queue, class ArcFilter>
47  using StateId = typename Arc::StateId;
48 
49  Queue *state_queue; // Queue discipline used; owned by caller.
50  ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
51  StateId source; // If kNoStateId, use the FST's initial state.
52  float delta; // Determines the degree of convergence required
53  bool first_path; // For a semiring with the path property (o.w.
54  // undefined), compute the shortest-distances along
55  // along the first path to a final state found
56  // by the algorithm. That path is the shortest-path
57  // only if the FST has a unique final state (or all
58  // the final states have the same final weight), the
59  // queue discipline is shortest-first and all the
60  // weights in the FST are between One() and Zero()
61  // according to NaturalLess.
62 
63  ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
64  StateId source = kNoStateId,
65  float delta = kShortestDelta, bool first_path = false)
66  : state_queue(state_queue),
67  arc_filter(arc_filter),
68  source(source),
69  delta(delta),
70  first_path(first_path) {}
71 };
72 
73 namespace internal {
74 
75 // Computation state of the shortest-distance algorithm. Reusable information
76 // is maintained across calls to member function ShortestDistance(source) when
77 // retain is true for improved efficiency when calling multiple times from
78 // different source states (e.g., in epsilon removal). Contrary to the usual
79 // conventions, fst may not be freed before this class. Vector distance
80 // should not be modified by the user between these calls. The Error() method
81 // returns true iff an error was encountered.
82 template <class Arc, class Queue, class ArcFilter,
83  class WeightEqual = WeightApproxEqual>
85  public:
86  using StateId = typename Arc::StateId;
87  using Weight = typename Arc::Weight;
88 
90  const Fst<Arc> &fst, std::vector<Weight> *distance,
91  const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
92  : fst_(fst),
93  distance_(distance),
94  state_queue_(opts.state_queue),
95  arc_filter_(opts.arc_filter),
96  weight_equal_(opts.delta),
97  first_path_(opts.first_path),
98  retain_(retain),
99  source_id_(0),
100  error_(false) {
101  distance_->clear();
102  if (std::optional<StateId> num_states = fst.NumStatesIfKnown()) {
103  distance_->reserve(*num_states);
104  adder_.reserve(*num_states);
105  radder_.reserve(*num_states);
106  enqueued_.reserve(*num_states);
107  }
108  }
109 
111 
112  bool Error() const { return error_; }
113 
114  private:
115  void EnsureDistanceIndexIsValid(std::size_t index) {
116  while (distance_->size() <= index) {
117  distance_->push_back(Weight::Zero());
118  adder_.push_back(Adder<Weight>());
119  radder_.push_back(Adder<Weight>());
120  enqueued_.push_back(false);
121  }
122  DCHECK_LT(index, distance_->size());
123  }
124 
125  void EnsureSourcesIndexIsValid(std::size_t index) {
126  while (sources_.size() <= index) {
127  sources_.push_back(kNoStateId);
128  }
129  DCHECK_LT(index, sources_.size());
130  }
131 
132  const Fst<Arc> &fst_;
133  std::vector<Weight> *distance_;
134  Queue *state_queue_;
135  ArcFilter arc_filter_;
136  WeightEqual weight_equal_; // Determines when relaxation stops.
137  const bool first_path_;
138  const bool retain_; // Retain and reuse information across calls.
139 
140  std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
141  std::vector<Adder<Weight>> radder_; // Relaxation distance.
142  std::vector<bool> enqueued_; // Is state enqueued?
143  std::vector<StateId> sources_; // Source ID for ith state in distance_,
144  // (r)adder_, and enqueued_ if retained.
145  StateId source_id_; // Unique ID characterizing each call.
146  bool error_;
147 };
148 
149 // Compute the shortest distance; if source is kNoStateId, uses the initial
150 // state of the FST.
151 template <class Arc, class Queue, class ArcFilter, class WeightEqual>
152 void ShortestDistanceState<Arc, Queue, ArcFilter,
154  if (fst_.Start() == kNoStateId) {
155  if (fst_.Properties(kError, false)) error_ = true;
156  return;
157  }
158  if (!(Weight::Properties() & kRightSemiring)) {
159  FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
160  << Weight::Type();
161  error_ = true;
162  return;
163  }
164  if (first_path_ && !(Weight::Properties() & kPath)) {
165  FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
166  << "Weight does not have the path property: " << Weight::Type();
167  error_ = true;
168  return;
169  }
170  state_queue_->Clear();
171  if (!retain_) {
172  distance_->clear();
173  adder_.clear();
174  radder_.clear();
175  enqueued_.clear();
176  }
177  if (source == kNoStateId) source = fst_.Start();
178  EnsureDistanceIndexIsValid(source);
179  if (retain_) {
180  EnsureSourcesIndexIsValid(source);
181  sources_[source] = source_id_;
182  }
183  (*distance_)[source] = Weight::One();
184  adder_[source].Reset(Weight::One());
185  radder_[source].Reset(Weight::One());
186  enqueued_[source] = true;
187  state_queue_->Enqueue(source);
188  while (!state_queue_->Empty()) {
189  const auto state = state_queue_->Head();
190  state_queue_->Dequeue();
191  EnsureDistanceIndexIsValid(state);
192  if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
193  enqueued_[state] = false;
194  const auto r = radder_[state].Sum();
195  radder_[state].Reset();
196  for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
197  aiter.Next()) {
198  const auto &arc = aiter.Value();
199  const auto nextstate = arc.nextstate;
200  if (!arc_filter_(arc)) continue;
201  EnsureDistanceIndexIsValid(nextstate);
202  if (retain_) {
203  EnsureSourcesIndexIsValid(nextstate);
204  if (sources_[nextstate] != source_id_) {
205  (*distance_)[nextstate] = Weight::Zero();
206  adder_[nextstate].Reset();
207  radder_[nextstate].Reset();
208  enqueued_[nextstate] = false;
209  sources_[nextstate] = source_id_;
210  }
211  }
212  auto &nd = (*distance_)[nextstate];
213  auto &na = adder_[nextstate];
214  auto &nr = radder_[nextstate];
215  auto weight = Times(r, arc.weight);
216  if (!weight_equal_(nd, Plus(nd, weight))) {
217  nd = na.Add(weight);
218  nr.Add(weight);
219  if (!nd.Member() || !nr.Sum().Member()) {
220  error_ = true;
221  return;
222  }
223  if (!enqueued_[nextstate]) {
224  state_queue_->Enqueue(nextstate);
225  enqueued_[nextstate] = true;
226  } else {
227  state_queue_->Update(nextstate);
228  }
229  }
230  }
231  }
232  ++source_id_;
233  if (fst_.Properties(kError, false)) error_ = true;
234 }
235 
236 } // namespace internal
237 
238 // Shortest-distance algorithm: this version allows fine control
239 // via the options argument. See below for a simpler interface.
240 //
241 // This computes the shortest distance from the opts.source state to each
242 // visited state S and stores the value in the distance vector. An
243 // unvisited state S has distance Zero(), which will be stored in the
244 // distance vector if S is less than the maximum visited state. The state
245 // queue discipline, arc filter, and convergence delta are taken in the
246 // options argument. The distance vector will contain a unique element for
247 // which Member() is false if an error was encountered.
248 //
249 // The weights must must be right distributive and k-closed (i.e., 1 +
250 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
251 //
252 // Complexity:
253 //
254 // Depends on properties of the semiring and the queue discipline.
255 //
256 // For more information, see:
257 //
258 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
259 // problems, Journal of Automata, Languages and
260 // Combinatorics 7(3): 321-350, 2002.
261 template <class Arc, class Queue, class ArcFilter>
263  const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
266  opts, false);
267  sd_state.ShortestDistance(opts.source);
268  if (sd_state.Error()) {
269  distance->assign(1, Arc::Weight::NoWeight());
270  }
271 }
272 
273 // Shortest-distance algorithm: simplified interface. See above for a version
274 // that permits finer control.
275 //
276 // If reverse is false, this computes the shortest distance from the initial
277 // state to each state S and stores the value in the distance vector. If
278 // reverse is true, this computes the shortest distance from each state to the
279 // final states. An unvisited state S has distance Zero(), which will be stored
280 // in the distance vector if S is less than the maximum visited state. The
281 // state queue discipline is automatically-selected. The distance vector will
282 // contain a unique element for which Member() is false if an error was
283 // encountered.
284 //
285 // The weights must must be right (left) distributive if reverse is false (true)
286 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
287 //
288 // Arc weights must satisfy the property that the sum of the weights of one or
289 // more paths from some state S to T is never Zero(). In particular, arc weights
290 // are never Zero().
291 //
292 // Complexity:
293 //
294 // Depends on properties of the semiring and the queue discipline.
295 //
296 // For more information, see:
297 //
298 // Mohri, M. 2002. Semiring framework and algorithms for
299 // shortest-distance problems, Journal of Automata, Languages and
300 // Combinatorics 7(3): 321-350, 2002.
301 template <class Arc>
303  std::vector<typename Arc::Weight> *distance,
304  bool reverse = false, float delta = kShortestDelta) {
305  using StateId = typename Arc::StateId;
306  if (!reverse) {
308  AutoQueue<StateId> state_queue(fst, distance, arc_filter);
310  opts(&state_queue, arc_filter, kNoStateId, delta);
311  ShortestDistance(fst, distance, opts);
312  } else {
313  using ReverseArc = ReverseArc<Arc>;
314  using ReverseWeight = typename ReverseArc::Weight;
315  AnyArcFilter<ReverseArc> rarc_filter;
317  Reverse(fst, &rfst);
318  std::vector<ReverseWeight> rdistance;
319  AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
322  ropts(&state_queue, rarc_filter, kNoStateId, delta);
323  ShortestDistance(rfst, &rdistance, ropts);
324  distance->clear();
325  if (rdistance.size() == 1 && !rdistance[0].Member()) {
326  distance->assign(1, Arc::Weight::NoWeight());
327  return;
328  }
329  DCHECK_GE(rdistance.size(), 1); // reversing added one state
330  distance->reserve(rdistance.size() - 1);
331  while (distance->size() < rdistance.size() - 1) {
332  distance->push_back(rdistance[distance->size() + 1].Reverse());
333  }
334  }
335 }
336 
337 // Return the sum of the weight of all successful paths in an FST, i.e., the
338 // shortest-distance from the initial state to the final states. Returns a
339 // weight such that Member() is false if an error was encountered.
340 template <class Arc>
341 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
342  float delta = kShortestDelta) {
343  using StateId = typename Arc::StateId;
344  using Weight = typename Arc::Weight;
345  std::vector<Weight> distance;
346  if (Weight::Properties() & kRightSemiring) {
347  ShortestDistance(fst, &distance, false, delta);
348  if (distance.size() == 1 && !distance[0].Member()) {
349  return Arc::Weight::NoWeight();
350  }
351  Adder<Weight> adder; // maintains cumulative sum accurately
352  for (StateId state = 0; state < distance.size(); ++state) {
353  adder.Add(Times(distance[state], fst.Final(state)));
354  }
355  return adder.Sum();
356  } else {
357  ShortestDistance(fst, &distance, true, delta);
358  const auto state = fst.Start();
359  if (distance.size() == 1 && !distance[0].Member()) {
360  return Arc::Weight::NoWeight();
361  }
362  return state != kNoStateId && state < distance.size() ? distance[state]
363  : Weight::Zero();
364  }
365 }
366 
367 } // namespace fst
368 
369 #endif // FST_SHORTEST_DISTANCE_H_
#define DCHECK_LT(x, y)
Definition: log.h:76
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
Weight Sum() const
Definition: weight.h:224
constexpr uint64_t kError
Definition: properties.h:52
virtual Weight Final(StateId) const =0
constexpr int kNoStateId
Definition: fst.h:196
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:36
constexpr uint64_t kRightSemiring
Definition: weight.h:139
#define FSTERROR()
Definition: util.h:56
Weight Add(const Weight &w)
Definition: weight.h:219
virtual StateId Start() const =0
constexpr float kShortestDelta
virtual std::optional< StateId > NumStatesIfKnown() const
Definition: fst.h:228
constexpr uint64_t kPath
Definition: weight.h:150
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
ShortestDistanceState(const Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter, StateId source=kNoStateId, float delta=kShortestDelta, bool first_path=false)
#define DCHECK_GE(x, y)
Definition: log.h:79