FST  openfst-1.8.3
OpenFst Library
queue.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes for various FST state queues with a unified interface.
19 
20 #ifndef FST_QUEUE_H_
21 #define FST_QUEUE_H_
22 
23 #include <sys/types.h>
24 
25 #include <algorithm>
26 #include <cstdint>
27 #include <memory>
28 #include <queue>
29 #include <type_traits>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/log.h>
34 #include <fst/arcfilter.h>
35 #include <fst/cc-visitors.h>
36 #include <fst/dfs-visit.h>
37 #include <fst/fst.h>
38 #include <fst/heap.h>
39 #include <fst/properties.h>
40 #include <fst/topsort.h>
41 #include <fst/util.h>
42 #include <fst/weight.h>
43 
44 namespace fst {
45 
46 // The Queue interface is:
47 //
48 // template <class S>
49 // class Queue {
50 // public:
51 // using StateId = S;
52 //
53 // // Constructor: may need args (e.g., FST, comparator) for some queues.
54 // Queue(...) override;
55 //
56 // // Returns the head of the queue.
57 // StateId Head() const override;
58 //
59 // // Inserts a state.
60 // void Enqueue(StateId s) override;
61 //
62 // // Removes the head of the queue.
63 // void Dequeue() override;
64 //
65 // // Updates ordering of state s when weight changes, if necessary.
66 // void Update(StateId s) override;
67 //
68 // // Is the queue empty?
69 // bool Empty() const override;
70 //
71 // // Removes all states from the queue.
72 // void Clear() override;
73 // };
74 
75 // State queue types.
76 enum QueueType {
77  TRIVIAL_QUEUE = 0, // Single state queue.
78  FIFO_QUEUE = 1, // First-in, first-out queue.
79  LIFO_QUEUE = 2, // Last-in, first-out queue.
80  SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
81  TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
82  STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
83  SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
84  AUTO_QUEUE = 7, // Auto-selected queue.
86 };
87 
88 // QueueBase, templated on the StateId, is a virtual base class shared by all
89 // queues considered by AutoQueue.
90 template <class S>
91 class QueueBase {
92  public:
93  using StateId = S;
94 
95  virtual ~QueueBase() = default;
96 
97  // Concrete implementation.
98 
99  explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
100 
101  void SetError(bool error) { error_ = error; }
102 
103  bool Error() const { return error_; }
104 
105  QueueType Type() const { return queue_type_; }
106 
107  // Virtual interface.
108 
109  virtual StateId Head() const = 0;
110  virtual void Enqueue(StateId) = 0;
111  virtual void Dequeue() = 0;
112  virtual void Update(StateId) = 0;
113  virtual bool Empty() const = 0;
114  virtual void Clear() = 0;
115 
116  private:
117  QueueType queue_type_;
118  bool error_;
119 };
120 
121 // Trivial queue discipline; one may enqueue at most one state at a time. It
122 // can be used for strongly connected components with only one state and no
123 // self-loops.
124 template <class S>
125 class TrivialQueue : public QueueBase<S> {
126  public:
127  using StateId = S;
128 
130 
131  ~TrivialQueue() override = default;
132 
133  StateId Head() const final { return front_; }
134 
135  void Enqueue(StateId s) final { front_ = s; }
136 
137  void Dequeue() final { front_ = kNoStateId; }
138 
139  void Update(StateId) final {}
140 
141  bool Empty() const final { return front_ == kNoStateId; }
142 
143  void Clear() final { front_ = kNoStateId; }
144 
145  private:
146  StateId front_;
147 };
148 
149 // First-in, first-out queue discipline.
150 //
151 // This is not a final class.
152 template <class S>
153 class FifoQueue : public QueueBase<S> {
154  public:
155  using StateId = S;
156 
158 
159  ~FifoQueue() override = default;
160 
161  StateId Head() const override { return queue_.front(); }
162 
163  void Enqueue(StateId s) override { queue_.push(s); }
164 
165  void Dequeue() override { queue_.pop(); }
166 
167  void Update(StateId) override {}
168 
169  bool Empty() const override { return queue_.empty(); }
170 
171  void Clear() override { queue_ = std::queue<StateId>(); }
172 
173  private:
174  std::queue<StateId> queue_;
175 };
176 
177 // Last-in, first-out queue discipline.
178 template <class S>
179 class LifoQueue : public QueueBase<S> {
180  public:
181  using StateId = S;
182 
184 
185  ~LifoQueue() override = default;
186 
187  StateId Head() const final { return stack_.back(); }
188 
189  void Enqueue(StateId s) final { stack_.push_back(s); }
190 
191  void Dequeue() final { stack_.pop_back(); }
192 
193  void Update(StateId) final {}
194 
195  bool Empty() const final { return stack_.empty(); }
196 
197  void Clear() final { stack_.clear(); }
198 
199  private:
200  std::vector<StateId> stack_;
201 };
202 
203 // Shortest-first queue discipline, templated on the StateId and as well as a
204 // comparison functor used to compare two StateIds. If a (single) state's order
205 // changes, it can be reordered in the queue with a call to Update(). If update
206 // is false, call to Update() does not reorder the queue.
207 //
208 // This is not a final class.
209 template <typename S, typename Compare, bool update = true>
210 class ShortestFirstQueue : public QueueBase<S> {
211  public:
212  using StateId = S;
213 
214  explicit ShortestFirstQueue(Compare comp)
215  : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
216 
217  ~ShortestFirstQueue() override = default;
218 
219  StateId Head() const override { return heap_.Top(); }
220 
221  void Enqueue(StateId s) override {
222  if (update) {
223  for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
224  key_[s] = heap_.Insert(s);
225  } else {
226  heap_.Insert(s);
227  }
228  }
229 
230  void Dequeue() override {
231  if (update) {
232  key_[heap_.Pop()] = kNoStateId;
233  } else {
234  heap_.Pop();
235  }
236  }
237 
238  void Update(StateId s) override {
239  if (!update) return;
240  if (s >= key_.size() || key_[s] == kNoStateId) {
241  Enqueue(s);
242  } else {
243  heap_.Update(key_[s], s);
244  }
245  }
246 
247  bool Empty() const override { return heap_.Empty(); }
248 
249  void Clear() override {
250  heap_.Clear();
251  if (update) key_.clear();
252  }
253 
254  ssize_t Size() const { return heap_.Size(); }
255 
256  const Compare &GetCompare() const { return heap_.GetCompare(); }
257 
258  private:
260  std::vector<ssize_t> key_;
261 };
262 
263 namespace internal {
264 
265 // Given a vector that maps from states to weights, and a comparison functor
266 // for weights, this class defines a comparison function object between states.
267 template <typename StateId, typename Less>
269  public:
270  using Weight = typename Less::Weight;
271 
272  StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
273  : weights_(weights), less_(less) {}
274 
275  bool operator()(const StateId s1, const StateId s2) const {
276  return less_(weights_[s1], weights_[s2]);
277  }
278 
279  private:
280  // Borrowed references.
281  const std::vector<Weight> &weights_;
282  const Less &less_;
283 };
284 
285 // Comparison that can never be instantiated. Useful only to pass a pointer to
286 // this to a function that needs a comparison when it is known that the pointer
287 // will always be null.
288 template <class W>
289 struct ErrorLess {
290  using Weight = W;
292  FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type();
293  }
294  bool operator()(const Weight &, const Weight &) const { return false; }
295 };
296 
297 } // namespace internal
298 
299 // Shortest-first queue discipline, templated on the StateId and Weight, is
300 // specialized to use the weight's natural order for the comparison function.
301 // Requires Weight is idempotent (due to use of NaturalLess).
302 template <typename S, typename Weight>
304  : public ShortestFirstQueue<
305  S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
306  public:
307  using StateId = S;
310 
311  explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
312  : ShortestFirstQueue<StateId, Compare>(Compare(distance, Less())) {}
313 
314  ~NaturalShortestFirstQueue() override = default;
315 };
316 
317 // In a shortest path computation on a lattice-like FST, we may keep many old
318 // nonviable paths as a part of the search. Since the search process always
319 // expands the lowest cost path next, that lowest cost path may be a very old
320 // nonviable path instead of one we expect to lead to a shortest path.
321 //
322 // For instance, suppose that the current best path in an alignment has
323 // traversed 500 arcs with a cost of 10. We may also have a bad path in
324 // the queue that has traversed only 40 arcs but also has a cost of 10.
325 // This path is very unlikely to lead to a reasonable alignment, so this queue
326 // can prune it from the search space.
327 //
328 // This queue relies on the caller using a shortest-first exploration order
329 // like this:
330 // while (true) {
331 // StateId head = queue.Head();
332 // queue.Dequeue();
333 // for (const auto& arc : GetArcs(fst, head)) {
334 // queue.Enqueue(arc.nextstate);
335 // }
336 // }
337 // We use this assumption to guess that there is an arc between Head and the
338 // Enqueued state; this is how the number of path steps is measured.
339 template <typename S, typename Weight>
341  : public NaturalShortestFirstQueue<S, Weight> {
342  public:
343  using StateId = S;
345 
346  PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
347  ssize_t arc_threshold, ssize_t state_limit = 0)
348  : Base(distance),
349  arc_threshold_(arc_threshold),
350  state_limit_(state_limit),
351  head_steps_(0),
352  max_head_steps_(0) {}
353 
354  ~PruneNaturalShortestFirstQueue() override = default;
355 
356  StateId Head() const override {
357  const auto head = Base::Head();
358  // Stores the number of steps from the start of the graph to this state
359  // along the shortest-weight path.
360  if (head < steps_.size()) {
361  max_head_steps_ = std::max(steps_[head], max_head_steps_);
362  head_steps_ = steps_[head];
363  }
364  return head;
365  }
366 
367  void Enqueue(StateId s) override {
368  // We assume that there is an arc between the Head() state and this
369  // Enqueued state.
370  const ssize_t state_steps = head_steps_ + 1;
371  if (s >= steps_.size()) {
372  steps_.resize(s + 1, state_steps);
373  }
374  // This is the number of arcs in the minimum cost path from Start to s.
375  steps_[s] = state_steps;
376 
377  // Adjust the threshold in cases where path step thresholding wasn't
378  // enough to keep the queue small.
379  ssize_t adjusted_threshold = arc_threshold_;
380  if (Base::Size() > state_limit_ && state_limit_ > 0) {
381  adjusted_threshold = std::max<ssize_t>(
382  0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
383  }
384 
385  if (state_steps > (max_head_steps_ - adjusted_threshold) ||
386  arc_threshold_ < 0) {
387  if (adjusted_threshold == 0 && state_limit_ > 0) {
388  // If the queue is continuing to grow without bound, we follow any
389  // path that makes progress and clear the rest.
390  Base::Clear();
391  }
392  Base::Enqueue(s);
393  }
394  }
395 
396  private:
397  // A dense map from StateId to the number of arcs in the minimum weight
398  // path from Start to this state.
399  std::vector<ssize_t> steps_;
400  // We only keep paths that are within this number of arcs (not weight!)
401  // of the longest path.
402  const ssize_t arc_threshold_;
403  // If the size of the queue climbs above this number, we increase the
404  // threshold to reduce the amount of work we have to do.
405  const ssize_t state_limit_;
406 
407  // The following are mutable because Head() is const.
408  // The number of arcs traversed in the minimum cost path from the start
409  // state to the current Head() state.
410  mutable ssize_t head_steps_;
411  // The maximum number of arcs traversed by any low-cost path so far.
412  mutable ssize_t max_head_steps_;
413 };
414 
415 // Topological-order queue discipline, templated on the StateId. States are
416 // ordered in the queue topologically. The FST must be acyclic.
417 template <class S>
418 class TopOrderQueue : public QueueBase<S> {
419  public:
420  using StateId = S;
421 
422  // This constructor computes the topological order. It accepts an arc filter
423  // to limit the transitions considered in that computation (e.g., only the
424  // epsilon graph).
425  template <class Arc, class ArcFilter>
426  TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
428  front_(0),
429  back_(kNoStateId),
430  order_(0),
431  state_(0) {
432  bool acyclic;
433  TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
434  DfsVisit(fst, &top_order_visitor, filter);
435  if (!acyclic) {
436  FSTERROR() << "TopOrderQueue: FST is not acyclic";
438  }
439  state_.resize(order_.size(), kNoStateId);
440  }
441 
442  // This constructor is passed the pre-computed topological order.
443  explicit TopOrderQueue(const std::vector<StateId> &order)
445  front_(0),
446  back_(kNoStateId),
447  order_(order),
448  state_(order.size(), kNoStateId) {}
449 
450  ~TopOrderQueue() override = default;
451 
452  StateId Head() const final { return state_[front_]; }
453 
454  void Enqueue(StateId s) final {
455  if (front_ > back_) {
456  front_ = back_ = order_[s];
457  } else if (order_[s] > back_) {
458  back_ = order_[s];
459  } else if (order_[s] < front_) {
460  front_ = order_[s];
461  }
462  state_[order_[s]] = s;
463  }
464 
465  void Dequeue() final {
466  state_[front_] = kNoStateId;
467  while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
468  }
469 
470  void Update(StateId) final {}
471 
472  bool Empty() const final { return front_ > back_; }
473 
474  void Clear() final {
475  for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
476  back_ = kNoStateId;
477  front_ = 0;
478  }
479 
480  private:
481  StateId front_;
482  StateId back_;
483  std::vector<StateId> order_;
484  std::vector<StateId> state_;
485 };
486 
487 // State order queue discipline, templated on the StateId. States are ordered in
488 // the queue by state ID.
489 template <class S>
490 class StateOrderQueue : public QueueBase<S> {
491  public:
492  using StateId = S;
493 
495  : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
496 
497  ~StateOrderQueue() override = default;
498 
499  StateId Head() const final { return front_; }
500 
501  void Enqueue(StateId s) final {
502  if (front_ > back_) {
503  front_ = back_ = s;
504  } else if (s > back_) {
505  back_ = s;
506  } else if (s < front_) {
507  front_ = s;
508  }
509  while (enqueued_.size() <= s) enqueued_.push_back(false);
510  enqueued_[s] = true;
511  }
512 
513  void Dequeue() final {
514  enqueued_[front_] = false;
515  while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
516  }
517 
518  void Update(StateId) final {}
519 
520  bool Empty() const final { return front_ > back_; }
521 
522  void Clear() final {
523  for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
524  front_ = 0;
525  back_ = kNoStateId;
526  }
527 
528  private:
529  StateId front_;
530  StateId back_;
531  std::vector<bool> enqueued_;
532 };
533 
534 // SCC topological-order meta-queue discipline, templated on the StateId and a
535 // queue used inside each SCC. It visits the SCCs of an FST in topological
536 // order. Its constructor is passed the queues to to use within an SCC.
537 template <class S, class Queue>
538 class SccQueue : public QueueBase<S> {
539  public:
540  using StateId = S;
541 
542  // Constructor takes a vector specifying the SCC number per state and a
543  // vector giving the queue to use per SCC number.
544  SccQueue(const std::vector<StateId> &scc,
545  std::vector<std::unique_ptr<Queue>> *queue)
547  queue_(queue),
548  scc_(scc),
549  front_(0),
550  back_(kNoStateId) {}
551 
552  ~SccQueue() override = default;
553 
554  StateId Head() const final {
555  while ((front_ <= back_) &&
556  (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
557  (((*queue_)[front_] == nullptr) &&
558  ((front_ >= trivial_queue_.size()) ||
559  (trivial_queue_[front_] == kNoStateId))))) {
560  ++front_;
561  }
562  if ((*queue_)[front_]) {
563  return (*queue_)[front_]->Head();
564  } else {
565  return trivial_queue_[front_];
566  }
567  }
568 
569  void Enqueue(StateId s) final {
570  if (front_ > back_) {
571  front_ = back_ = scc_[s];
572  } else if (scc_[s] > back_) {
573  back_ = scc_[s];
574  } else if (scc_[s] < front_) {
575  front_ = scc_[s];
576  }
577  if ((*queue_)[scc_[s]]) {
578  (*queue_)[scc_[s]]->Enqueue(s);
579  } else {
580  while (trivial_queue_.size() <= scc_[s]) {
581  trivial_queue_.push_back(kNoStateId);
582  }
583  trivial_queue_[scc_[s]] = s;
584  }
585  }
586 
587  void Dequeue() final {
588  if ((*queue_)[front_]) {
589  (*queue_)[front_]->Dequeue();
590  } else if (front_ < trivial_queue_.size()) {
591  trivial_queue_[front_] = kNoStateId;
592  }
593  }
594 
595  void Update(StateId s) final {
596  if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
597  }
598 
599  bool Empty() const final {
600  // Queues SCC number back_ is not empty unless back_ == front_.
601  if (front_ < back_) {
602  return false;
603  } else if (front_ > back_) {
604  return true;
605  } else if ((*queue_)[front_]) {
606  return (*queue_)[front_]->Empty();
607  } else {
608  return (front_ >= trivial_queue_.size()) ||
609  (trivial_queue_[front_] == kNoStateId);
610  }
611  }
612 
613  void Clear() final {
614  for (StateId i = front_; i <= back_; ++i) {
615  if ((*queue_)[i]) {
616  (*queue_)[i]->Clear();
617  } else if (i < trivial_queue_.size()) {
618  trivial_queue_[i] = kNoStateId;
619  }
620  }
621  front_ = 0;
622  back_ = kNoStateId;
623  }
624 
625  private:
626  std::vector<std::unique_ptr<Queue>> *queue_;
627  const std::vector<StateId> &scc_;
628  mutable StateId front_;
629  StateId back_;
630  std::vector<StateId> trivial_queue_;
631 };
632 
633 // Automatic queue discipline. It selects a queue discipline for a given FST
634 // based on its properties.
635 template <class S>
636 class AutoQueue : public QueueBase<S> {
637  public:
638  using StateId = S;
639 
640  // This constructor takes a state distance vector that, if non-null and if
641  // the Weight type has the path property, will entertain the shortest-first
642  // queue using the natural order w.r.t to the distance.
643  template <class Arc, class ArcFilter>
645  const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
647  using Weight = typename Arc::Weight;
648  // We need to have variables of type Less and Compare, so we use
649  // ErrorLess if the type NaturalLess<Weight> cannot be instantiated due
650  // to lack of path property.
651  using Less = std::conditional_t<IsPath<Weight>::value, NaturalLess<Weight>,
654  // First checks if the FST is known to have these properties.
655  const auto props =
657  if ((props & kTopSorted) || fst.Start() == kNoStateId) {
658  queue_ = std::make_unique<StateOrderQueue<StateId>>();
659  VLOG(2) << "AutoQueue: using state-order discipline";
660  } else if (props & kAcyclic) {
661  queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
662  VLOG(2) << "AutoQueue: using top-order discipline";
663  } else if ((props & kUnweighted) && IsIdempotent<Weight>::value) {
664  queue_ = std::make_unique<LifoQueue<StateId>>();
665  VLOG(2) << "AutoQueue: using LIFO discipline";
666  } else {
667  uint64_t properties;
668  // Decomposes into strongly-connected components.
669  SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
670  DfsVisit(fst, &scc_visitor, filter);
671  auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
672  std::vector<QueueType> queue_types(nscc);
673  std::unique_ptr<Less> less;
674  std::unique_ptr<Compare> comp;
675  if constexpr (IsPath<Weight>::value) {
676  if (distance) {
677  less = std::make_unique<Less>();
678  comp = std::make_unique<Compare>(*distance, *less);
679  }
680  }
681  // Finds the queue type to use per SCC.
682  bool unweighted;
683  bool all_trivial;
684  SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
685  &unweighted);
686  // If unweighted and semiring is idempotent, uses LIFO queue.
687  if (unweighted) {
688  queue_ = std::make_unique<LifoQueue<StateId>>();
689  VLOG(2) << "AutoQueue: using LIFO discipline";
690  return;
691  }
692  // If all the SCC are trivial, the FST is acyclic and the scc number gives
693  // the topological order.
694  if (all_trivial) {
695  queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
696  VLOG(2) << "AutoQueue: using top-order discipline";
697  return;
698  }
699  VLOG(2) << "AutoQueue: using SCC meta-discipline";
700  queues_.resize(nscc);
701  for (StateId i = 0; i < nscc; ++i) {
702  switch (queue_types[i]) {
703  case TRIVIAL_QUEUE:
704  queues_[i].reset();
705  VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
706  break;
708  // The IsPath test is not needed for correctness. It just saves
709  // instantiating a ShortestFirstQueue that can never be called.
710  if constexpr (IsPath<Weight>::value) {
711  queues_[i] =
712  std::make_unique<ShortestFirstQueue<StateId, Compare, false>>(
713  *comp);
714  VLOG(3) << "AutoQueue: SCC #" << i
715  << ": using shortest-first discipline";
716  } else {
717  // SccQueueType should ensure this can never happen.
718  FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight "
719  << Weight::Type();
720  queues_[i].reset();
721  }
722  break;
723  case LIFO_QUEUE:
724  queues_[i] = std::make_unique<LifoQueue<StateId>>();
725  VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
726  break;
727  case FIFO_QUEUE:
728  default:
729  queues_[i] = std::make_unique<FifoQueue<StateId>>();
730  VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
731  break;
732  }
733  }
734  queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
735  scc_, &queues_);
736  }
737  }
738 
739  ~AutoQueue() override = default;
740 
741  StateId Head() const final { return queue_->Head(); }
742 
743  void Enqueue(StateId s) final { queue_->Enqueue(s); }
744 
745  void Dequeue() final { queue_->Dequeue(); }
746 
747  void Update(StateId s) final { queue_->Update(s); }
748 
749  bool Empty() const final { return queue_->Empty(); }
750 
751  void Clear() final { queue_->Clear(); }
752 
753  private:
754  template <class Arc, class ArcFilter, class Less>
755  static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
756  std::vector<QueueType> *queue_types,
757  ArcFilter filter, Less *less, bool *all_trivial,
758  bool *unweighted);
759 
760  std::unique_ptr<QueueBase<StateId>> queue_;
761  std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
762  std::vector<StateId> scc_;
763 };
764 
765 // Examines the states in an FST's strongly connected components and determines
766 // which type of queue to use per SCC. Stores result as a vector of QueueTypes
767 // which is assumed to have length equal to the number of SCCs. An arc filter
768 // is used to limit the transitions considered (e.g., only the epsilon graph).
769 // The argument all_trivial is set to true if every queue is the trivial queue.
770 // The argument unweighted is set to true if the semiring is idempotent and all
771 // the arc weights are equal to Zero() or One().
772 template <class StateId>
773 template <class Arc, class ArcFilter, class Less>
775  const std::vector<StateId> &scc,
776  std::vector<QueueType> *queue_type,
777  ArcFilter filter, Less *less,
778  bool *all_trivial, bool *unweighted) {
779  using StateId = typename Arc::StateId;
780  using Weight = typename Arc::Weight;
781  *all_trivial = true;
782  *unweighted = true;
783  for (StateId i = 0; i < queue_type->size(); ++i) {
784  (*queue_type)[i] = TRIVIAL_QUEUE;
785  }
786  for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
787  const auto state = sit.Value();
788  for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
789  const auto &arc = ait.Value();
790  if (!filter(arc)) continue;
791  if (scc[state] == scc[arc.nextstate]) {
792  auto &type = (*queue_type)[scc[state]];
793  if constexpr (!IsPath<Weight>::value) {
794  type = FIFO_QUEUE;
795  } else if (!less || (*less)(arc.weight, Weight::One())) {
796  type = FIFO_QUEUE;
797  } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
799  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
800  type = SHORTEST_FIRST_QUEUE;
801  } else {
802  type = LIFO_QUEUE;
803  }
804  }
805  if (type != TRIVIAL_QUEUE) *all_trivial = false;
806  }
808  (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
809  *unweighted = false;
810  }
811  }
812  }
813 }
814 
815 // An A* estimate is a function object that maps from a state ID to an
816 // estimate of the shortest distance to the final states.
817 
818 // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
819 // algorithm.
820 template <typename StateId, typename Weight>
822  constexpr Weight operator()(StateId) const { return Weight::One(); }
823 };
824 
825 // A non-trivial A* estimate using a vector of the estimated future costs.
826 template <typename StateId, typename Weight>
828  public:
829  NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
830 
831  const Weight &operator()(StateId s) const {
832  return (s < beta_.size()) ? beta_[s] : kZero;
833  }
834 
835  private:
836  static constexpr Weight kZero = Weight::Zero();
837 
838  const std::vector<Weight> &beta_;
839 };
840 
841 // Given a vector that maps from states to weights representing the shortest
842 // distance from the initial state, a comparison function object between
843 // weights, and an estimate of the shortest distance to the final states, this
844 // class defines a comparison function object between states.
845 template <typename S, typename Less, typename Estimate>
847  public:
848  using StateId = S;
849  using Weight = typename Less::Weight;
850 
851  AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
852  const Estimate &estimate)
853  : weights_(weights), less_(less), estimate_(estimate) {}
854 
855  bool operator()(StateId s1, StateId s2) const {
856  const auto w1 = Times(weights_[s1], estimate_(s1));
857  const auto w2 = Times(weights_[s2], estimate_(s2));
858  return less_(w1, w2);
859  }
860 
861  const Estimate &GetEstimate() const { return estimate_; }
862 
863  private:
864  const std::vector<Weight> &weights_;
865  const Less &less_;
866  const Estimate &estimate_;
867 };
868 
869 // A* queue discipline templated on StateId, Weight, and Estimate.
870 template <typename S, typename Weight, typename Estimate>
872  : public ShortestFirstQueue<
873  S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
874  public:
875  using StateId = S;
877 
878  NaturalAStarQueue(const std::vector<Weight> &distance,
879  const Estimate &estimate)
881  Compare(distance, less_, estimate)) {}
882 
883  ~NaturalAStarQueue() override = default;
884 
885  private:
886  // This is non-static because the constructor for non-idempotent weights will
887  // result in an error.
888  const NaturalLess<Weight> less_{};
889 };
890 
891 // A state equivalence class is a function object that maps from a state ID to
892 // an equivalence class (state) ID. The trivial equivalence class maps a state
893 // ID to itself.
894 template <typename StateId>
896  StateId operator()(StateId s) const { return s; }
897 };
898 
899 // Distance-based pruning queue discipline: Enqueues a state only when its
900 // shortest distance (so far), as specified by distance, is less than (as
901 // specified by comp) the shortest distance Times() the threshold to any state
902 // in the same equivalence class, as specified by the functor class_func. The
903 // underlying queue discipline is specified by queue.
904 //
905 // This is not a final class.
906 template <typename Queue, typename Less, typename ClassFnc>
907 class PruneQueue : public QueueBase<typename Queue::StateId> {
908  public:
909  using StateId = typename Queue::StateId;
910  using Weight = typename Less::Weight;
911 
912  PruneQueue(const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
913  const Less &less, const ClassFnc &class_fnc, Weight threshold)
915  distance_(distance),
916  queue_(std::move(queue)),
917  less_(less),
918  class_fnc_(class_fnc),
919  threshold_(std::move(threshold)) {}
920 
921  ~PruneQueue() override = default;
922 
923  StateId Head() const override { return queue_->Head(); }
924 
925  void Enqueue(StateId s) override {
926  const auto c = class_fnc_(s);
927  if (c >= class_distance_.size()) {
928  class_distance_.resize(c + 1, Weight::Zero());
929  }
930  if (less_(distance_[s], class_distance_[c])) {
931  class_distance_[c] = distance_[s];
932  }
933  // Enqueues only if below threshold limit.
934  const auto limit = Times(class_distance_[c], threshold_);
935  if (less_(distance_[s], limit)) queue_->Enqueue(s);
936  }
937 
938  void Dequeue() override { queue_->Dequeue(); }
939 
940  void Update(StateId s) override {
941  const auto c = class_fnc_(s);
942  if (less_(distance_[s], class_distance_[c])) {
943  class_distance_[c] = distance_[s];
944  }
945  queue_->Update(s);
946  }
947 
948  bool Empty() const override { return queue_->Empty(); }
949 
950  void Clear() override { queue_->Clear(); }
951 
952  private:
953  const std::vector<Weight> &distance_; // Shortest distance to state.
954  std::unique_ptr<Queue> queue_;
955  const Less &less_; // Borrowed reference.
956  const ClassFnc &class_fnc_; // Equivalence class functor.
957  Weight threshold_; // Pruning weight threshold.
958  std::vector<Weight> class_distance_; // Shortest distance to class.
959 };
960 
961 // Pruning queue discipline (see above) using the weight's natural order for the
962 // comparison function. The ownership of the queue argument is given to this
963 // class.
964 template <typename Queue, typename Weight, typename ClassFnc>
965 class NaturalPruneQueue final
966  : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
967  public:
968  using StateId = typename Queue::StateId;
969 
970  NaturalPruneQueue(const std::vector<Weight> &distance,
971  std::unique_ptr<Queue> queue, const ClassFnc &class_fnc,
972  Weight threshold)
973  : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
974  distance, std::move(queue), NaturalLess<Weight>(), class_fnc,
975  threshold) {}
976 
977  ~NaturalPruneQueue() override = default;
978 };
979 
980 // Filter-based pruning queue discipline: enqueues a state only if allowed by
981 // the filter, specified by the state filter functor argument. The underlying
982 // queue discipline is specified by the queue argument.
983 template <typename Queue, typename Filter>
984 class FilterQueue : public QueueBase<typename Queue::StateId> {
985  public:
986  using StateId = typename Queue::StateId;
987 
988  FilterQueue(std::unique_ptr<Queue> queue, const Filter &filter)
990  queue_(std::move(queue)),
991  filter_(filter) {}
992 
993  ~FilterQueue() override = default;
994 
995  StateId Head() const final { return queue_->Head(); }
996 
997  // Enqueues only if allowed by state filter.
998  void Enqueue(StateId s) final {
999  if (filter_(s)) queue_->Enqueue(s);
1000  }
1001 
1002  void Dequeue() final { queue_->Dequeue(); }
1003 
1004  void Update(StateId s) final {}
1005 
1006  bool Empty() const final { return queue_->Empty(); }
1007 
1008  void Clear() final { queue_->Clear(); }
1009 
1010  private:
1011  std::unique_ptr<Queue> queue_;
1012  const Filter &filter_;
1013 };
1014 
1015 } // namespace fst
1016 
1017 #endif // FST_QUEUE_H_
virtual bool Empty() const =0
void Clear() override
Definition: queue.h:171
constexpr uint64_t kCyclic
Definition: properties.h:109
bool Empty() const final
Definition: queue.h:195
bool Empty() const final
Definition: queue.h:141
void Enqueue(StateId s) override
Definition: queue.h:163
StateId Head() const final
Definition: queue.h:452
void Update(StateId s) override
Definition: queue.h:940
void Enqueue(StateId s) override
Definition: queue.h:221
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:270
void Clear() final
Definition: queue.h:143
StateId Head() const final
Definition: queue.h:187
void SetError(bool error)
Definition: queue.h:101
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void Clear() override
Definition: queue.h:249
FilterQueue(std::unique_ptr< Queue > queue, const Filter &filter)
Definition: queue.h:988
void Update(StateId) final
Definition: queue.h:139
QueueType
Definition: queue.h:76
NaturalAStarQueue(const std::vector< Weight > &distance, const Estimate &estimate)
Definition: queue.h:878
NaturalShortestFirstQueue(const std::vector< Weight > &distance)
Definition: queue.h:311
StateWeightCompare(const std::vector< Weight > &weights, const Less &less)
Definition: queue.h:272
bool Empty() const final
Definition: queue.h:520
StateId Head() const final
Definition: queue.h:133
typename Queue::StateId StateId
Definition: queue.h:986
void Enqueue(StateId s) final
Definition: queue.h:569
void Clear() final
Definition: queue.h:522
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
virtual ~QueueBase()=default
void Enqueue(StateId s) final
Definition: queue.h:135
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:112
bool operator()(const Weight &, const Weight &) const
Definition: queue.h:294
ShortestFirstQueue(Compare comp)
Definition: queue.h:214
void Enqueue(StateId s) final
Definition: queue.h:743
void Enqueue(StateId s) final
Definition: queue.h:501
StateId Head() const override
Definition: queue.h:219
void Clear() final
Definition: queue.h:613
StateId Head() const override
Definition: queue.h:356
void Update(StateId) override
Definition: queue.h:167
bool operator()(StateId s1, StateId s2) const
Definition: queue.h:855
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:910
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:159
void Update(StateId) final
Definition: queue.h:193
bool Empty() const override
Definition: queue.h:169
TopOrderQueue(const Fst< Arc > &fst, ArcFilter filter)
Definition: queue.h:426
constexpr uint64_t kTopSorted
Definition: properties.h:119
const Weight & operator()(StateId s) const
Definition: queue.h:831
NaturalAStarEstimate(const std::vector< Weight > &beta)
Definition: queue.h:829
void Dequeue() final
Definition: queue.h:137
bool Error() const
Definition: queue.h:103
StateId Head() const override
Definition: queue.h:923
constexpr int kNoStateId
Definition: fst.h:196
virtual void Dequeue()=0
void Dequeue() final
Definition: queue.h:513
QueueBase(QueueType type)
Definition: queue.h:99
void Update(StateId) final
Definition: queue.h:470
StateId Head() const final
Definition: queue.h:554
#define FSTERROR()
Definition: util.h:56
void Enqueue(StateId s) override
Definition: queue.h:925
void Update(StateId s) final
Definition: queue.h:1004
StateId operator()(StateId s) const
Definition: queue.h:896
bool Empty() const final
Definition: queue.h:599
virtual void Update(StateId)=0
void Update(StateId s) final
Definition: queue.h:747
void Enqueue(StateId s) final
Definition: queue.h:189
constexpr uint64_t kAcyclic
Definition: properties.h:111
virtual void Clear()=0
virtual void Enqueue(StateId)=0
constexpr Weight operator()(StateId) const
Definition: queue.h:822
typename NaturalLess< Weight >::Weight Weight
Definition: queue.h:849
AutoQueue(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > *distance, ArcFilter filter)
Definition: queue.h:644
#define VLOG(level)
Definition: log.h:54
virtual StateId Start() const =0
void Dequeue() final
Definition: queue.h:465
void Dequeue() final
Definition: queue.h:1002
void Dequeue() override
Definition: queue.h:230
const Estimate & GetEstimate() const
Definition: queue.h:861
void Clear() final
Definition: queue.h:474
bool Empty() const override
Definition: queue.h:948
void Dequeue() override
Definition: queue.h:165
TopOrderQueue(const std::vector< StateId > &order)
Definition: queue.h:443
AStarWeightCompare(const std::vector< Weight > &weights, const Less &less, const Estimate &estimate)
Definition: queue.h:851
StateId Head() const final
Definition: queue.h:995
void Dequeue() final
Definition: queue.h:587
StateId Head() const final
Definition: queue.h:499
constexpr uint64_t kUnweighted
Definition: properties.h:106
StateId Head() const final
Definition: queue.h:741
void Enqueue(StateId s) final
Definition: queue.h:454
void Dequeue() final
Definition: queue.h:191
void Update(StateId s) override
Definition: queue.h:238
void Clear() final
Definition: queue.h:197
void Update(StateId) final
Definition: queue.h:518
NaturalPruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:970
void Enqueue(StateId s) override
Definition: queue.h:367
SccQueue(const std::vector< StateId > &scc, std::vector< std::unique_ptr< Queue >> *queue)
Definition: queue.h:544
bool Empty() const override
Definition: queue.h:247
void Update(StateId s) final
Definition: queue.h:595
QueueType Type() const
Definition: queue.h:105
Queue::StateId StateId
Definition: queue.h:93
bool operator()(const StateId s1, const StateId s2) const
Definition: queue.h:275
void Enqueue(StateId s) final
Definition: queue.h:998
PruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const Less &less, const ClassFnc &class_fnc, Weight threshold)
Definition: queue.h:912
void Dequeue() override
Definition: queue.h:938
bool Empty() const final
Definition: queue.h:1006
PruneNaturalShortestFirstQueue(const std::vector< Weight > &distance, ssize_t arc_threshold, ssize_t state_limit=0)
Definition: queue.h:346
void Clear() final
Definition: queue.h:751
ssize_t Size() const
Definition: queue.h:254
virtual StateId Head() const =0
typename Queue::StateId StateId
Definition: queue.h:968
void Clear() override
Definition: queue.h:950
void Clear() final
Definition: queue.h:1008
void Dequeue() final
Definition: queue.h:745
StateId Head() const override
Definition: queue.h:161
bool Empty() const final
Definition: queue.h:472
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
Definition: weight.h:162
bool Empty() const final
Definition: queue.h:749
const Compare & GetCompare() const
Definition: queue.h:256