FST  openfst-1.8.2
OpenFst Library
queue.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes for various FST state queues with a unified interface.
19 
20 #ifndef FST_QUEUE_H_
21 #define FST_QUEUE_H_
22 
23 #include <cstdint>
24 #include <memory>
25 #include <queue>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29 
30 #include <fst/log.h>
31 
32 #include <fst/arcfilter.h>
33 #include <fst/connect.h>
34 #include <fst/heap.h>
35 #include <fst/topsort.h>
36 #include <fst/weight.h>
37 
38 namespace fst {
39 
40 // The Queue interface is:
41 //
42 // template <class S>
43 // class Queue {
44 // public:
45 // using StateId = S;
46 //
47 // // Constructor: may need args (e.g., FST, comparator) for some queues.
48 // Queue(...) override;
49 //
50 // // Returns the head of the queue.
51 // StateId Head() const override;
52 //
53 // // Inserts a state.
54 // void Enqueue(StateId s) override;
55 //
56 // // Removes the head of the queue.
57 // void Dequeue() override;
58 //
59 // // Updates ordering of state s when weight changes, if necessary.
60 // void Update(StateId s) override;
61 //
62 // // Is the queue empty?
63 // bool Empty() const override;
64 //
65 // // Removes all states from the queue.
66 // void Clear() override;
67 // };
68 
69 // State queue types.
70 enum QueueType {
71  TRIVIAL_QUEUE = 0, // Single state queue.
72  FIFO_QUEUE = 1, // First-in, first-out queue.
73  LIFO_QUEUE = 2, // Last-in, first-out queue.
74  SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
75  TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
76  STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
77  SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
78  AUTO_QUEUE = 7, // Auto-selected queue.
80 };
81 
82 // QueueBase, templated on the StateId, is a virtual base class shared by all
83 // queues considered by AutoQueue.
84 template <class S>
85 class QueueBase {
86  public:
87  using StateId = S;
88 
89  virtual ~QueueBase() {}
90 
91  // Concrete implementation.
92 
93  explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
94 
95  void SetError(bool error) { error_ = error; }
96 
97  bool Error() const { return error_; }
98 
99  QueueType Type() const { return queue_type_; }
100 
101  // Virtual interface.
102 
103  virtual StateId Head() const = 0;
104  virtual void Enqueue(StateId) = 0;
105  virtual void Dequeue() = 0;
106  virtual void Update(StateId) = 0;
107  virtual bool Empty() const = 0;
108  virtual void Clear() = 0;
109 
110  private:
111  QueueType queue_type_;
112  bool error_;
113 };
114 
115 // Trivial queue discipline; one may enqueue at most one state at a time. It
116 // can be used for strongly connected components with only one state and no
117 // self-loops.
118 template <class S>
119 class TrivialQueue : public QueueBase<S> {
120  public:
121  using StateId = S;
122 
124 
125  ~TrivialQueue() override = default;
126 
127  StateId Head() const final { return front_; }
128 
129  void Enqueue(StateId s) final { front_ = s; }
130 
131  void Dequeue() final { front_ = kNoStateId; }
132 
133  void Update(StateId) final {}
134 
135  bool Empty() const final { return front_ == kNoStateId; }
136 
137  void Clear() final { front_ = kNoStateId; }
138 
139  private:
140  StateId front_;
141 };
142 
143 // First-in, first-out queue discipline.
144 //
145 // This is not a final class.
146 template <class S>
147 class FifoQueue : public QueueBase<S> {
148  public:
149  using StateId = S;
150 
152 
153  ~FifoQueue() override = default;
154 
155  StateId Head() const override { return queue_.front(); }
156 
157  void Enqueue(StateId s) override { queue_.push(s); }
158 
159  void Dequeue() override { queue_.pop(); }
160 
161  void Update(StateId) override {}
162 
163  bool Empty() const override { return queue_.empty(); }
164 
165  void Clear() override { queue_ = std::queue<StateId>(); }
166 
167  private:
168  std::queue<StateId> queue_;
169 };
170 
171 // Last-in, first-out queue discipline.
172 template <class S>
173 class LifoQueue : public QueueBase<S> {
174  public:
175  using StateId = S;
176 
178 
179  ~LifoQueue() override = default;
180 
181  StateId Head() const final { return stack_.back(); }
182 
183  void Enqueue(StateId s) final { stack_.push_back(s); }
184 
185  void Dequeue() final { stack_.pop_back(); }
186 
187  void Update(StateId) final {}
188 
189  bool Empty() const final { return stack_.empty(); }
190 
191  void Clear() final { stack_.clear(); }
192 
193  private:
194  std::vector<StateId> stack_;
195 };
196 
197 // Shortest-first queue discipline, templated on the StateId and as well as a
198 // comparison functor used to compare two StateIds. If a (single) state's order
199 // changes, it can be reordered in the queue with a call to Update(). If update
200 // is false, call to Update() does not reorder the queue.
201 //
202 // This is not a final class.
203 template <typename S, typename Compare, bool update = true>
204 class ShortestFirstQueue : public QueueBase<S> {
205  public:
206  using StateId = S;
207 
208  explicit ShortestFirstQueue(Compare comp)
209  : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
210 
211  ~ShortestFirstQueue() override = default;
212 
213  StateId Head() const override { return heap_.Top(); }
214 
215  void Enqueue(StateId s) override {
216  if (update) {
217  for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
218  key_[s] = heap_.Insert(s);
219  } else {
220  heap_.Insert(s);
221  }
222  }
223 
224  void Dequeue() override {
225  if (update) {
226  key_[heap_.Pop()] = kNoStateId;
227  } else {
228  heap_.Pop();
229  }
230  }
231 
232  void Update(StateId s) override {
233  if (!update) return;
234  if (s >= key_.size() || key_[s] == kNoStateId) {
235  Enqueue(s);
236  } else {
237  heap_.Update(key_[s], s);
238  }
239  }
240 
241  bool Empty() const override { return heap_.Empty(); }
242 
243  void Clear() override {
244  heap_.Clear();
245  if (update) key_.clear();
246  }
247 
248  ssize_t Size() const { return heap_.Size(); }
249 
250  const Compare &GetCompare() const { return heap_.GetCompare(); }
251 
252  private:
254  std::vector<ssize_t> key_;
255 };
256 
257 namespace internal {
258 
259 // Given a vector that maps from states to weights, and a comparison functor
260 // for weights, this class defines a comparison function object between states.
261 template <typename StateId, typename Less>
263  public:
264  using Weight = typename Less::Weight;
265 
266  StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
267  : weights_(weights), less_(less) {}
268 
269  bool operator()(const StateId s1, const StateId s2) const {
270  return less_(weights_[s1], weights_[s2]);
271  }
272 
273  private:
274  // Borrowed references.
275  const std::vector<Weight> &weights_;
276  const Less &less_;
277 };
278 
279 // Comparison that can never be instantiated. Useful only to pass a pointer to
280 // this to a function that needs a comparison when it is known that the pointer
281 // will always be null.
282 template <class W>
283 struct ErrorLess {
284  using Weight = W;
286  FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type();
287  }
288  bool operator()(const Weight &, const Weight &) const { return false; }
289 };
290 
291 } // namespace internal
292 
293 // Shortest-first queue discipline, templated on the StateId and Weight, is
294 // specialized to use the weight's natural order for the comparison function.
295 // Requires Weight is idempotent (due to use of NaturalLess).
296 template <typename S, typename Weight>
298  : public ShortestFirstQueue<
299  S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
300  public:
301  using StateId = S;
304 
305  explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
306  : ShortestFirstQueue<StateId, Compare>(Compare(distance, Less())) {}
307 
308  ~NaturalShortestFirstQueue() override = default;
309 };
310 
311 // In a shortest path computation on a lattice-like FST, we may keep many old
312 // nonviable paths as a part of the search. Since the search process always
313 // expands the lowest cost path next, that lowest cost path may be a very old
314 // nonviable path instead of one we expect to lead to a shortest path.
315 //
316 // For instance, suppose that the current best path in an alignment has
317 // traversed 500 arcs with a cost of 10. We may also have a bad path in
318 // the queue that has traversed only 40 arcs but also has a cost of 10.
319 // This path is very unlikely to lead to a reasonable alignment, so this queue
320 // can prune it from the search space.
321 //
322 // This queue relies on the caller using a shortest-first exploration order
323 // like this:
324 // while (true) {
325 // StateId head = queue.Head();
326 // queue.Dequeue();
327 // for (const auto& arc : GetArcs(fst, head)) {
328 // queue.Enqueue(arc.nextstate);
329 // }
330 // }
331 // We use this assumption to guess that there is an arc between Head and the
332 // Enqueued state; this is how the number of path steps is measured.
333 template <typename S, typename Weight>
335  : public NaturalShortestFirstQueue<S, Weight> {
336  public:
337  using StateId = S;
339 
340  PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
341  ssize_t arc_threshold, ssize_t state_limit = 0)
342  : Base(distance),
343  arc_threshold_(arc_threshold),
344  state_limit_(state_limit),
345  head_steps_(0),
346  max_head_steps_(0) {}
347 
348  ~PruneNaturalShortestFirstQueue() override = default;
349 
350  StateId Head() const override {
351  const auto head = Base::Head();
352  // Stores the number of steps from the start of the graph to this state
353  // along the shortest-weight path.
354  if (head < steps_.size()) {
355  max_head_steps_ = std::max(steps_[head], max_head_steps_);
356  head_steps_ = steps_[head];
357  }
358  return head;
359  }
360 
361  void Enqueue(StateId s) override {
362  // We assume that there is an arc between the Head() state and this
363  // Enqueued state.
364  const ssize_t state_steps = head_steps_ + 1;
365  if (s >= steps_.size()) {
366  steps_.resize(s + 1, state_steps);
367  }
368  // This is the number of arcs in the minimum cost path from Start to s.
369  steps_[s] = state_steps;
370 
371  // Adjust the threshold in cases where path step thresholding wasn't
372  // enough to keep the queue small.
373  ssize_t adjusted_threshold = arc_threshold_;
374  if (Base::Size() > state_limit_ && state_limit_ > 0) {
375  adjusted_threshold = std::max<ssize_t>(
376  0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
377  }
378 
379  if (state_steps > (max_head_steps_ - adjusted_threshold) ||
380  arc_threshold_ < 0) {
381  if (adjusted_threshold == 0 && state_limit_ > 0) {
382  // If the queue is continuing to grow without bound, we follow any
383  // path that makes progress and clear the rest.
384  Base::Clear();
385  }
386  Base::Enqueue(s);
387  }
388  }
389 
390  private:
391  // A dense map from StateId to the number of arcs in the minimum weight
392  // path from Start to this state.
393  std::vector<ssize_t> steps_;
394  // We only keep paths that are within this number of arcs (not weight!)
395  // of the longest path.
396  const ssize_t arc_threshold_;
397  // If the size of the queue climbs above this number, we increase the
398  // threshold to reduce the amount of work we have to do.
399  const ssize_t state_limit_;
400 
401  // The following are mutable because Head() is const.
402  // The number of arcs traversed in the minimum cost path from the start
403  // state to the current Head() state.
404  mutable ssize_t head_steps_;
405  // The maximum number of arcs traversed by any low-cost path so far.
406  mutable ssize_t max_head_steps_;
407 };
408 
409 // Topological-order queue discipline, templated on the StateId. States are
410 // ordered in the queue topologically. The FST must be acyclic.
411 template <class S>
412 class TopOrderQueue : public QueueBase<S> {
413  public:
414  using StateId = S;
415 
416  // This constructor computes the topological order. It accepts an arc filter
417  // to limit the transitions considered in that computation (e.g., only the
418  // epsilon graph).
419  template <class Arc, class ArcFilter>
420  TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
422  front_(0),
423  back_(kNoStateId),
424  order_(0),
425  state_(0) {
426  bool acyclic;
427  TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
428  DfsVisit(fst, &top_order_visitor, filter);
429  if (!acyclic) {
430  FSTERROR() << "TopOrderQueue: FST is not acyclic";
432  }
433  state_.resize(order_.size(), kNoStateId);
434  }
435 
436  // This constructor is passed the pre-computed topological order.
437  explicit TopOrderQueue(const std::vector<StateId> &order)
439  front_(0),
440  back_(kNoStateId),
441  order_(order),
442  state_(order.size(), kNoStateId) {}
443 
444  ~TopOrderQueue() override = default;
445 
446  StateId Head() const final { return state_[front_]; }
447 
448  void Enqueue(StateId s) final {
449  if (front_ > back_) {
450  front_ = back_ = order_[s];
451  } else if (order_[s] > back_) {
452  back_ = order_[s];
453  } else if (order_[s] < front_) {
454  front_ = order_[s];
455  }
456  state_[order_[s]] = s;
457  }
458 
459  void Dequeue() final {
460  state_[front_] = kNoStateId;
461  while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
462  }
463 
464  void Update(StateId) final {}
465 
466  bool Empty() const final { return front_ > back_; }
467 
468  void Clear() final {
469  for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
470  back_ = kNoStateId;
471  front_ = 0;
472  }
473 
474  private:
475  StateId front_;
476  StateId back_;
477  std::vector<StateId> order_;
478  std::vector<StateId> state_;
479 };
480 
481 // State order queue discipline, templated on the StateId. States are ordered in
482 // the queue by state ID.
483 template <class S>
484 class StateOrderQueue : public QueueBase<S> {
485  public:
486  using StateId = S;
487 
489  : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
490 
491  ~StateOrderQueue() override = default;
492 
493  StateId Head() const final { return front_; }
494 
495  void Enqueue(StateId s) final {
496  if (front_ > back_) {
497  front_ = back_ = s;
498  } else if (s > back_) {
499  back_ = s;
500  } else if (s < front_) {
501  front_ = s;
502  }
503  while (enqueued_.size() <= s) enqueued_.push_back(false);
504  enqueued_[s] = true;
505  }
506 
507  void Dequeue() final {
508  enqueued_[front_] = false;
509  while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
510  }
511 
512  void Update(StateId) final {}
513 
514  bool Empty() const final { return front_ > back_; }
515 
516  void Clear() final {
517  for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
518  front_ = 0;
519  back_ = kNoStateId;
520  }
521 
522  private:
523  StateId front_;
524  StateId back_;
525  std::vector<bool> enqueued_;
526 };
527 
528 // SCC topological-order meta-queue discipline, templated on the StateId and a
529 // queue used inside each SCC. It visits the SCCs of an FST in topological
530 // order. Its constructor is passed the queues to to use within an SCC.
531 template <class S, class Queue>
532 class SccQueue : public QueueBase<S> {
533  public:
534  using StateId = S;
535 
536  // Constructor takes a vector specifying the SCC number per state and a
537  // vector giving the queue to use per SCC number.
538  SccQueue(const std::vector<StateId> &scc,
539  std::vector<std::unique_ptr<Queue>> *queue)
541  queue_(queue),
542  scc_(scc),
543  front_(0),
544  back_(kNoStateId) {}
545 
546  ~SccQueue() override = default;
547 
548  StateId Head() const final {
549  while ((front_ <= back_) &&
550  (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
551  (((*queue_)[front_] == nullptr) &&
552  ((front_ >= trivial_queue_.size()) ||
553  (trivial_queue_[front_] == kNoStateId))))) {
554  ++front_;
555  }
556  if ((*queue_)[front_]) {
557  return (*queue_)[front_]->Head();
558  } else {
559  return trivial_queue_[front_];
560  }
561  }
562 
563  void Enqueue(StateId s) final {
564  if (front_ > back_) {
565  front_ = back_ = scc_[s];
566  } else if (scc_[s] > back_) {
567  back_ = scc_[s];
568  } else if (scc_[s] < front_) {
569  front_ = scc_[s];
570  }
571  if ((*queue_)[scc_[s]]) {
572  (*queue_)[scc_[s]]->Enqueue(s);
573  } else {
574  while (trivial_queue_.size() <= scc_[s]) {
575  trivial_queue_.push_back(kNoStateId);
576  }
577  trivial_queue_[scc_[s]] = s;
578  }
579  }
580 
581  void Dequeue() final {
582  if ((*queue_)[front_]) {
583  (*queue_)[front_]->Dequeue();
584  } else if (front_ < trivial_queue_.size()) {
585  trivial_queue_[front_] = kNoStateId;
586  }
587  }
588 
589  void Update(StateId s) final {
590  if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
591  }
592 
593  bool Empty() const final {
594  // Queues SCC number back_ is not empty unless back_ == front_.
595  if (front_ < back_) {
596  return false;
597  } else if (front_ > back_) {
598  return true;
599  } else if ((*queue_)[front_]) {
600  return (*queue_)[front_]->Empty();
601  } else {
602  return (front_ >= trivial_queue_.size()) ||
603  (trivial_queue_[front_] == kNoStateId);
604  }
605  }
606 
607  void Clear() final {
608  for (StateId i = front_; i <= back_; ++i) {
609  if ((*queue_)[i]) {
610  (*queue_)[i]->Clear();
611  } else if (i < trivial_queue_.size()) {
612  trivial_queue_[i] = kNoStateId;
613  }
614  }
615  front_ = 0;
616  back_ = kNoStateId;
617  }
618 
619  private:
620  std::vector<std::unique_ptr<Queue>> *queue_;
621  const std::vector<StateId> &scc_;
622  mutable StateId front_;
623  StateId back_;
624  std::vector<StateId> trivial_queue_;
625 };
626 
627 // Automatic queue discipline. It selects a queue discipline for a given FST
628 // based on its properties.
629 template <class S>
630 class AutoQueue : public QueueBase<S> {
631  public:
632  using StateId = S;
633 
634  // This constructor takes a state distance vector that, if non-null and if
635  // the Weight type has the path property, will entertain the shortest-first
636  // queue using the natural order w.r.t to the distance.
637  template <class Arc, class ArcFilter>
639  const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
641  using Weight = typename Arc::Weight;
642  // We need to have variables of type Less and Compare, so we use
643  // ErrorLess if the type NaturalLess<Weight> cannot be instantiated due
644  // to lack of path property.
645  using Less = std::conditional_t<IsPath<Weight>::value, NaturalLess<Weight>,
648  // First checks if the FST is known to have these properties.
649  const auto props =
651  if ((props & kTopSorted) || fst.Start() == kNoStateId) {
652  queue_ = std::make_unique<StateOrderQueue<StateId>>();
653  VLOG(2) << "AutoQueue: using state-order discipline";
654  } else if (props & kAcyclic) {
655  queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
656  VLOG(2) << "AutoQueue: using top-order discipline";
657  } else if ((props & kUnweighted) && IsIdempotent<Weight>::value) {
658  queue_ = std::make_unique<LifoQueue<StateId>>();
659  VLOG(2) << "AutoQueue: using LIFO discipline";
660  } else {
661  uint64_t properties;
662  // Decomposes into strongly-connected components.
663  SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
664  DfsVisit(fst, &scc_visitor, filter);
665  auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
666  std::vector<QueueType> queue_types(nscc);
667  std::unique_ptr<Less> less;
668  std::unique_ptr<Compare> comp;
669  if constexpr (IsPath<Weight>::value) {
670  if (distance) {
671  less = std::make_unique<Less>();
672  comp = std::make_unique<Compare>(*distance, *less);
673  }
674  }
675  // Finds the queue type to use per SCC.
676  bool unweighted;
677  bool all_trivial;
678  SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
679  &unweighted);
680  // If unweighted and semiring is idempotent, uses LIFO queue.
681  if (unweighted) {
682  queue_ = std::make_unique<LifoQueue<StateId>>();
683  VLOG(2) << "AutoQueue: using LIFO discipline";
684  return;
685  }
686  // If all the SCC are trivial, the FST is acyclic and the scc number gives
687  // the topological order.
688  if (all_trivial) {
689  queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
690  VLOG(2) << "AutoQueue: using top-order discipline";
691  return;
692  }
693  VLOG(2) << "AutoQueue: using SCC meta-discipline";
694  queues_.resize(nscc);
695  for (StateId i = 0; i < nscc; ++i) {
696  switch (queue_types[i]) {
697  case TRIVIAL_QUEUE:
698  queues_[i].reset();
699  VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
700  break;
702  // The IsPath test is not needed for correctness. It just saves
703  // instantiating a ShortestFirstQueue that can never be called.
704  if constexpr (IsPath<Weight>::value) {
705  queues_[i] = std::make_unique<
707  VLOG(3) << "AutoQueue: SCC #" << i
708  << ": using shortest-first discipline";
709  } else {
710  // SccQueueType should ensure this can never happen.
711  FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight "
712  << Weight::Type();
713  queues_[i].reset();
714  }
715  break;
716  case LIFO_QUEUE:
717  queues_[i] = std::make_unique<LifoQueue<StateId>>();
718  VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
719  break;
720  case FIFO_QUEUE:
721  default:
722  queues_[i] = std::make_unique<FifoQueue<StateId>>();
723  VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
724  break;
725  }
726  }
727  queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
728  scc_, &queues_);
729  }
730  }
731 
732  ~AutoQueue() override = default;
733 
734  StateId Head() const final { return queue_->Head(); }
735 
736  void Enqueue(StateId s) final { queue_->Enqueue(s); }
737 
738  void Dequeue() final { queue_->Dequeue(); }
739 
740  void Update(StateId s) final { queue_->Update(s); }
741 
742  bool Empty() const final { return queue_->Empty(); }
743 
744  void Clear() final { queue_->Clear(); }
745 
746  private:
747  template <class Arc, class ArcFilter, class Less>
748  static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
749  std::vector<QueueType> *queue_types,
750  ArcFilter filter, Less *less, bool *all_trivial,
751  bool *unweighted);
752 
753  std::unique_ptr<QueueBase<StateId>> queue_;
754  std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
755  std::vector<StateId> scc_;
756 };
757 
758 // Examines the states in an FST's strongly connected components and determines
759 // which type of queue to use per SCC. Stores result as a vector of QueueTypes
760 // which is assumed to have length equal to the number of SCCs. An arc filter
761 // is used to limit the transitions considered (e.g., only the epsilon graph).
762 // The argument all_trivial is set to true if every queue is the trivial queue.
763 // The argument unweighted is set to true if the semiring is idempotent and all
764 // the arc weights are equal to Zero() or One().
765 template <class StateId>
766 template <class Arc, class ArcFilter, class Less>
768  const std::vector<StateId> &scc,
769  std::vector<QueueType> *queue_type,
770  ArcFilter filter, Less *less,
771  bool *all_trivial, bool *unweighted) {
772  using StateId = typename Arc::StateId;
773  using Weight = typename Arc::Weight;
774  *all_trivial = true;
775  *unweighted = true;
776  for (StateId i = 0; i < queue_type->size(); ++i) {
777  (*queue_type)[i] = TRIVIAL_QUEUE;
778  }
779  for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
780  const auto state = sit.Value();
781  for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
782  const auto &arc = ait.Value();
783  if (!filter(arc)) continue;
784  if (scc[state] == scc[arc.nextstate]) {
785  auto &type = (*queue_type)[scc[state]];
786  if constexpr (!IsPath<Weight>::value) {
787  type = FIFO_QUEUE;
788  } else if (!less || (*less)(arc.weight, Weight::One())) {
789  type = FIFO_QUEUE;
790  } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
792  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
793  type = SHORTEST_FIRST_QUEUE;
794  } else {
795  type = LIFO_QUEUE;
796  }
797  }
798  if (type != TRIVIAL_QUEUE) *all_trivial = false;
799  }
801  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
802  *unweighted = false;
803  }
804  }
805  }
806 }
807 
808 // An A* estimate is a function object that maps from a state ID to an
809 // estimate of the shortest distance to the final states.
810 
811 // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
812 // algorithm.
813 template <typename StateId, typename Weight>
815  constexpr Weight operator()(StateId) const { return Weight::One(); }
816 };
817 
818 // A non-trivial A* estimate using a vector of the estimated future costs.
819 template <typename StateId, typename Weight>
821  public:
822  NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
823 
824  const Weight &operator()(StateId s) const {
825  return (s < beta_.size()) ? beta_[s] : kZero;
826  }
827 
828  private:
829  static constexpr Weight kZero = Weight::Zero();
830 
831  const std::vector<Weight> &beta_;
832 };
833 
834 // Given a vector that maps from states to weights representing the shortest
835 // distance from the initial state, a comparison function object between
836 // weights, and an estimate of the shortest distance to the final states, this
837 // class defines a comparison function object between states.
838 template <typename S, typename Less, typename Estimate>
840  public:
841  using StateId = S;
842  using Weight = typename Less::Weight;
843 
844  AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
845  const Estimate &estimate)
846  : weights_(weights), less_(less), estimate_(estimate) {}
847 
848  bool operator()(StateId s1, StateId s2) const {
849  const auto w1 = Times(weights_[s1], estimate_(s1));
850  const auto w2 = Times(weights_[s2], estimate_(s2));
851  return less_(w1, w2);
852  }
853 
854  const Estimate &GetEstimate() const { return estimate_; }
855 
856  private:
857  const std::vector<Weight> &weights_;
858  const Less &less_;
859  const Estimate &estimate_;
860 };
861 
862 // A* queue discipline templated on StateId, Weight, and Estimate.
863 template <typename S, typename Weight, typename Estimate>
865  : public ShortestFirstQueue<
866  S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
867  public:
868  using StateId = S;
870 
871  NaturalAStarQueue(const std::vector<Weight> &distance,
872  const Estimate &estimate)
874  Compare(distance, less_, estimate)) {}
875 
876  ~NaturalAStarQueue() override = default;
877 
878  private:
879  // This is non-static because the constructor for non-idempotent weights will
880  // result in an error.
881  const NaturalLess<Weight> less_{};
882 };
883 
884 // A state equivalence class is a function object that maps from a state ID to
885 // an equivalence class (state) ID. The trivial equivalence class maps a state
886 // ID to itself.
887 template <typename StateId>
889  StateId operator()(StateId s) const { return s; }
890 };
891 
892 // Distance-based pruning queue discipline: Enqueues a state only when its
893 // shortest distance (so far), as specified by distance, is less than (as
894 // specified by comp) the shortest distance Times() the threshold to any state
895 // in the same equivalence class, as specified by the functor class_func. The
896 // underlying queue discipline is specified by queue.
897 //
898 // This is not a final class.
899 template <typename Queue, typename Less, typename ClassFnc>
900 class PruneQueue : public QueueBase<typename Queue::StateId> {
901  public:
902  using StateId = typename Queue::StateId;
903  using Weight = typename Less::Weight;
904 
905  PruneQueue(const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
906  const Less &less, const ClassFnc &class_fnc, Weight threshold)
908  distance_(distance),
909  queue_(std::move(queue)),
910  less_(less),
911  class_fnc_(class_fnc),
912  threshold_(std::move(threshold)) {}
913 
914  ~PruneQueue() override = default;
915 
916  StateId Head() const override { return queue_->Head(); }
917 
918  void Enqueue(StateId s) override {
919  const auto c = class_fnc_(s);
920  if (c >= class_distance_.size()) {
921  class_distance_.resize(c + 1, Weight::Zero());
922  }
923  if (less_(distance_[s], class_distance_[c])) {
924  class_distance_[c] = distance_[s];
925  }
926  // Enqueues only if below threshold limit.
927  const auto limit = Times(class_distance_[c], threshold_);
928  if (less_(distance_[s], limit)) queue_->Enqueue(s);
929  }
930 
931  void Dequeue() override { queue_->Dequeue(); }
932 
933  void Update(StateId s) override {
934  const auto c = class_fnc_(s);
935  if (less_(distance_[s], class_distance_[c])) {
936  class_distance_[c] = distance_[s];
937  }
938  queue_->Update(s);
939  }
940 
941  bool Empty() const override { return queue_->Empty(); }
942 
943  void Clear() override { queue_->Clear(); }
944 
945  private:
946  const std::vector<Weight> &distance_; // Shortest distance to state.
947  std::unique_ptr<Queue> queue_;
948  const Less &less_; // Borrowed reference.
949  const ClassFnc &class_fnc_; // Equivalence class functor.
950  Weight threshold_; // Pruning weight threshold.
951  std::vector<Weight> class_distance_; // Shortest distance to class.
952 };
953 
954 // Pruning queue discipline (see above) using the weight's natural order for the
955 // comparison function. The ownership of the queue argument is given to this
956 // class.
957 template <typename Queue, typename Weight, typename ClassFnc>
958 class NaturalPruneQueue final
959  : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
960  public:
961  using StateId = typename Queue::StateId;
962 
963  NaturalPruneQueue(const std::vector<Weight> &distance,
964  std::unique_ptr<Queue> queue, const ClassFnc &class_fnc,
965  Weight threshold)
966  : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
967  distance, std::move(queue), NaturalLess<Weight>(), class_fnc,
968  threshold) {}
969 
970  ~NaturalPruneQueue() override = default;
971 };
972 
973 // Filter-based pruning queue discipline: enqueues a state only if allowed by
974 // the filter, specified by the state filter functor argument. The underlying
975 // queue discipline is specified by the queue argument.
976 template <typename Queue, typename Filter>
977 class FilterQueue : public QueueBase<typename Queue::StateId> {
978  public:
979  using StateId = typename Queue::StateId;
980 
981  FilterQueue(std::unique_ptr<Queue> queue, const Filter &filter)
983  queue_(std::move(queue)),
984  filter_(filter) {}
985 
986  ~FilterQueue() override = default;
987 
988  StateId Head() const final { return queue_->Head(); }
989 
990  // Enqueues only if allowed by state filter.
991  void Enqueue(StateId s) final {
992  if (filter_(s)) queue_->Enqueue(s);
993  }
994 
995  void Dequeue() final { queue_->Dequeue(); }
996 
997  void Update(StateId s) final {}
998 
999  bool Empty() const final { return queue_->Empty(); }
1000 
1001  void Clear() final { queue_->Clear(); }
1002 
1003  private:
1004  std::unique_ptr<Queue> queue_;
1005  const Filter &filter_;
1006 };
1007 
1008 } // namespace fst
1009 
1010 #endif // FST_QUEUE_H_
virtual bool Empty() const =0
void Clear() override
Definition: queue.h:165
constexpr uint64_t kCyclic
Definition: properties.h:108
bool Empty() const final
Definition: queue.h:189
bool Empty() const final
Definition: queue.h:135
void Enqueue(StateId s) override
Definition: queue.h:157
StateId Head() const final
Definition: queue.h:446
void Update(StateId s) override
Definition: queue.h:933
void Enqueue(StateId s) override
Definition: queue.h:215
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:264
void Clear() final
Definition: queue.h:137
StateId Head() const final
Definition: queue.h:181
void SetError(bool error)
Definition: queue.h:95
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void Clear() override
Definition: queue.h:243
FilterQueue(std::unique_ptr< Queue > queue, const Filter &filter)
Definition: queue.h:981
void Update(StateId) final
Definition: queue.h:133
QueueType
Definition: queue.h:70
NaturalAStarQueue(const std::vector< Weight > &distance, const Estimate &estimate)
Definition: queue.h:871
NaturalShortestFirstQueue(const std::vector< Weight > &distance)
Definition: queue.h:305
StateWeightCompare(const std::vector< Weight > &weights, const Less &less)
Definition: queue.h:266
bool Empty() const final
Definition: queue.h:514
StateId Head() const final
Definition: queue.h:127
typename Queue::StateId StateId
Definition: queue.h:979
void Enqueue(StateId s) final
Definition: queue.h:563
void Clear() final
Definition: queue.h:516
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
void Enqueue(StateId s) final
Definition: queue.h:129
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:109
bool operator()(const Weight &, const Weight &) const
Definition: queue.h:288
ShortestFirstQueue(Compare comp)
Definition: queue.h:208
void Enqueue(StateId s) final
Definition: queue.h:736
void Enqueue(StateId s) final
Definition: queue.h:495
StateId Head() const override
Definition: queue.h:213
void Clear() final
Definition: queue.h:607
StateId Head() const override
Definition: queue.h:350
void Update(StateId) override
Definition: queue.h:161
bool operator()(StateId s1, StateId s2) const
Definition: queue.h:848
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:903
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:156
void Update(StateId) final
Definition: queue.h:187
bool Empty() const override
Definition: queue.h:163
TopOrderQueue(const Fst< Arc > &fst, ArcFilter filter)
Definition: queue.h:420
constexpr uint64_t kTopSorted
Definition: properties.h:118
const Weight & operator()(StateId s) const
Definition: queue.h:824
NaturalAStarEstimate(const std::vector< Weight > &beta)
Definition: queue.h:822
void Dequeue() final
Definition: queue.h:131
bool Error() const
Definition: queue.h:97
StateId Head() const override
Definition: queue.h:916
constexpr int kNoStateId
Definition: fst.h:202
virtual void Dequeue()=0
void Dequeue() final
Definition: queue.h:507
QueueBase(QueueType type)
Definition: queue.h:93
void Update(StateId) final
Definition: queue.h:464
StateId Head() const final
Definition: queue.h:548
#define FSTERROR()
Definition: util.h:53
void Enqueue(StateId s) override
Definition: queue.h:918
void Update(StateId s) final
Definition: queue.h:997
StateId operator()(StateId s) const
Definition: queue.h:889
bool Empty() const final
Definition: queue.h:593
virtual void Update(StateId)=0
void Update(StateId s) final
Definition: queue.h:740
void Enqueue(StateId s) final
Definition: queue.h:183
constexpr uint64_t kAcyclic
Definition: properties.h:110
virtual void Clear()=0
virtual void Enqueue(StateId)=0
constexpr Weight operator()(StateId) const
Definition: queue.h:815
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:842
AutoQueue(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > *distance, ArcFilter filter)
Definition: queue.h:638
#define VLOG(level)
Definition: log.h:50
virtual StateId Start() const =0
void Dequeue() final
Definition: queue.h:459
void Dequeue() final
Definition: queue.h:995
void Dequeue() override
Definition: queue.h:224
const Estimate & GetEstimate() const
Definition: queue.h:854
void Clear() final
Definition: queue.h:468
bool Empty() const override
Definition: queue.h:941
void Dequeue() override
Definition: queue.h:159
TopOrderQueue(const std::vector< StateId > &order)
Definition: queue.h:437
AStarWeightCompare(const std::vector< Weight > &weights, const Less &less, const Estimate &estimate)
Definition: queue.h:844
StateId Head() const final
Definition: queue.h:988
virtual ~QueueBase()
Definition: queue.h:89
void Dequeue() final
Definition: queue.h:581
StateId Head() const final
Definition: queue.h:493
constexpr uint64_t kUnweighted
Definition: properties.h:105
StateId Head() const final
Definition: queue.h:734
void Enqueue(StateId s) final
Definition: queue.h:448
void Dequeue() final
Definition: queue.h:185
void Update(StateId s) override
Definition: queue.h:232
void Clear() final
Definition: queue.h:191
void Update(StateId) final
Definition: queue.h:512
NaturalPruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:963
void Enqueue(StateId s) override
Definition: queue.h:361
SccQueue(const std::vector< StateId > &scc, std::vector< std::unique_ptr< Queue >> *queue)
Definition: queue.h:538
bool Empty() const override
Definition: queue.h:241
void Update(StateId s) final
Definition: queue.h:589
QueueType Type() const
Definition: queue.h:99
Queue::StateId StateId
Definition: queue.h:87
bool operator()(const StateId s1, const StateId s2) const
Definition: queue.h:269
void Enqueue(StateId s) final
Definition: queue.h:991
PruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const Less &less, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:905
void Dequeue() override
Definition: queue.h:931
bool Empty() const final
Definition: queue.h:999
PruneNaturalShortestFirstQueue(const std::vector< Weight > &distance, ssize_t arc_threshold, ssize_t state_limit=0)
Definition: queue.h:340
void Clear() final
Definition: queue.h:744
ssize_t Size() const
Definition: queue.h:248
virtual StateId Head() const =0
typename Queue::StateId StateId
Definition: queue.h:961
void Clear() override
Definition: queue.h:943
void Clear() final
Definition: queue.h:1001
void Dequeue() final
Definition: queue.h:738
StateId Head() const override
Definition: queue.h:155
bool Empty() const final
Definition: queue.h:466
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
Definition: weight.h:159
bool Empty() const final
Definition: queue.h:742
const Compare & GetCompare() const
Definition: queue.h:250