FST  openfst-1.8.4
OpenFst Library
queue.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes for various FST state queues with a unified interface.
19 
20 #ifndef FST_QUEUE_H_
21 #define FST_QUEUE_H_
22 
23 #include <sys/types.h>
24 
25 #include <algorithm>
26 #include <cstdint>
27 #include <memory>
28 #include <queue>
29 #include <type_traits>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/log.h>
34 #include <fst/arcfilter.h>
35 #include <fst/cc-visitors.h>
36 #include <fst/dfs-visit.h>
37 #include <fst/fst.h>
38 #include <fst/heap.h>
39 #include <fst/properties.h>
40 #include <fst/topsort.h>
41 #include <fst/util.h>
42 #include <fst/weight.h>
43 
44 namespace fst {
45 
46 // The Queue interface is:
47 //
48 // template <class S>
49 // class Queue {
50 // public:
51 // using StateId = S;
52 //
53 // // Constructor: may need args (e.g., FST, comparator) for some queues.
54 // Queue(...) override;
55 //
56 // // Returns the head of the queue.
57 // StateId Head() const override;
58 //
59 // // Inserts a state.
60 // void Enqueue(StateId s) override;
61 //
62 // // Removes the head of the queue.
63 // void Dequeue() override;
64 //
65 // // Updates ordering of state s when weight changes, if necessary.
66 // void Update(StateId s) override;
67 //
68 // // Is the queue empty?
69 // bool Empty() const override;
70 //
71 // // Removes all states from the queue.
72 // void Clear() override;
73 // };
74 
75 // State queue types.
76 enum QueueType {
77  TRIVIAL_QUEUE = 0, // Single state queue.
78  FIFO_QUEUE = 1, // First-in, first-out queue.
79  LIFO_QUEUE = 2, // Last-in, first-out queue.
80  SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
81  TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
82  STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
83  SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
84  AUTO_QUEUE = 7, // Auto-selected queue.
86 };
87 
88 // QueueBase, templated on the StateId, is a virtual base class shared by all
89 // queues considered by AutoQueue.
90 template <class S>
91 class QueueBase {
92  public:
93  using StateId = S;
94 
95  virtual ~QueueBase() = default;
96 
97  // Concrete implementation.
98 
99  explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
100 
101  void SetError(bool error) { error_ = error; }
102 
103  bool Error() const { return error_; }
104 
105  QueueType Type() const { return queue_type_; }
106 
107  // Virtual interface.
108 
109  virtual StateId Head() const = 0;
110  virtual void Enqueue(StateId) = 0;
111  virtual void Dequeue() = 0;
112  virtual void Update(StateId) = 0;
113  virtual bool Empty() const = 0;
114  virtual void Clear() = 0;
115 
116  private:
117  QueueType queue_type_;
118  bool error_;
119 };
120 
121 // Trivial queue discipline; one may enqueue at most one state at a time. It
122 // can be used for strongly connected components with only one state and no
123 // self-loops.
124 template <class S>
125 class TrivialQueue : public QueueBase<S> {
126  public:
127  using StateId = S;
128 
130 
131  ~TrivialQueue() override = default;
132 
133  StateId Head() const final { return front_; }
134 
135  void Enqueue(StateId s) final { front_ = s; }
136 
137  void Dequeue() final { front_ = kNoStateId; }
138 
139  void Update(StateId) final {}
140 
141  bool Empty() const final { return front_ == kNoStateId; }
142 
143  void Clear() final { front_ = kNoStateId; }
144 
145  private:
146  StateId front_;
147 };
148 
149 // First-in, first-out queue discipline.
150 //
151 // This is not a final class.
152 template <class S>
153 class FifoQueue : public QueueBase<S> {
154  public:
155  using StateId = S;
156 
158 
159  ~FifoQueue() override = default;
160 
161  StateId Head() const override { return queue_.front(); }
162 
163  void Enqueue(StateId s) override { queue_.push(s); }
164 
165  void Dequeue() override { queue_.pop(); }
166 
167  void Update(StateId) override {}
168 
169  bool Empty() const override { return queue_.empty(); }
170 
171  void Clear() override { queue_ = std::queue<StateId>(); }
172 
173  private:
174  std::queue<StateId> queue_;
175 };
176 
177 // Last-in, first-out queue discipline.
178 template <class S>
179 class LifoQueue : public QueueBase<S> {
180  public:
181  using StateId = S;
182 
184 
185  ~LifoQueue() override = default;
186 
187  StateId Head() const final { return stack_.back(); }
188 
189  void Enqueue(StateId s) final { stack_.push_back(s); }
190 
191  void Dequeue() final { stack_.pop_back(); }
192 
193  void Update(StateId) final {}
194 
195  bool Empty() const final { return stack_.empty(); }
196 
197  void Clear() final { stack_.clear(); }
198 
199  private:
200  std::vector<StateId> stack_;
201 };
202 
203 // Shortest-first queue discipline, templated on the StateId and as well as a
204 // comparison functor used to compare two StateIds. If a (single) state's order
205 // changes, it can be reordered in the queue with a call to Update(). If update
206 // is false, call to Update() does not reorder the queue.
207 //
208 // This is not a final class.
209 template <typename S, typename Compare, bool update = true>
210 class ShortestFirstQueue : public QueueBase<S> {
211  public:
212  using StateId = S;
213 
214  explicit ShortestFirstQueue(Compare comp)
215  : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
216 
217  ~ShortestFirstQueue() override = default;
218 
219  StateId Head() const override { return heap_.Top(); }
220 
221  void Enqueue(StateId s) override {
222  if (update) {
223  for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
224  key_[s] = heap_.Insert(s);
225  } else {
226  heap_.Insert(s);
227  }
228  }
229 
230  void Dequeue() override {
231  if (update) {
232  key_[heap_.Pop()] = kNoStateId;
233  } else {
234  heap_.Pop();
235  }
236  }
237 
238  void Update(StateId s) override {
239  if (!update) return;
240  if (s >= key_.size() || key_[s] == kNoStateId) {
241  Enqueue(s);
242  } else {
243  heap_.Update(key_[s], s);
244  }
245  }
246 
247  bool Empty() const override { return heap_.Empty(); }
248 
249  void Clear() override {
250  heap_.Clear();
251  if (update) key_.clear();
252  }
253 
254  ssize_t Size() const { return heap_.Size(); }
255 
256  const Compare &GetCompare() const { return heap_.GetCompare(); }
257 
258  private:
260  std::vector<ssize_t> key_;
261 };
262 
263 namespace internal {
264 
265 // Given a vector that maps from states to weights, and a comparison functor
266 // for weights, this class defines a comparison function object between states.
267 template <typename StateId, typename Less>
269  public:
270  using Weight = typename Less::Weight;
271 
272  StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
273  : weights_(weights), less_(less) {}
274 
275  bool operator()(const StateId s1, const StateId s2) const {
276  return less_(weights_[s1], weights_[s2]);
277  }
278 
279  private:
280  // Borrowed references.
281  const std::vector<Weight> &weights_;
282  const Less &less_;
283 };
284 
285 // Comparison that can never be instantiated. Useful only to pass a pointer to
286 // this to a function that needs a comparison when it is known that the pointer
287 // will always be null.
288 template <class W>
289 struct ErrorLess {
290  using Weight = W;
292  FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type();
293  }
294  bool operator()(const Weight &, const Weight &) const { return false; }
295 };
296 
297 } // namespace internal
298 
299 // Shortest-first queue discipline, templated on the StateId and Weight, is
300 // specialized to use the weight's provided order for the comparison function.
301 // See NaturalShortestFirstQueue for the specialization using natural order.
302 template <typename S, typename Weight, typename Less>
304  : public ShortestFirstQueue<S, internal::StateWeightCompare<S, Less>> {
305  public:
306  using StateId = S;
308 
309  explicit CustomShortestFirstQueue(const std::vector<Weight> &distance)
310  : ShortestFirstQueue<StateId, Compare>(Compare(distance, Less())) {}
311 
312  ~CustomShortestFirstQueue() override = default;
313 };
314 
315 // Shortest-first queue discipline using the weight's natural order.
316 // Requires Weight is idempotent (due to use of NaturalLess).
317 template <typename S, typename Weight>
320 
321 // In a shortest path computation on a lattice-like FST, we may keep many old
322 // nonviable paths as a part of the search. Since the search process always
323 // expands the lowest cost path next, that lowest cost path may be a very old
324 // nonviable path instead of one we expect to lead to a shortest path.
325 //
326 // For instance, suppose that the current best path in an alignment has
327 // traversed 500 arcs with a cost of 10. We may also have a bad path in
328 // the queue that has traversed only 40 arcs but also has a cost of 10.
329 // This path is very unlikely to lead to a reasonable alignment, so this queue
330 // can prune it from the search space.
331 //
332 // This queue relies on the caller using a shortest-first exploration order
333 // like this:
334 // while (true) {
335 // StateId head = queue.Head();
336 // queue.Dequeue();
337 // for (const auto& arc : GetArcs(fst, head)) {
338 // queue.Enqueue(arc.nextstate);
339 // }
340 // }
341 // We use this assumption to guess that there is an arc between Head and the
342 // Enqueued state; this is how the number of path steps is measured.
343 template <typename S, typename Weight>
345  : public NaturalShortestFirstQueue<S, Weight> {
346  public:
347  using StateId = S;
349 
350  PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
351  ssize_t arc_threshold, ssize_t state_limit = 0)
352  : Base(distance),
353  arc_threshold_(arc_threshold),
354  state_limit_(state_limit),
355  head_steps_(0),
356  max_head_steps_(0) {}
357 
358  ~PruneNaturalShortestFirstQueue() override = default;
359 
360  StateId Head() const override {
361  const auto head = Base::Head();
362  // Stores the number of steps from the start of the graph to this state
363  // along the shortest-weight path.
364  if (head < steps_.size()) {
365  max_head_steps_ = std::max(steps_[head], max_head_steps_);
366  head_steps_ = steps_[head];
367  }
368  return head;
369  }
370 
371  void Enqueue(StateId s) override {
372  // We assume that there is an arc between the Head() state and this
373  // Enqueued state.
374  const ssize_t state_steps = head_steps_ + 1;
375  if (s >= steps_.size()) {
376  steps_.resize(s + 1, state_steps);
377  }
378  // This is the number of arcs in the minimum cost path from Start to s.
379  steps_[s] = state_steps;
380 
381  // Adjust the threshold in cases where path step thresholding wasn't
382  // enough to keep the queue small.
383  ssize_t adjusted_threshold = arc_threshold_;
384  if (Base::Size() > state_limit_ && state_limit_ > 0) {
385  adjusted_threshold = std::max<ssize_t>(
386  0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
387  }
388 
389  if (state_steps > (max_head_steps_ - adjusted_threshold) ||
390  arc_threshold_ < 0) {
391  if (adjusted_threshold == 0 && state_limit_ > 0) {
392  // If the queue is continuing to grow without bound, we follow any
393  // path that makes progress and clear the rest.
394  Base::Clear();
395  }
396  Base::Enqueue(s);
397  }
398  }
399 
400  private:
401  // A dense map from StateId to the number of arcs in the minimum weight
402  // path from Start to this state.
403  std::vector<ssize_t> steps_;
404  // We only keep paths that are within this number of arcs (not weight!)
405  // of the longest path.
406  const ssize_t arc_threshold_;
407  // If the size of the queue climbs above this number, we increase the
408  // threshold to reduce the amount of work we have to do.
409  const ssize_t state_limit_;
410 
411  // The following are mutable because Head() is const.
412  // The number of arcs traversed in the minimum cost path from the start
413  // state to the current Head() state.
414  mutable ssize_t head_steps_;
415  // The maximum number of arcs traversed by any low-cost path so far.
416  mutable ssize_t max_head_steps_;
417 };
418 
419 // Topological-order queue discipline, templated on the StateId. States are
420 // ordered in the queue topologically. The FST must be acyclic.
421 template <class S>
422 class TopOrderQueue : public QueueBase<S> {
423  public:
424  using StateId = S;
425 
426  // This constructor computes the topological order. It accepts an arc filter
427  // to limit the transitions considered in that computation (e.g., only the
428  // epsilon graph).
429  template <class Arc, class ArcFilter>
430  TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
432  front_(0),
433  back_(kNoStateId),
434  order_(0),
435  state_(0) {
436  bool acyclic;
437  TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
438  DfsVisit(fst, &top_order_visitor, filter);
439  if (!acyclic) {
440  FSTERROR() << "TopOrderQueue: FST is not acyclic";
442  }
443  state_.resize(order_.size(), kNoStateId);
444  }
445 
446  // This constructor is passed the pre-computed topological order.
447  explicit TopOrderQueue(const std::vector<StateId> &order)
449  front_(0),
450  back_(kNoStateId),
451  order_(order),
452  state_(order.size(), kNoStateId) {}
453 
454  ~TopOrderQueue() override = default;
455 
456  StateId Head() const final { return state_[front_]; }
457 
458  void Enqueue(StateId s) final {
459  if (front_ > back_) {
460  front_ = back_ = order_[s];
461  } else if (order_[s] > back_) {
462  back_ = order_[s];
463  } else if (order_[s] < front_) {
464  front_ = order_[s];
465  }
466  state_[order_[s]] = s;
467  }
468 
469  void Dequeue() final {
470  state_[front_] = kNoStateId;
471  while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
472  }
473 
474  void Update(StateId) final {}
475 
476  bool Empty() const final { return front_ > back_; }
477 
478  void Clear() final {
479  for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
480  back_ = kNoStateId;
481  front_ = 0;
482  }
483 
484  private:
485  StateId front_;
486  StateId back_;
487  std::vector<StateId> order_;
488  std::vector<StateId> state_;
489 };
490 
491 // State order queue discipline, templated on the StateId. States are ordered in
492 // the queue by state ID.
493 template <class S>
494 class StateOrderQueue : public QueueBase<S> {
495  public:
496  using StateId = S;
497 
499  : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
500 
501  ~StateOrderQueue() override = default;
502 
503  StateId Head() const final { return front_; }
504 
505  void Enqueue(StateId s) final {
506  if (front_ > back_) {
507  front_ = back_ = s;
508  } else if (s > back_) {
509  back_ = s;
510  } else if (s < front_) {
511  front_ = s;
512  }
513  while (enqueued_.size() <= s) enqueued_.push_back(false);
514  enqueued_[s] = true;
515  }
516 
517  void Dequeue() final {
518  enqueued_[front_] = false;
519  while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
520  }
521 
522  void Update(StateId) final {}
523 
524  bool Empty() const final { return front_ > back_; }
525 
526  void Clear() final {
527  for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
528  front_ = 0;
529  back_ = kNoStateId;
530  }
531 
532  private:
533  StateId front_;
534  StateId back_;
535  std::vector<bool> enqueued_;
536 };
537 
538 // SCC topological-order meta-queue discipline, templated on the StateId and a
539 // queue used inside each SCC. It visits the SCCs of an FST in topological
540 // order. Its constructor is passed the queues to to use within an SCC.
541 template <class S, class Queue>
542 class SccQueue : public QueueBase<S> {
543  public:
544  using StateId = S;
545 
546  // Constructor takes a vector specifying the SCC number per state and a
547  // vector giving the queue to use per SCC number.
548  SccQueue(const std::vector<StateId> &scc,
549  std::vector<std::unique_ptr<Queue>> *queue)
551  queue_(queue),
552  scc_(scc),
553  front_(0),
554  back_(kNoStateId) {}
555 
556  ~SccQueue() override = default;
557 
558  StateId Head() const final {
559  while ((front_ <= back_) &&
560  (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
561  (((*queue_)[front_] == nullptr) &&
562  ((front_ >= trivial_queue_.size()) ||
563  (trivial_queue_[front_] == kNoStateId))))) {
564  ++front_;
565  }
566  if ((*queue_)[front_]) {
567  return (*queue_)[front_]->Head();
568  } else {
569  return trivial_queue_[front_];
570  }
571  }
572 
573  void Enqueue(StateId s) final {
574  if (front_ > back_) {
575  front_ = back_ = scc_[s];
576  } else if (scc_[s] > back_) {
577  back_ = scc_[s];
578  } else if (scc_[s] < front_) {
579  front_ = scc_[s];
580  }
581  if ((*queue_)[scc_[s]]) {
582  (*queue_)[scc_[s]]->Enqueue(s);
583  } else {
584  while (trivial_queue_.size() <= scc_[s]) {
585  trivial_queue_.push_back(kNoStateId);
586  }
587  trivial_queue_[scc_[s]] = s;
588  }
589  }
590 
591  void Dequeue() final {
592  if ((*queue_)[front_]) {
593  (*queue_)[front_]->Dequeue();
594  } else if (front_ < trivial_queue_.size()) {
595  trivial_queue_[front_] = kNoStateId;
596  }
597  }
598 
599  void Update(StateId s) final {
600  if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
601  }
602 
603  bool Empty() const final {
604  // Queues SCC number back_ is not empty unless back_ == front_.
605  if (front_ < back_) {
606  return false;
607  } else if (front_ > back_) {
608  return true;
609  } else if ((*queue_)[front_]) {
610  return (*queue_)[front_]->Empty();
611  } else {
612  return (front_ >= trivial_queue_.size()) ||
613  (trivial_queue_[front_] == kNoStateId);
614  }
615  }
616 
617  void Clear() final {
618  for (StateId i = front_; i <= back_; ++i) {
619  if ((*queue_)[i]) {
620  (*queue_)[i]->Clear();
621  } else if (i < trivial_queue_.size()) {
622  trivial_queue_[i] = kNoStateId;
623  }
624  }
625  front_ = 0;
626  back_ = kNoStateId;
627  }
628 
629  private:
630  std::vector<std::unique_ptr<Queue>> *queue_;
631  const std::vector<StateId> &scc_;
632  mutable StateId front_;
633  StateId back_;
634  std::vector<StateId> trivial_queue_;
635 };
636 
637 // Automatic queue discipline. It selects a queue discipline for a given FST
638 // based on its properties.
639 template <class S>
640 class AutoQueue : public QueueBase<S> {
641  public:
642  using StateId = S;
643 
644  // This constructor takes a state distance vector that, if non-null and if
645  // the Weight type has the path property, will entertain the shortest-first
646  // queue using the natural order w.r.t to the distance.
647  template <class Arc, class ArcFilter>
649  const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
651  using Weight = typename Arc::Weight;
652  // We need to have variables of type Less and Compare, so we use
653  // ErrorLess if the type NaturalLess<Weight> cannot be instantiated due
654  // to lack of path property.
655  using Less = std::conditional_t<IsPath<Weight>::value, NaturalLess<Weight>,
658  // First checks if the FST is known to have these properties.
659  const auto props =
661  if ((props & kTopSorted) || fst.Start() == kNoStateId) {
662  queue_ = std::make_unique<StateOrderQueue<StateId>>();
663  VLOG(2) << "AutoQueue: using state-order discipline";
664  } else if (props & kAcyclic) {
665  queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
666  VLOG(2) << "AutoQueue: using top-order discipline";
667  } else if ((props & kUnweighted) && IsIdempotent<Weight>::value) {
668  queue_ = std::make_unique<LifoQueue<StateId>>();
669  VLOG(2) << "AutoQueue: using LIFO discipline";
670  } else {
671  uint64_t properties;
672  // Decomposes into strongly-connected components.
673  SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
674  DfsVisit(fst, &scc_visitor, filter);
675  auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
676  std::vector<QueueType> queue_types(nscc);
677  std::unique_ptr<Less> less;
678  std::unique_ptr<Compare> comp;
679  if constexpr (IsPath<Weight>::value) {
680  if (distance) {
681  less = std::make_unique<Less>();
682  comp = std::make_unique<Compare>(*distance, *less);
683  }
684  }
685  // Finds the queue type to use per SCC.
686  bool unweighted;
687  bool all_trivial;
688  SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
689  &unweighted);
690  // If unweighted and semiring is idempotent, uses LIFO queue.
691  if (unweighted) {
692  queue_ = std::make_unique<LifoQueue<StateId>>();
693  VLOG(2) << "AutoQueue: using LIFO discipline";
694  return;
695  }
696  // If all the SCC are trivial, the FST is acyclic and the scc number gives
697  // the topological order.
698  if (all_trivial) {
699  queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
700  VLOG(2) << "AutoQueue: using top-order discipline";
701  return;
702  }
703  VLOG(2) << "AutoQueue: using SCC meta-discipline";
704  queues_.resize(nscc);
705  for (StateId i = 0; i < nscc; ++i) {
706  switch (queue_types[i]) {
707  case TRIVIAL_QUEUE:
708  queues_[i].reset();
709  VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
710  break;
712  // The IsPath test is not needed for correctness. It just saves
713  // instantiating a ShortestFirstQueue that can never be called.
714  if constexpr (IsPath<Weight>::value) {
715  queues_[i] =
716  std::make_unique<ShortestFirstQueue<StateId, Compare, false>>(
717  *comp);
718  VLOG(3) << "AutoQueue: SCC #" << i
719  << ": using shortest-first discipline";
720  } else {
721  // SccQueueType should ensure this can never happen.
722  FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight "
723  << Weight::Type();
724  queues_[i].reset();
725  }
726  break;
727  case LIFO_QUEUE:
728  queues_[i] = std::make_unique<LifoQueue<StateId>>();
729  VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
730  break;
731  case FIFO_QUEUE:
732  default:
733  queues_[i] = std::make_unique<FifoQueue<StateId>>();
734  VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
735  break;
736  }
737  }
738  queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
739  scc_, &queues_);
740  }
741  }
742 
743  ~AutoQueue() override = default;
744 
745  StateId Head() const final { return queue_->Head(); }
746 
747  void Enqueue(StateId s) final { queue_->Enqueue(s); }
748 
749  void Dequeue() final { queue_->Dequeue(); }
750 
751  void Update(StateId s) final { queue_->Update(s); }
752 
753  bool Empty() const final { return queue_->Empty(); }
754 
755  void Clear() final { queue_->Clear(); }
756 
757  private:
758  template <class Arc, class ArcFilter, class Less>
759  static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
760  std::vector<QueueType> *queue_types,
761  ArcFilter filter, Less *less, bool *all_trivial,
762  bool *unweighted);
763 
764  std::unique_ptr<QueueBase<StateId>> queue_;
765  std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
766  std::vector<StateId> scc_;
767 };
768 
769 // Examines the states in an FST's strongly connected components and determines
770 // which type of queue to use per SCC. Stores result as a vector of QueueTypes
771 // which is assumed to have length equal to the number of SCCs. An arc filter
772 // is used to limit the transitions considered (e.g., only the epsilon graph).
773 // The argument all_trivial is set to true if every queue is the trivial queue.
774 // The argument unweighted is set to true if the semiring is idempotent and all
775 // the arc weights are equal to Zero() or One().
776 template <class StateId>
777 template <class Arc, class ArcFilter, class Less>
779  const std::vector<StateId> &scc,
780  std::vector<QueueType> *queue_type,
781  ArcFilter filter, Less *less,
782  bool *all_trivial, bool *unweighted) {
783  using StateId = typename Arc::StateId;
784  using Weight = typename Arc::Weight;
785  *all_trivial = true;
786  *unweighted = true;
787  for (StateId i = 0; i < queue_type->size(); ++i) {
788  (*queue_type)[i] = TRIVIAL_QUEUE;
789  }
790  for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
791  const auto state = sit.Value();
792  for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
793  const auto &arc = ait.Value();
794  if (!filter(arc)) continue;
795  if (scc[state] == scc[arc.nextstate]) {
796  auto &type = (*queue_type)[scc[state]];
797  if constexpr (!IsPath<Weight>::value) {
798  type = FIFO_QUEUE;
799  } else if (!less || (*less)(arc.weight, Weight::One())) {
800  type = FIFO_QUEUE;
801  } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
803  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
804  type = SHORTEST_FIRST_QUEUE;
805  } else {
806  type = LIFO_QUEUE;
807  }
808  }
809  if (type != TRIVIAL_QUEUE) *all_trivial = false;
810  }
812  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
813  *unweighted = false;
814  }
815  }
816  }
817 }
818 
819 // An A* estimate is a function object that maps from a state ID to an
820 // estimate of the shortest distance to the final states.
821 
822 // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
823 // algorithm.
824 template <typename StateId, typename Weight>
826  constexpr Weight operator()(StateId) const { return Weight::One(); }
827 };
828 
829 // A non-trivial A* estimate using a vector of the estimated future costs.
830 template <typename StateId, typename Weight>
832  public:
833  NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
834 
835  const Weight &operator()(StateId s) const {
836  return (s < beta_.size()) ? beta_[s] : kZero;
837  }
838 
839  private:
840  static constexpr Weight kZero = Weight::Zero();
841 
842  const std::vector<Weight> &beta_;
843 };
844 
845 // Given a vector that maps from states to weights representing the shortest
846 // distance from the initial state, a comparison function object between
847 // weights, and an estimate of the shortest distance to the final states, this
848 // class defines a comparison function object between states.
849 template <typename S, typename Less, typename Estimate>
851  public:
852  using StateId = S;
853  using Weight = typename Less::Weight;
854 
855  AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
856  const Estimate &estimate)
857  : weights_(weights), less_(less), estimate_(estimate) {}
858 
859  bool operator()(StateId s1, StateId s2) const {
860  const auto w1 = Times(weights_[s1], estimate_(s1));
861  const auto w2 = Times(weights_[s2], estimate_(s2));
862  return less_(w1, w2);
863  }
864 
865  const Estimate &GetEstimate() const { return estimate_; }
866 
867  private:
868  const std::vector<Weight> &weights_;
869  const Less &less_;
870  const Estimate &estimate_;
871 };
872 
873 // A* queue discipline templated on StateId, Weight, and Estimate.
874 template <typename S, typename Weight, typename Estimate>
876  : public ShortestFirstQueue<
877  S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
878  public:
879  using StateId = S;
881 
882  NaturalAStarQueue(const std::vector<Weight> &distance,
883  const Estimate &estimate)
885  Compare(distance, less_, estimate)) {}
886 
887  ~NaturalAStarQueue() override = default;
888 
889  private:
890  // This is non-static because the constructor for non-idempotent weights will
891  // result in an error.
892  const NaturalLess<Weight> less_{};
893 };
894 
895 // A state equivalence class is a function object that maps from a state ID to
896 // an equivalence class (state) ID. The trivial equivalence class maps a state
897 // ID to itself.
898 template <typename StateId>
900  StateId operator()(StateId s) const { return s; }
901 };
902 
903 // Distance-based pruning queue discipline: Enqueues a state only when its
904 // shortest distance (so far), as specified by distance, is less than (as
905 // specified by comp) the shortest distance Times() the threshold to any state
906 // in the same equivalence class, as specified by the functor class_func. The
907 // underlying queue discipline is specified by queue.
908 //
909 // This is not a final class.
910 template <typename Queue, typename Less, typename ClassFnc>
911 class PruneQueue : public QueueBase<typename Queue::StateId> {
912  public:
913  using StateId = typename Queue::StateId;
914  using Weight = typename Less::Weight;
915 
916  PruneQueue(const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
917  const Less &less, const ClassFnc &class_fnc, Weight threshold)
919  distance_(distance),
920  queue_(std::move(queue)),
921  less_(less),
922  class_fnc_(class_fnc),
923  threshold_(std::move(threshold)) {}
924 
925  ~PruneQueue() override = default;
926 
927  StateId Head() const override { return queue_->Head(); }
928 
929  void Enqueue(StateId s) override {
930  const auto c = class_fnc_(s);
931  if (c >= class_distance_.size()) {
932  class_distance_.resize(c + 1, Weight::Zero());
933  }
934  if (less_(distance_[s], class_distance_[c])) {
935  class_distance_[c] = distance_[s];
936  }
937  // Enqueues only if below threshold limit.
938  const auto limit = Times(class_distance_[c], threshold_);
939  if (less_(distance_[s], limit)) queue_->Enqueue(s);
940  }
941 
942  void Dequeue() override { queue_->Dequeue(); }
943 
944  void Update(StateId s) override {
945  const auto c = class_fnc_(s);
946  if (less_(distance_[s], class_distance_[c])) {
947  class_distance_[c] = distance_[s];
948  }
949  queue_->Update(s);
950  }
951 
952  bool Empty() const override { return queue_->Empty(); }
953 
954  void Clear() override { queue_->Clear(); }
955 
956  private:
957  const std::vector<Weight> &distance_; // Shortest distance to state.
958  std::unique_ptr<Queue> queue_;
959  const Less &less_; // Borrowed reference.
960  const ClassFnc &class_fnc_; // Equivalence class functor.
961  Weight threshold_; // Pruning weight threshold.
962  std::vector<Weight> class_distance_; // Shortest distance to class.
963 };
964 
965 // Pruning queue discipline (see above) using the weight's natural order for the
966 // comparison function. The ownership of the queue argument is given to this
967 // class.
968 template <typename Queue, typename Weight, typename ClassFnc>
969 class NaturalPruneQueue final
970  : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
971  public:
972  using StateId = typename Queue::StateId;
973 
974  NaturalPruneQueue(const std::vector<Weight> &distance,
975  std::unique_ptr<Queue> queue, const ClassFnc &class_fnc,
976  Weight threshold)
977  : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
978  distance, std::move(queue), NaturalLess<Weight>(), class_fnc,
979  threshold) {}
980 
981  ~NaturalPruneQueue() override = default;
982 };
983 
984 // Filter-based pruning queue discipline: enqueues a state only if allowed by
985 // the filter, specified by the state filter functor argument. The underlying
986 // queue discipline is specified by the queue argument.
987 template <typename Queue, typename Filter>
988 class FilterQueue : public QueueBase<typename Queue::StateId> {
989  public:
990  using StateId = typename Queue::StateId;
991 
992  FilterQueue(std::unique_ptr<Queue> queue, const Filter &filter)
994  queue_(std::move(queue)),
995  filter_(filter) {}
996 
997  ~FilterQueue() override = default;
998 
999  StateId Head() const final { return queue_->Head(); }
1000 
1001  // Enqueues only if allowed by state filter.
1002  void Enqueue(StateId s) final {
1003  if (filter_(s)) queue_->Enqueue(s);
1004  }
1005 
1006  void Dequeue() final { queue_->Dequeue(); }
1007 
1008  void Update(StateId s) final {}
1009 
1010  bool Empty() const final { return queue_->Empty(); }
1011 
1012  void Clear() final { queue_->Clear(); }
1013 
1014  private:
1015  std::unique_ptr<Queue> queue_;
1016  const Filter &filter_;
1017 };
1018 
1019 } // namespace fst
1020 
1021 #endif // FST_QUEUE_H_
virtual bool Empty() const =0
void Clear() override
Definition: queue.h:171
constexpr uint64_t kCyclic
Definition: properties.h:109
bool Empty() const final
Definition: queue.h:195
bool Empty() const final
Definition: queue.h:141
void Enqueue(StateId s) override
Definition: queue.h:163
StateId Head() const final
Definition: queue.h:456
void Update(StateId s) override
Definition: queue.h:944
void Enqueue(StateId s) override
Definition: queue.h:221
void Clear() final
Definition: queue.h:143
StateId Head() const final
Definition: queue.h:187
void SetError(bool error)
Definition: queue.h:101
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void Clear() override
Definition: queue.h:249
FilterQueue(std::unique_ptr< Queue > queue, const Filter &filter)
Definition: queue.h:992
void Update(StateId) final
Definition: queue.h:139
QueueType
Definition: queue.h:76
NaturalAStarQueue(const std::vector< Weight > &distance, const Estimate &estimate)
Definition: queue.h:882
StateWeightCompare(const std::vector< Weight > &weights, const Less &less)
Definition: queue.h:272
bool Empty() const final
Definition: queue.h:524
StateId Head() const final
Definition: queue.h:133
typename Queue::StateId StateId
Definition: queue.h:990
void Enqueue(StateId s) final
Definition: queue.h:573
void Clear() final
Definition: queue.h:526
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
virtual ~QueueBase()=default
void Enqueue(StateId s) final
Definition: queue.h:135
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:112
bool operator()(const Weight &, const Weight &) const
Definition: queue.h:294
ShortestFirstQueue(Compare comp)
Definition: queue.h:214
void Enqueue(StateId s) final
Definition: queue.h:747
void Enqueue(StateId s) final
Definition: queue.h:505
StateId Head() const override
Definition: queue.h:219
void Clear() final
Definition: queue.h:617
StateId Head() const override
Definition: queue.h:360
void Update(StateId) override
Definition: queue.h:167
bool operator()(StateId s1, StateId s2) const
Definition: queue.h:859
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:914
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:159
void Update(StateId) final
Definition: queue.h:193
bool Empty() const override
Definition: queue.h:169
TopOrderQueue(const Fst< Arc > &fst, ArcFilter filter)
Definition: queue.h:430
constexpr uint64_t kTopSorted
Definition: properties.h:119
const Weight & operator()(StateId s) const
Definition: queue.h:835
NaturalAStarEstimate(const std::vector< Weight > &beta)
Definition: queue.h:833
void Dequeue() final
Definition: queue.h:137
bool Error() const
Definition: queue.h:103
StateId Head() const override
Definition: queue.h:927
constexpr int kNoStateId
Definition: fst.h:195
virtual void Dequeue()=0
void Dequeue() final
Definition: queue.h:517
QueueBase(QueueType type)
Definition: queue.h:99
void Update(StateId) final
Definition: queue.h:474
StateId Head() const final
Definition: queue.h:558
#define FSTERROR()
Definition: util.h:57
void Enqueue(StateId s) override
Definition: queue.h:929
void Update(StateId s) final
Definition: queue.h:1008
StateId operator()(StateId s) const
Definition: queue.h:900
bool Empty() const final
Definition: queue.h:603
virtual void Update(StateId)=0
void Update(StateId s) final
Definition: queue.h:751
void Enqueue(StateId s) final
Definition: queue.h:189
constexpr uint64_t kAcyclic
Definition: properties.h:111
virtual void Clear()=0
virtual void Enqueue(StateId)=0
constexpr Weight operator()(StateId) const
Definition: queue.h:826
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:853
AutoQueue(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > *distance, ArcFilter filter)
Definition: queue.h:648
#define VLOG(level)
Definition: log.h:54
virtual StateId Start() const =0
void Dequeue() final
Definition: queue.h:469
void Dequeue() final
Definition: queue.h:1006
void Dequeue() override
Definition: queue.h:230
const Estimate & GetEstimate() const
Definition: queue.h:865
void Clear() final
Definition: queue.h:478
bool Empty() const override
Definition: queue.h:952
void Dequeue() override
Definition: queue.h:165
TopOrderQueue(const std::vector< StateId > &order)
Definition: queue.h:447
AStarWeightCompare(const std::vector< Weight > &weights, const Less &less, const Estimate &estimate)
Definition: queue.h:855
StateId Head() const final
Definition: queue.h:999
void Dequeue() final
Definition: queue.h:591
StateId Head() const final
Definition: queue.h:503
constexpr uint64_t kUnweighted
Definition: properties.h:106
StateId Head() const final
Definition: queue.h:745
void Enqueue(StateId s) final
Definition: queue.h:458
void Dequeue() final
Definition: queue.h:191
void Update(StateId s) override
Definition: queue.h:238
void Clear() final
Definition: queue.h:197
void Update(StateId) final
Definition: queue.h:522
NaturalPruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:974
CustomShortestFirstQueue(const std::vector< Weight > &distance)
Definition: queue.h:309
void Enqueue(StateId s) override
Definition: queue.h:371
SccQueue(const std::vector< StateId > &scc, std::vector< std::unique_ptr< Queue >> *queue)
Definition: queue.h:548
bool Empty() const override
Definition: queue.h:247
void Update(StateId s) final
Definition: queue.h:599
QueueType Type() const
Definition: queue.h:105
Queue::StateId StateId
Definition: queue.h:93
bool operator()(const StateId s1, const StateId s2) const
Definition: queue.h:275
void Enqueue(StateId s) final
Definition: queue.h:1002
PruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const Less &less, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:916
void Dequeue() override
Definition: queue.h:942
bool Empty() const final
Definition: queue.h:1010
PruneNaturalShortestFirstQueue(const std::vector< Weight > &distance, ssize_t arc_threshold, ssize_t state_limit=0)
Definition: queue.h:350
void Clear() final
Definition: queue.h:755
ssize_t Size() const
Definition: queue.h:254
virtual StateId Head() const =0
typename Queue::StateId StateId
Definition: queue.h:972
void Clear() override
Definition: queue.h:954
void Clear() final
Definition: queue.h:1012
void Dequeue() final
Definition: queue.h:749
StateId Head() const override
Definition: queue.h:161
bool Empty() const final
Definition: queue.h:476
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
Definition: weight.h:162
bool Empty() const final
Definition: queue.h:753
const Compare & GetCompare() const
Definition: queue.h:256