FST  openfst-1.7.2
OpenFst Library
queue.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for various FST state queues with a unified interface.
5 
6 #ifndef FST_QUEUE_H_
7 #define FST_QUEUE_H_
8 
9 #include <deque>
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 #include <vector>
14 
15 #include <fst/log.h>
16 
17 #include <fst/arcfilter.h>
18 #include <fst/connect.h>
19 #include <fst/heap.h>
20 #include <fst/topsort.h>
21 #include <fst/weight.h>
22 
23 
24 namespace fst {
25 
26 // The Queue interface is:
27 //
28 // template <class S>
29 // class Queue {
30 // public:
31 // using StateId = S;
32 //
33 // // Constructor: may need args (e.g., FST, comparator) for some queues.
34 // Queue(...) override;
35 //
36 // // Returns the head of the queue.
37 // StateId Head() const override;
38 //
39 // // Inserts a state.
40 // void Enqueue(StateId s) override;
41 //
42 // // Removes the head of the queue.
43 // void Dequeue() override;
44 //
45 // // Updates ordering of state s when weight changes, if necessary.
46 // void Update(StateId s) override;
47 //
48 // // Is the queue empty?
49 // bool Empty() const override;
50 //
51 // // Removes all states from the queue.
52 // void Clear() override;
53 // };
54 
55 // State queue types.
56 enum QueueType {
57  TRIVIAL_QUEUE = 0, // Single state queue.
58  FIFO_QUEUE = 1, // First-in, first-out queue.
59  LIFO_QUEUE = 2, // Last-in, first-out queue.
60  SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
61  TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
62  STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
63  SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
64  AUTO_QUEUE = 7, // Auto-selected queue.
66 };
67 
68 // QueueBase, templated on the StateId, is a virtual base class shared by all
69 // queues considered by AutoQueue.
70 template <class S>
71 class QueueBase {
72  public:
73  using StateId = S;
74 
75  virtual ~QueueBase() {}
76 
77  // Concrete implementation.
78 
79  explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
80 
81  void SetError(bool error) { error_ = error; }
82 
83  bool Error() const { return error_; }
84 
85  QueueType Type() const { return queue_type_; }
86 
87  // Virtual interface.
88 
89  virtual StateId Head() const = 0;
90  virtual void Enqueue(StateId) = 0;
91  virtual void Dequeue() = 0;
92  virtual void Update(StateId) = 0;
93  virtual bool Empty() const = 0;
94  virtual void Clear() = 0;
95 
96  private:
97  QueueType queue_type_;
98  bool error_;
99 };
100 
101 // Trivial queue discipline; one may enqueue at most one state at a time. It
102 // can be used for strongly connected components with only one state and no
103 // self-loops.
104 template <class S>
105 class TrivialQueue : public QueueBase<S> {
106  public:
107  using StateId = S;
108 
110 
111  virtual ~TrivialQueue() = default;
112 
113  StateId Head() const final { return front_; }
114 
115  void Enqueue(StateId s) final { front_ = s; }
116 
117  void Dequeue() final { front_ = kNoStateId; }
118 
119  void Update(StateId) final {}
120 
121  bool Empty() const final { return front_ == kNoStateId; }
122 
123  void Clear() final { front_ = kNoStateId; }
124 
125  private:
126  StateId front_;
127 };
128 
129 // First-in, first-out queue discipline.
130 //
131 // This is not a final class.
132 template <class S>
133 class FifoQueue : public QueueBase<S> {
134  public:
135  using StateId = S;
136 
138 
139  virtual ~FifoQueue() = default;
140 
141  StateId Head() const override { return queue_.back(); }
142 
143  void Enqueue(StateId s) override { queue_.push_front(s); }
144 
145  void Dequeue() override { queue_.pop_back(); }
146 
147  void Update(StateId) override {}
148 
149  bool Empty() const override { return queue_.empty(); }
150 
151  void Clear() override { queue_.clear(); }
152 
153  private:
154  std::deque<StateId> queue_;
155 };
156 
157 // Last-in, first-out queue discipline.
158 template <class S>
159 class LifoQueue : public QueueBase<S> {
160  public:
161  using StateId = S;
162 
164 
165  virtual ~LifoQueue() = default;
166 
167  StateId Head() const final { return queue_.front(); }
168 
169  void Enqueue(StateId s) final { queue_.push_front(s); }
170 
171  void Dequeue() final { queue_.pop_front(); }
172 
173  void Update(StateId) final {}
174 
175  bool Empty() const final { return queue_.empty(); }
176 
177  void Clear() final { queue_.clear(); }
178 
179  private:
180  std::deque<StateId> queue_;
181 };
182 
183 // Shortest-first queue discipline, templated on the StateId and as well as a
184 // comparison functor used to compare two StateIds. If a (single) state's order
185 // changes, it can be reordered in the queue with a call to Update(). If update
186 // is false, call to Update() does not reorder the queue.
187 //
188 // This is not a final class.
189 template <typename S, typename Compare, bool update = true>
190 class ShortestFirstQueue : public QueueBase<S> {
191  public:
192  using StateId = S;
193 
194  explicit ShortestFirstQueue(Compare comp)
195  : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
196 
197  virtual ~ShortestFirstQueue() = default;
198 
199  StateId Head() const override { return heap_.Top(); }
200 
201  void Enqueue(StateId s) override {
202  if (update) {
203  for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
204  key_[s] = heap_.Insert(s);
205  } else {
206  heap_.Insert(s);
207  }
208  }
209 
210  void Dequeue() override {
211  if (update) {
212  key_[heap_.Pop()] = kNoStateId;
213  } else {
214  heap_.Pop();
215  }
216  }
217 
218  void Update(StateId s) override {
219  if (!update) return;
220  if (s >= key_.size() || key_[s] == kNoStateId) {
221  Enqueue(s);
222  } else {
223  heap_.Update(key_[s], s);
224  }
225  }
226 
227  bool Empty() const override { return heap_.Empty(); }
228 
229  void Clear() override {
230  heap_.Clear();
231  if (update) key_.clear();
232  }
233 
234  const Compare &GetCompare() const { return heap_.GetCompare(); }
235 
236  private:
238  std::vector<ssize_t> key_;
239 };
240 
241 namespace internal {
242 
243 // Given a vector that maps from states to weights, and a comparison functor
244 // for weights, this class defines a comparison function object between states.
245 template <typename StateId, typename Less>
247  public:
248  using Weight = typename Less::Weight;
249 
250  StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
251  : weights_(weights), less_(less) {}
252 
253  bool operator()(const StateId s1, const StateId s2) const {
254  return less_(weights_[s1], weights_[s2]);
255  }
256 
257  private:
258  // Borrowed references.
259  const std::vector<Weight> &weights_;
260  const Less &less_;
261 };
262 
263 } // namespace internal
264 
265 // Shortest-first queue discipline, templated on the StateId and Weight, is
266 // specialized to use the weight's natural order for the comparison function.
267 template <typename S, typename Weight>
269  : public ShortestFirstQueue<
270  S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
271  public:
272  using StateId = S;
274 
275  explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
276  : ShortestFirstQueue<StateId, Compare>(Compare(distance, less_)) {}
277 
278  virtual ~NaturalShortestFirstQueue() = default;
279 
280  private:
281  // This is non-static because the constructor for non-idempotent weights will
282  // result in an error.
283  const NaturalLess<Weight> less_{};
284 };
285 
286 // In a shortest path computation on a lattice-like FST, we may keep many old
287 // nonviable paths as a part of the search. Since the search process always
288 // expands the lowest cost path next, that lowest cost path may be a very old
289 // nonviable path instead of one we expect to lead to a shortest path.
290 //
291 // For instance, suppose that the current best path in an alignment has
292 // traversed 500 arcs with a cost of 10. We may also have a bad path in
293 // the queue that has traversed only 40 arcs but also has a cost of 10.
294 // This path is very unlikely to lead to a reasonable alignment, so this queue
295 // can prune it from the search space.
296 //
297 // This queue relies on the caller using a shortest-first exploration order
298 // like this:
299 // while (true) {
300 // StateId head = queue.Head();
301 // queue.Dequeue();
302 // for (const auto& arc : GetArcs(fst, head)) {
303 // queue.Enqueue(arc.nextstate);
304 // }
305 // }
306 // We use this assumption to guess that there is an arc between Head and the
307 // Enqueued state; this is how the number of path steps is measured.
308 template <typename S, typename Weight>
310  : public NaturalShortestFirstQueue<S, Weight> {
311  public:
312  using StateId = S;
314 
315  explicit PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
316  int threshold)
317  : Base(distance),
318  threshold_(threshold),
319  head_steps_(0),
320  max_head_steps_(0) {}
321 
322  ~PruneNaturalShortestFirstQueue() override = default;
323 
324  StateId Head() const override {
325  const auto head = Base::Head();
326  // Stores the number of steps from the start of the graph to this state
327  // along the shortest-weight path.
328  if (head < steps_.size()) {
329  max_head_steps_ = std::max(steps_[head], max_head_steps_);
330  head_steps_ = steps_[head];
331  }
332  return head;
333  }
334 
335  void Enqueue(StateId s) override {
336  // We assume that there is an arc between the Head() state and this
337  // Enqueued state.
338  const ssize_t state_steps = head_steps_ + 1;
339  if (s >= steps_.size()) {
340  steps_.resize(s + 1, state_steps);
341  }
342  // This is the number of arcs in the minimum cost path from Start to s.
343  steps_[s] = state_steps;
344  if (state_steps > (max_head_steps_ - threshold_) || threshold_ < 0) {
345  Base::Enqueue(s);
346  }
347  }
348 
349  private:
350  // A dense map from StateId to the number of arcs in the minimum weight
351  // path from Start to this state.
352  std::vector<ssize_t> steps_;
353  // We only keep paths that are within this number of arcs (not weight!)
354  // of the longest path.
355  const ssize_t threshold_;
356 
357  // The following are mutable because Head() is const.
358  // The number of arcs traversed in the minimum cost path from the start
359  // state to the current Head() state.
360  mutable ssize_t head_steps_;
361  // The maximum number of arcs traversed by any low-cost path so far.
362  mutable ssize_t max_head_steps_;
363 };
364 
365 // Topological-order queue discipline, templated on the StateId. States are
366 // ordered in the queue topologically. The FST must be acyclic.
367 template <class S>
368 class TopOrderQueue : public QueueBase<S> {
369  public:
370  using StateId = S;
371 
372  // This constructor computes the topological order. It accepts an arc filter
373  // to limit the transitions considered in that computation (e.g., only the
374  // epsilon graph).
375  template <class Arc, class ArcFilter>
376  TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
378  front_(0),
379  back_(kNoStateId),
380  order_(0),
381  state_(0) {
382  bool acyclic;
383  TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
384  DfsVisit(fst, &top_order_visitor, filter);
385  if (!acyclic) {
386  FSTERROR() << "TopOrderQueue: FST is not acyclic";
388  }
389  state_.resize(order_.size(), kNoStateId);
390  }
391 
392  // This constructor is passed the pre-computed topological order.
393  explicit TopOrderQueue(const std::vector<StateId> &order)
395  front_(0),
396  back_(kNoStateId),
397  order_(order),
398  state_(order.size(), kNoStateId) {}
399 
400  virtual ~TopOrderQueue() = default;
401 
402  StateId Head() const final { return state_[front_]; }
403 
404  void Enqueue(StateId s) final {
405  if (front_ > back_) {
406  front_ = back_ = order_[s];
407  } else if (order_[s] > back_) {
408  back_ = order_[s];
409  } else if (order_[s] < front_) {
410  front_ = order_[s];
411  }
412  state_[order_[s]] = s;
413  }
414 
415  void Dequeue() final {
416  state_[front_] = kNoStateId;
417  while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
418  }
419 
420  void Update(StateId) final {}
421 
422  bool Empty() const final { return front_ > back_; }
423 
424  void Clear() final {
425  for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
426  back_ = kNoStateId;
427  front_ = 0;
428  }
429 
430  private:
431  StateId front_;
432  StateId back_;
433  std::vector<StateId> order_;
434  std::vector<StateId> state_;
435 };
436 
437 // State order queue discipline, templated on the StateId. States are ordered in
438 // the queue by state ID.
439 template <class S>
440 class StateOrderQueue : public QueueBase<S> {
441  public:
442  using StateId = S;
443 
445  : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
446 
447  virtual ~StateOrderQueue() = default;
448 
449  StateId Head() const final { return front_; }
450 
451  void Enqueue(StateId s) final {
452  if (front_ > back_) {
453  front_ = back_ = s;
454  } else if (s > back_) {
455  back_ = s;
456  } else if (s < front_) {
457  front_ = s;
458  }
459  while (enqueued_.size() <= s) enqueued_.push_back(false);
460  enqueued_[s] = true;
461  }
462 
463  void Dequeue() final {
464  enqueued_[front_] = false;
465  while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
466  }
467 
468  void Update(StateId) final {}
469 
470  bool Empty() const final { return front_ > back_; }
471 
472  void Clear() final {
473  for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
474  front_ = 0;
475  back_ = kNoStateId;
476  }
477 
478  private:
479  StateId front_;
480  StateId back_;
481  std::vector<bool> enqueued_;
482 };
483 
484 // SCC topological-order meta-queue discipline, templated on the StateId and a
485 // queue used inside each SCC. It visits the SCCs of an FST in topological
486 // order. Its constructor is passed the queues to to use within an SCC.
487 template <class S, class Queue>
488 class SccQueue : public QueueBase<S> {
489  public:
490  using StateId = S;
491 
492  // Constructor takes a vector specifying the SCC number per state and a
493  // vector giving the queue to use per SCC number.
494  SccQueue(const std::vector<StateId> &scc,
495  std::vector<std::unique_ptr<Queue>> *queue)
497  queue_(queue),
498  scc_(scc),
499  front_(0),
500  back_(kNoStateId) {}
501 
502  virtual ~SccQueue() = default;
503 
504  StateId Head() const final {
505  while ((front_ <= back_) &&
506  (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
507  (((*queue_)[front_] == nullptr) &&
508  ((front_ >= trivial_queue_.size()) ||
509  (trivial_queue_[front_] == kNoStateId))))) {
510  ++front_;
511  }
512  if ((*queue_)[front_]) {
513  return (*queue_)[front_]->Head();
514  } else {
515  return trivial_queue_[front_];
516  }
517  }
518 
519  void Enqueue(StateId s) final {
520  if (front_ > back_) {
521  front_ = back_ = scc_[s];
522  } else if (scc_[s] > back_) {
523  back_ = scc_[s];
524  } else if (scc_[s] < front_) {
525  front_ = scc_[s];
526  }
527  if ((*queue_)[scc_[s]]) {
528  (*queue_)[scc_[s]]->Enqueue(s);
529  } else {
530  while (trivial_queue_.size() <= scc_[s]) {
531  trivial_queue_.push_back(kNoStateId);
532  }
533  trivial_queue_[scc_[s]] = s;
534  }
535  }
536 
537  void Dequeue() final {
538  if ((*queue_)[front_]) {
539  (*queue_)[front_]->Dequeue();
540  } else if (front_ < trivial_queue_.size()) {
541  trivial_queue_[front_] = kNoStateId;
542  }
543  }
544 
545  void Update(StateId s) final {
546  if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
547  }
548 
549  bool Empty() const final {
550  // Queues SCC number back_ is not empty unless back_ == front_.
551  if (front_ < back_) {
552  return false;
553  } else if (front_ > back_) {
554  return true;
555  } else if ((*queue_)[front_]) {
556  return (*queue_)[front_]->Empty();
557  } else {
558  return (front_ >= trivial_queue_.size()) ||
559  (trivial_queue_[front_] == kNoStateId);
560  }
561  }
562 
563  void Clear() final {
564  for (StateId i = front_; i <= back_; ++i) {
565  if ((*queue_)[i]) {
566  (*queue_)[i]->Clear();
567  } else if (i < trivial_queue_.size()) {
568  trivial_queue_[i] = kNoStateId;
569  }
570  }
571  front_ = 0;
572  back_ = kNoStateId;
573  }
574 
575  private:
576  std::vector<std::unique_ptr<Queue>> *queue_;
577  const std::vector<StateId> &scc_;
578  mutable StateId front_;
579  StateId back_;
580  std::vector<StateId> trivial_queue_;
581 };
582 
583 // Automatic queue discipline. It selects a queue discipline for a given FST
584 // based on its properties.
585 template <class S>
586 class AutoQueue : public QueueBase<S> {
587  public:
588  using StateId = S;
589 
590  // This constructor takes a state distance vector that, if non-null and if
591  // the Weight type has the path property, will entertain the shortest-first
592  // queue using the natural order w.r.t to the distance.
593  template <class Arc, class ArcFilter>
595  const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
597  using Weight = typename Arc::Weight;
598  using Less = NaturalLess<Weight>;
600  // First checks if the FST is known to have these properties.
601  const auto props =
603  if ((props & kTopSorted) || fst.Start() == kNoStateId) {
604  queue_.reset(new StateOrderQueue<StateId>());
605  VLOG(2) << "AutoQueue: using state-order discipline";
606  } else if (props & kAcyclic) {
607  queue_.reset(new TopOrderQueue<StateId>(fst, filter));
608  VLOG(2) << "AutoQueue: using top-order discipline";
609  } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) {
610  queue_.reset(new LifoQueue<StateId>());
611  VLOG(2) << "AutoQueue: using LIFO discipline";
612  } else {
613  uint64 properties;
614  // Decomposes into strongly-connected components.
615  SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
616  DfsVisit(fst, &scc_visitor, filter);
617  auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
618  std::vector<QueueType> queue_types(nscc);
619  std::unique_ptr<Less> less;
620  std::unique_ptr<Compare> comp;
621  if (distance && (Weight::Properties() & kPath) == kPath) {
622  less.reset(new Less);
623  comp.reset(new Compare(*distance, *less));
624  }
625  // Finds the queue type to use per SCC.
626  bool unweighted;
627  bool all_trivial;
628  SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
629  &unweighted);
630  // If unweighted and semiring is idempotent, uses LIFO queue.
631  if (unweighted) {
632  queue_.reset(new LifoQueue<StateId>());
633  VLOG(2) << "AutoQueue: using LIFO discipline";
634  return;
635  }
636  // If all the SCC are trivial, the FST is acyclic and the scc number gives
637  // the topological order.
638  if (all_trivial) {
639  queue_.reset(new TopOrderQueue<StateId>(scc_));
640  VLOG(2) << "AutoQueue: using top-order discipline";
641  return;
642  }
643  VLOG(2) << "AutoQueue: using SCC meta-discipline";
644  queues_.resize(nscc);
645  for (StateId i = 0; i < nscc; ++i) {
646  switch (queue_types[i]) {
647  case TRIVIAL_QUEUE:
648  queues_[i].reset();
649  VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
650  break;
652  queues_[i].reset(
654  VLOG(3) << "AutoQueue: SCC #" << i
655  << ": using shortest-first discipline";
656  break;
657  case LIFO_QUEUE:
658  queues_[i].reset(new LifoQueue<StateId>());
659  VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
660  break;
661  case FIFO_QUEUE:
662  default:
663  queues_[i].reset(new FifoQueue<StateId>());
664  VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
665  break;
666  }
667  }
668  queue_.reset(new SccQueue<StateId, QueueBase<StateId>>(scc_, &queues_));
669  }
670  }
671 
672  virtual ~AutoQueue() = default;
673 
674  StateId Head() const final { return queue_->Head(); }
675 
676  void Enqueue(StateId s) final { queue_->Enqueue(s); }
677 
678  void Dequeue() final { queue_->Dequeue(); }
679 
680  void Update(StateId s) final { queue_->Update(s); }
681 
682  bool Empty() const final { return queue_->Empty(); }
683 
684  void Clear() final { queue_->Clear(); }
685 
686  private:
687  template <class Arc, class ArcFilter, class Less>
688  static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
689  std::vector<QueueType> *queue_types,
690  ArcFilter filter, Less *less, bool *all_trivial,
691  bool *unweighted);
692 
693  std::unique_ptr<QueueBase<StateId>> queue_;
694  std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
695  std::vector<StateId> scc_;
696 };
697 
698 // Examines the states in an FST's strongly connected components and determines
699 // which type of queue to use per SCC. Stores result as a vector of QueueTypes
700 // which is assumed to have length equal to the number of SCCs. An arc filter
701 // is used to limit the transitions considered (e.g., only the epsilon graph).
702 // The argument all_trivial is set to true if every queue is the trivial queue.
703 // The argument unweighted is set to true if the semiring is idempotent and all
704 // the arc weights are equal to Zero() or One().
705 template <class StateId>
706 template <class Arc, class ArcFilter, class Less>
708  const std::vector<StateId> &scc,
709  std::vector<QueueType> *queue_type,
710  ArcFilter filter, Less *less,
711  bool *all_trivial, bool *unweighted) {
712  using StateId = typename Arc::StateId;
713  using Weight = typename Arc::Weight;
714  *all_trivial = true;
715  *unweighted = true;
716  for (StateId i = 0; i < queue_type->size(); ++i) {
717  (*queue_type)[i] = TRIVIAL_QUEUE;
718  }
719  for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
720  const auto state = sit.Value();
721  for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
722  const auto &arc = ait.Value();
723  if (!filter(arc)) continue;
724  if (scc[state] == scc[arc.nextstate]) {
725  auto &type = (*queue_type)[scc[state]];
726  if (!less || ((*less)(arc.weight, Weight::One()))) {
727  type = FIFO_QUEUE;
728  } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
729  if (!(Weight::Properties() & kIdempotent) ||
730  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
731  type = SHORTEST_FIRST_QUEUE;
732  } else {
733  type = LIFO_QUEUE;
734  }
735  }
736  if (type != TRIVIAL_QUEUE) *all_trivial = false;
737  }
738  if (!(Weight::Properties() & kIdempotent) ||
739  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
740  *unweighted = false;
741  }
742  }
743  }
744 }
745 
746 // An A* estimate is a function object that maps from a state ID to an
747 // estimate of the shortest distance to the final states.
748 
749 // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
750 // algorithm.
751 template <typename StateId, typename Weight>
753  constexpr Weight operator()(StateId) const { return Weight::One(); }
754 };
755 
756 // A non-trivial A* estimate using a vector of the estimated future costs.
757 template <typename StateId, typename Weight>
759  public:
760  NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
761 
762  const Weight &operator()(StateId s) const {
763  return (s < beta_.size()) ? beta_[s] : kZero;
764  }
765 
766  private:
767  static constexpr Weight kZero = Weight::Zero();
768 
769  const std::vector<Weight> &beta_;
770 };
771 
772 template <typename Arc, typename Weight>
774 
775 // Given a vector that maps from states to weights representing the shortest
776 // distance from the initial state, a comparison function object between
777 // weights, and an estimate of the shortest distance to the final states, this
778 // class defines a comparison function object between states.
779 template <typename S, typename Less, typename Estimate>
781  public:
782  using StateId = S;
783  using Weight = typename Less::Weight;
784 
785  AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
786  const Estimate &estimate)
787  : weights_(weights), less_(less), estimate_(estimate) {}
788 
789  bool operator()(StateId s1, StateId s2) const {
790  const auto w1 = Times(weights_[s1], estimate_(s1));
791  const auto w2 = Times(weights_[s2], estimate_(s2));
792  return less_(w1, w2);
793  }
794 
795  const Estimate &GetEstimate() const { return estimate_; }
796 
797  private:
798  const std::vector<Weight> &weights_;
799  const Less &less_;
800  const Estimate &estimate_;
801 };
802 
803 // A* queue discipline templated on StateId, Weight, and Estimate.
804 template <typename S, typename Weight, typename Estimate>
806  S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
807  public:
808  using StateId = S;
810 
811  NaturalAStarQueue(const std::vector<Weight> &distance,
812  const Estimate &estimate)
814  Compare(distance, less_, estimate)) {}
815 
816  ~NaturalAStarQueue() = default;
817 
818  private:
819  // This is non-static because the constructor for non-idempotent weights will
820  // result in an error.
821  const NaturalLess<Weight> less_{};
822 };
823 
824 // A state equivalence class is a function object that maps from a state ID to
825 // an equivalence class (state) ID. The trivial equivalence class maps a state
826 // ID to itself.
827 template <typename StateId>
829  StateId operator()(StateId s) const { return s; }
830 };
831 
832 // Distance-based pruning queue discipline: Enqueues a state only when its
833 // shortest distance (so far), as specified by distance, is less than (as
834 // specified by comp) the shortest distance Times() the threshold to any state
835 // in the same equivalence class, as specified by the functor class_func. The
836 // underlying queue discipline is specified by queue. The ownership of queue is
837 // given to this class.
838 //
839 // This is not a final class.
840 template <typename Queue, typename Less, typename ClassFnc>
841 class PruneQueue : public QueueBase<typename Queue::StateId> {
842  public:
843  using StateId = typename Queue::StateId;
844  using Weight = typename Less::Weight;
845 
846  PruneQueue(const std::vector<Weight> &distance, Queue *queue,
847  const Less &less, const ClassFnc &class_fnc, Weight threshold)
849  distance_(distance),
850  queue_(queue),
851  less_(less),
852  class_fnc_(class_fnc),
853  threshold_(std::move(threshold)) {}
854 
855  virtual ~PruneQueue() = default;
856 
857  StateId Head() const override { return queue_->Head(); }
858 
859  void Enqueue(StateId s) override {
860  const auto c = class_fnc_(s);
861  if (c >= class_distance_.size()) {
862  class_distance_.resize(c + 1, Weight::Zero());
863  }
864  if (less_(distance_[s], class_distance_[c])) {
865  class_distance_[c] = distance_[s];
866  }
867  // Enqueues only if below threshold limit.
868  const auto limit = Times(class_distance_[c], threshold_);
869  if (less_(distance_[s], limit)) queue_->Enqueue(s);
870  }
871 
872  void Dequeue() override { queue_->Dequeue(); }
873 
874  void Update(StateId s) override {
875  const auto c = class_fnc_(s);
876  if (less_(distance_[s], class_distance_[c])) {
877  class_distance_[c] = distance_[s];
878  }
879  queue_->Update(s);
880  }
881 
882  bool Empty() const override { return queue_->Empty(); }
883 
884  void Clear() override { queue_->Clear(); }
885 
886  private:
887  const std::vector<Weight> &distance_; // Shortest distance to state.
888  std::unique_ptr<Queue> queue_;
889  const Less &less_; // Borrowed reference.
890  const ClassFnc &class_fnc_; // Equivalence class functor.
891  Weight threshold_; // Pruning weight threshold.
892  std::vector<Weight> class_distance_; // Shortest distance to class.
893 };
894 
895 // Pruning queue discipline (see above) using the weight's natural order for the
896 // comparison function. The ownership of the queue argument is given to this
897 // class.
898 template <typename Queue, typename Weight, typename ClassFnc>
899 class NaturalPruneQueue final
900  : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
901  public:
902  using StateId = typename Queue::StateId;
903 
904  NaturalPruneQueue(const std::vector<Weight> &distance, Queue *queue,
905  const ClassFnc &class_fnc, Weight threshold)
906  : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
907  distance, queue, NaturalLess<Weight>(), class_fnc, threshold) {}
908 
909  virtual ~NaturalPruneQueue() = default;
910 };
911 
912 // Filter-based pruning queue discipline: enqueues a state only if allowed by
913 // the filter, specified by the state filter functor argument. The underlying
914 // queue discipline is specified by the queue argument. The ownership of the
915 // queue is given to this class.
916 template <typename Queue, typename Filter>
917 class FilterQueue : public QueueBase<typename Queue::StateId> {
918  public:
919  using StateId = typename Queue::StateId;
920 
921  FilterQueue(Queue *queue, const Filter &filter)
922  : QueueBase<StateId>(OTHER_QUEUE), queue_(queue), filter_(filter) {}
923 
924  virtual ~FilterQueue() = default;
925 
926  StateId Head() const final { return queue_->Head(); }
927 
928  // Enqueues only if allowed by state filter.
929  void Enqueue(StateId s) final {
930  if (filter_(s)) queue_->Enqueue(s);
931  }
932 
933  void Dequeue() final { queue_->Dequeue(); }
934 
935  void Update(StateId s) final {}
936 
937  bool Empty() const final { return queue_->Empty(); }
938 
939  void Clear() final { queue_->Clear(); }
940 
941  private:
942  std::unique_ptr<Queue> queue_;
943  const Filter &filter_;
944 };
945 
946 } // namespace fst
947 
948 #endif // FST_QUEUE_H_
virtual bool Empty() const =0
void Clear() override
Definition: queue.h:151
bool Empty() const final
Definition: queue.h:175
bool Empty() const final
Definition: queue.h:121
void Enqueue(StateId s) override
Definition: queue.h:143
StateId Head() const final
Definition: queue.h:402
void Update(StateId s) override
Definition: queue.h:874
void Enqueue(StateId s) override
Definition: queue.h:201
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:248
void Clear() final
Definition: queue.h:123
StateId Head() const final
Definition: queue.h:167
void SetError(bool error)
Definition: queue.h:81
void Clear() override
Definition: queue.h:229
uint64_t uint64
Definition: types.h:32
void Update(StateId) final
Definition: queue.h:119
QueueType
Definition: queue.h:56
NaturalAStarQueue(const std::vector< Weight > &distance, const Estimate &estimate)
Definition: queue.h:811
PruneQueue(const std::vector< Weight > &distance, Queue *queue, const Less &less, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:846
NaturalShortestFirstQueue(const std::vector< Weight > &distance)
Definition: queue.h:275
StateWeightCompare(const std::vector< Weight > &weights, const Less &less)
Definition: queue.h:250
bool Empty() const final
Definition: queue.h:470
StateId Head() const final
Definition: queue.h:113
typename Queue::StateId StateId
Definition: queue.h:919
void Enqueue(StateId s) final
Definition: queue.h:519
void Clear() final
Definition: queue.h:472
void Enqueue(StateId s) final
Definition: queue.h:115
constexpr uint64 kTopSorted
Definition: properties.h:100
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:94
ShortestFirstQueue(Compare comp)
Definition: queue.h:194
void Enqueue(StateId s) final
Definition: queue.h:676
void Enqueue(StateId s) final
Definition: queue.h:451
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
StateId Head() const override
Definition: queue.h:199
void Clear() final
Definition: queue.h:563
StateId Head() const override
Definition: queue.h:324
void Update(StateId) override
Definition: queue.h:147
bool operator()(StateId s1, StateId s2) const
Definition: queue.h:789
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:844
void Update(StateId) final
Definition: queue.h:173
bool Empty() const override
Definition: queue.h:149
TopOrderQueue(const Fst< Arc > &fst, ArcFilter filter)
Definition: queue.h:376
const Weight & operator()(StateId s) const
Definition: queue.h:762
NaturalAStarEstimate(const std::vector< Weight > &beta)
Definition: queue.h:760
void Dequeue() final
Definition: queue.h:117
bool Error() const
Definition: queue.h:83
StateId Head() const override
Definition: queue.h:857
constexpr int kNoStateId
Definition: fst.h:180
virtual void Dequeue()=0
void Dequeue() final
Definition: queue.h:463
QueueBase(QueueType type)
Definition: queue.h:79
virtual uint64 Properties(uint64 mask, bool test) const =0
void Update(StateId) final
Definition: queue.h:420
StateId Head() const final
Definition: queue.h:504
#define FSTERROR()
Definition: util.h:35
void Enqueue(StateId s) override
Definition: queue.h:859
constexpr uint64 kUnweighted
Definition: properties.h:87
void Update(StateId s) final
Definition: queue.h:935
StateId operator()(StateId s) const
Definition: queue.h:829
bool Empty() const final
Definition: queue.h:549
virtual void Update(StateId)=0
void Update(StateId s) final
Definition: queue.h:680
void Enqueue(StateId s) final
Definition: queue.h:169
virtual void Clear()=0
virtual void Enqueue(StateId)=0
constexpr Weight operator()(StateId) const
Definition: queue.h:753
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:783
constexpr uint64 kIdempotent
Definition: weight.h:123
AutoQueue(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > *distance, ArcFilter filter)
Definition: queue.h:594
#define VLOG(level)
Definition: log.h:49
virtual StateId Start() const =0
void Dequeue() final
Definition: queue.h:415
NaturalPruneQueue(const std::vector< Weight > &distance, Queue *queue, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:904
void Dequeue() final
Definition: queue.h:933
void Dequeue() override
Definition: queue.h:210
const Estimate & GetEstimate() const
Definition: queue.h:795
void Clear() final
Definition: queue.h:424
bool Empty() const override
Definition: queue.h:882
void Dequeue() override
Definition: queue.h:145
TopOrderQueue(const std::vector< StateId > &order)
Definition: queue.h:393
AStarWeightCompare(const std::vector< Weight > &weights, const Less &less, const Estimate &estimate)
Definition: queue.h:785
StateId Head() const final
Definition: queue.h:926
virtual ~QueueBase()
Definition: queue.h:75
void Dequeue() final
Definition: queue.h:537
StateId Head() const final
Definition: queue.h:449
StateId Head() const final
Definition: queue.h:674
void Enqueue(StateId s) final
Definition: queue.h:404
void Dequeue() final
Definition: queue.h:171
void Update(StateId s) override
Definition: queue.h:218
void Clear() final
Definition: queue.h:177
void Update(StateId) final
Definition: queue.h:468
constexpr uint64 kAcyclic
Definition: properties.h:92
void Enqueue(StateId s) override
Definition: queue.h:335
SccQueue(const std::vector< StateId > &scc, std::vector< std::unique_ptr< Queue >> *queue)
Definition: queue.h:494
constexpr uint64 kPath
Definition: weight.h:126
bool Empty() const override
Definition: queue.h:227
void Update(StateId s) final
Definition: queue.h:545
FilterQueue(Queue *queue, const Filter &filter)
Definition: queue.h:921
PruneNaturalShortestFirstQueue(const std::vector< Weight > &distance, int threshold)
Definition: queue.h:315
QueueType Type() const
Definition: queue.h:85
Queue::StateId StateId
Definition: queue.h:73
bool operator()(const StateId s1, const StateId s2) const
Definition: queue.h:253
void Enqueue(StateId s) final
Definition: queue.h:929
void Dequeue() override
Definition: queue.h:872
bool Empty() const final
Definition: queue.h:937
void Clear() final
Definition: queue.h:684
virtual StateId Head() const =0
typename Queue::StateId StateId
Definition: queue.h:902
void Clear() override
Definition: queue.h:884
void Clear() final
Definition: queue.h:939
void Dequeue() final
Definition: queue.h:678
StateId Head() const override
Definition: queue.h:141
bool Empty() const final
Definition: queue.h:422
bool Empty() const final
Definition: queue.h:682
constexpr uint64 kCyclic
Definition: properties.h:90
const Compare & GetCompare() const
Definition: queue.h:234