23 #include <sys/types.h> 29 #include <type_traits> 103 bool Error()
const {
return error_; }
113 virtual bool Empty()
const = 0;
114 virtual void Clear() = 0;
169 bool Empty()
const override {
return queue_.empty(); }
171 void Clear()
override { queue_ = std::queue<StateId>(); }
174 std::queue<StateId> queue_;
195 bool Empty() const final {
return stack_.empty(); }
197 void Clear() final { stack_.clear(); }
200 std::vector<StateId> stack_;
209 template <
typename S,
typename Compare,
bool update = true>
224 key_[s] = heap_.Insert(s);
240 if (s >= key_.size() || key_[s] ==
kNoStateId) {
243 heap_.Update(key_[s], s);
247 bool Empty()
const override {
return heap_.Empty(); }
251 if (update) key_.clear();
254 ssize_t
Size()
const {
return heap_.Size(); }
256 const Compare &
GetCompare()
const {
return heap_.GetCompare(); }
260 std::vector<ssize_t> key_;
267 template <
typename StateId,
typename Less>
273 : weights_(weights), less_(less) {}
276 return less_(weights_[s1], weights_[s2]);
281 const std::vector<Weight> &weights_;
292 FSTERROR() <<
"ErrorLess: instantiated for Weight " << Weight::Type();
302 template <
typename S,
typename Weight,
typename Less>
317 template <
typename S,
typename Weight>
343 template <
typename S,
typename Weight>
351 ssize_t arc_threshold, ssize_t state_limit = 0)
353 arc_threshold_(arc_threshold),
354 state_limit_(state_limit),
356 max_head_steps_(0) {}
361 const auto head = Base::Head();
364 if (head < steps_.size()) {
365 max_head_steps_ = std::max(steps_[head], max_head_steps_);
366 head_steps_ = steps_[head];
374 const ssize_t state_steps = head_steps_ + 1;
375 if (s >= steps_.size()) {
376 steps_.resize(s + 1, state_steps);
379 steps_[s] = state_steps;
383 ssize_t adjusted_threshold = arc_threshold_;
384 if (Base::Size() > state_limit_ && state_limit_ > 0) {
385 adjusted_threshold = std::max<ssize_t>(
386 0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
389 if (state_steps > (max_head_steps_ - adjusted_threshold) ||
390 arc_threshold_ < 0) {
391 if (adjusted_threshold == 0 && state_limit_ > 0) {
403 std::vector<ssize_t> steps_;
406 const ssize_t arc_threshold_;
409 const ssize_t state_limit_;
414 mutable ssize_t head_steps_;
416 mutable ssize_t max_head_steps_;
429 template <
class Arc,
class ArcFilter>
438 DfsVisit(fst, &top_order_visitor, filter);
440 FSTERROR() <<
"TopOrderQueue: FST is not acyclic";
459 if (front_ > back_) {
460 front_ = back_ = order_[s];
461 }
else if (order_[s] > back_) {
463 }
else if (order_[s] < front_) {
466 state_[order_[s]] = s;
471 while ((front_ <= back_) && (state_[front_] ==
kNoStateId)) ++front_;
476 bool Empty() const final {
return front_ > back_; }
487 std::vector<StateId> order_;
488 std::vector<StateId> state_;
506 if (front_ > back_) {
508 }
else if (s > back_) {
510 }
else if (s < front_) {
513 while (enqueued_.size() <= s) enqueued_.push_back(
false);
518 enqueued_[front_] =
false;
519 while ((front_ <= back_) && (enqueued_[front_] ==
false)) ++front_;
524 bool Empty() const final {
return front_ > back_; }
527 for (
StateId i = front_; i <= back_; ++i) enqueued_[i] =
false;
535 std::vector<bool> enqueued_;
541 template <
class S,
class Queue>
549 std::vector<std::unique_ptr<Queue>> *queue)
559 while ((front_ <= back_) &&
560 (((*queue_)[front_] && (*queue_)[front_]->
Empty()) ||
561 (((*queue_)[front_] ==
nullptr) &&
562 ((front_ >= trivial_queue_.size()) ||
566 if ((*queue_)[front_]) {
567 return (*queue_)[front_]->Head();
569 return trivial_queue_[front_];
574 if (front_ > back_) {
575 front_ = back_ = scc_[s];
576 }
else if (scc_[s] > back_) {
578 }
else if (scc_[s] < front_) {
581 if ((*queue_)[scc_[s]]) {
582 (*queue_)[scc_[s]]->Enqueue(s);
584 while (trivial_queue_.size() <= scc_[s]) {
587 trivial_queue_[scc_[s]] = s;
592 if ((*queue_)[front_]) {
593 (*queue_)[front_]->Dequeue();
594 }
else if (front_ < trivial_queue_.size()) {
600 if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
605 if (front_ < back_) {
607 }
else if (front_ > back_) {
609 }
else if ((*queue_)[front_]) {
610 return (*queue_)[front_]->Empty();
612 return (front_ >= trivial_queue_.size()) ||
618 for (
StateId i = front_; i <= back_; ++i) {
620 (*queue_)[i]->Clear();
621 }
else if (i < trivial_queue_.size()) {
630 std::vector<std::unique_ptr<Queue>> *queue_;
631 const std::vector<StateId> &scc_;
634 std::vector<StateId> trivial_queue_;
647 template <
class Arc,
class ArcFilter>
649 const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
651 using Weight =
typename Arc::Weight;
662 queue_ = std::make_unique<StateOrderQueue<StateId>>();
663 VLOG(2) <<
"AutoQueue: using state-order discipline";
665 queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
666 VLOG(2) <<
"AutoQueue: using top-order discipline";
668 queue_ = std::make_unique<LifoQueue<StateId>>();
669 VLOG(2) <<
"AutoQueue: using LIFO discipline";
674 DfsVisit(fst, &scc_visitor, filter);
675 auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
676 std::vector<QueueType> queue_types(nscc);
677 std::unique_ptr<Less> less;
678 std::unique_ptr<Compare> comp;
681 less = std::make_unique<Less>();
682 comp = std::make_unique<Compare>(*distance, *less);
688 SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
692 queue_ = std::make_unique<LifoQueue<StateId>>();
693 VLOG(2) <<
"AutoQueue: using LIFO discipline";
699 queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
700 VLOG(2) <<
"AutoQueue: using top-order discipline";
703 VLOG(2) <<
"AutoQueue: using SCC meta-discipline";
704 queues_.resize(nscc);
705 for (
StateId i = 0; i < nscc; ++i) {
706 switch (queue_types[i]) {
709 VLOG(3) <<
"AutoQueue: SCC #" << i <<
": using trivial discipline";
716 std::make_unique<ShortestFirstQueue<StateId, Compare, false>>(
718 VLOG(3) <<
"AutoQueue: SCC #" << i
719 <<
": using shortest-first discipline";
722 FSTERROR() <<
"Got SHORTEST_FIRST_QUEUE for non-Path Weight " 728 queues_[i] = std::make_unique<LifoQueue<StateId>>();
729 VLOG(3) <<
"AutoQueue: SCC #" << i <<
": using LIFO discipline";
733 queues_[i] = std::make_unique<FifoQueue<StateId>>();
734 VLOG(3) <<
"AutoQueue: SCC #" << i <<
": using FIFO discipine";
738 queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
753 bool Empty() const final {
return queue_->Empty(); }
755 void Clear() final { queue_->Clear(); }
758 template <
class Arc,
class ArcFilter,
class Less>
759 static void SccQueueType(
const Fst<Arc> &
fst,
const std::vector<StateId> &scc,
760 std::vector<QueueType> *queue_types,
761 ArcFilter filter, Less *less,
bool *all_trivial,
764 std::unique_ptr<QueueBase<StateId>> queue_;
765 std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
766 std::vector<StateId> scc_;
776 template <
class StateId>
777 template <
class Arc,
class ArcFilter,
class Less>
779 const std::vector<StateId> &scc,
780 std::vector<QueueType> *queue_type,
781 ArcFilter filter, Less *less,
782 bool *all_trivial,
bool *unweighted) {
783 using StateId =
typename Arc::StateId;
784 using Weight =
typename Arc::Weight;
787 for (
StateId i = 0; i < queue_type->size(); ++i) {
791 const auto state = sit.Value();
793 const auto &arc = ait.Value();
794 if (!filter(arc))
continue;
795 if (scc[state] == scc[arc.nextstate]) {
796 auto &type = (*queue_type)[scc[state]];
799 }
else if (!less || (*less)(arc.weight, Weight::One())) {
803 (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
812 (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
824 template <
typename StateId,
typename Weight>
830 template <
typename StateId,
typename Weight>
836 return (s < beta_.size()) ? beta_[s] : kZero;
840 static constexpr Weight kZero = Weight::Zero();
842 const std::vector<Weight> &beta_;
849 template <
typename S,
typename Less,
typename Estimate>
856 const Estimate &estimate)
857 : weights_(weights), less_(less), estimate_(estimate) {}
860 const auto w1 =
Times(weights_[s1], estimate_(s1));
861 const auto w2 =
Times(weights_[s2], estimate_(s2));
862 return less_(w1, w2);
868 const std::vector<Weight> &weights_;
870 const Estimate &estimate_;
874 template <
typename S,
typename Weight,
typename Estimate>
877 S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
883 const Estimate &estimate)
885 Compare(distance, less_, estimate)) {}
898 template <
typename StateId>
910 template <
typename Queue,
typename Less,
typename ClassFnc>
916 PruneQueue(
const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
917 const Less &less,
const ClassFnc &class_fnc,
Weight threshold)
920 queue_(std::move(queue)),
922 class_fnc_(class_fnc),
923 threshold_(std::move(threshold)) {}
930 const auto c = class_fnc_(s);
931 if (c >= class_distance_.size()) {
932 class_distance_.resize(c + 1, Weight::Zero());
934 if (less_(distance_[s], class_distance_[c])) {
935 class_distance_[c] = distance_[s];
938 const auto limit =
Times(class_distance_[c], threshold_);
939 if (less_(distance_[s], limit)) queue_->Enqueue(s);
942 void Dequeue()
override { queue_->Dequeue(); }
945 const auto c = class_fnc_(s);
946 if (less_(distance_[s], class_distance_[c])) {
947 class_distance_[c] = distance_[s];
952 bool Empty()
const override {
return queue_->Empty(); }
954 void Clear()
override { queue_->Clear(); }
957 const std::vector<Weight> &distance_;
958 std::unique_ptr<Queue> queue_;
960 const ClassFnc &class_fnc_;
962 std::vector<Weight> class_distance_;
968 template <
typename Queue,
typename Weight,
typename ClassFnc>
970 :
public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
975 std::unique_ptr<Queue> queue,
const ClassFnc &class_fnc,
987 template <
typename Queue,
typename Filter>
994 queue_(std::move(queue)),
1003 if (filter_(s)) queue_->Enqueue(s);
1010 bool Empty() const final {
return queue_->Empty(); }
1015 std::unique_ptr<Queue> queue_;
1016 const Filter &filter_;
1021 #endif // FST_QUEUE_H_ virtual bool Empty() const =0
constexpr uint64_t kCyclic
void Enqueue(StateId s) override
StateId Head() const final
void Update(StateId s) override
void Enqueue(StateId s) override
typename Less::Weight Weight
StateId Head() const final
void SetError(bool error)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
FilterQueue(std::unique_ptr< Queue > queue, const Filter &filter)
void Update(StateId) final
NaturalAStarQueue(const std::vector< Weight > &distance, const Estimate &estimate)
StateWeightCompare(const std::vector< Weight > &weights, const Less &less)
StateId Head() const final
typename Queue::StateId StateId
void Enqueue(StateId s) final
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
virtual ~QueueBase()=default
void Enqueue(StateId s) final
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
bool operator()(const Weight &, const Weight &) const
ShortestFirstQueue(Compare comp)
void Enqueue(StateId s) final
void Enqueue(StateId s) final
StateId Head() const override
StateId Head() const override
void Update(StateId) override
bool operator()(StateId s1, StateId s2) const
typename NaturalLess< Weight >::Weight Weight
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
void Update(StateId) final
bool Empty() const override
TopOrderQueue(const Fst< Arc > &fst, ArcFilter filter)
constexpr uint64_t kTopSorted
const Weight & operator()(StateId s) const
NaturalAStarEstimate(const std::vector< Weight > &beta)
StateId Head() const override
QueueBase(QueueType type)
void Update(StateId) final
StateId Head() const final
void Enqueue(StateId s) override
void Update(StateId s) final
StateId operator()(StateId s) const
virtual void Update(StateId)=0
void Update(StateId s) final
void Enqueue(StateId s) final
constexpr uint64_t kAcyclic
virtual void Enqueue(StateId)=0
constexpr Weight operator()(StateId) const
typename NaturalLess< Weight >::Weight Weight
AutoQueue(const Fst< Arc > &fst, const std::vector< typename Arc::Weight > *distance, ArcFilter filter)
virtual StateId Start() const =0
const Estimate & GetEstimate() const
bool Empty() const override
TopOrderQueue(const std::vector< StateId > &order)
AStarWeightCompare(const std::vector< Weight > &weights, const Less &less, const Estimate &estimate)
StateId Head() const final
StateId Head() const final
constexpr uint64_t kUnweighted
StateId Head() const final
void Enqueue(StateId s) final
void Update(StateId s) override
void Update(StateId) final
NaturalPruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const ClassFnc &class_fnc, Weight threshold)
CustomShortestFirstQueue(const std::vector< Weight > &distance)
void Enqueue(StateId s) override
SccQueue(const std::vector< StateId > &scc, std::vector< std::unique_ptr< Queue >> *queue)
bool Empty() const override
void Update(StateId s) final
bool operator()(const StateId s1, const StateId s2) const
void Enqueue(StateId s) final
PruneQueue(const std::vector< Weight > &distance, std::unique_ptr< Queue > queue, const Less &less, const ClassFnc &class_fnc, Weight threshold)
PruneNaturalShortestFirstQueue(const std::vector< Weight > &distance, ssize_t arc_threshold, ssize_t state_limit=0)
virtual StateId Head() const =0
typename Queue::StateId StateId
StateId Head() const override
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
const Compare & GetCompare() const