20 #ifndef FST_SHORTEST_DISTANCE_H_ 21 #define FST_SHORTEST_DISTANCE_H_ 41 template <
class Arc,
class Queue,
class ArcFilter>
61 float delta = kShortestDelta,
bool first_path =
false)
62 : state_queue(state_queue),
63 arc_filter(arc_filter),
66 first_path(first_path) {}
78 template <
class Arc,
class Queue,
class ArcFilter,
92 weight_equal_(opts.
delta),
100 distance_->reserve(num_states);
101 adder_.reserve(num_states);
102 radder_.reserve(num_states);
103 enqueued_.reserve(num_states);
109 bool Error()
const {
return error_; }
112 void EnsureDistanceIndexIsValid(std::size_t index) {
113 while (distance_->size() <= index) {
114 distance_->push_back(Weight::Zero());
117 enqueued_.push_back(
false);
122 void EnsureSourcesIndexIsValid(std::size_t index) {
123 while (sources_.size() <= index) {
130 std::vector<Weight> *distance_;
132 ArcFilter arc_filter_;
133 WeightEqual weight_equal_;
134 const bool first_path_;
137 std::vector<Adder<Weight>> adder_;
138 std::vector<Adder<Weight>> radder_;
139 std::vector<bool> enqueued_;
140 std::vector<StateId> sources_;
148 template <
class Arc,
class Queue,
class ArcFilter,
class WeightEqual>
152 if (fst_.Properties(
kError,
false)) error_ =
true;
156 FSTERROR() <<
"ShortestDistance: Weight needs to be right distributive: " 161 if (first_path_ && !(Weight::Properties() &
kPath)) {
162 FSTERROR() <<
"ShortestDistance: The first_path option is disallowed when " 163 <<
"Weight does not have the path property: " << Weight::Type();
167 state_queue_->Clear();
174 if (source ==
kNoStateId) source = fst_.Start();
175 EnsureDistanceIndexIsValid(source);
177 EnsureSourcesIndexIsValid(source);
178 sources_[
source] = source_id_;
180 (*distance_)[
source] = Weight::One();
181 adder_[
source].Reset(Weight::One());
182 radder_[
source].Reset(Weight::One());
184 state_queue_->Enqueue(source);
185 while (!state_queue_->Empty()) {
186 const auto state = state_queue_->Head();
187 state_queue_->Dequeue();
188 EnsureDistanceIndexIsValid(state);
189 if (first_path_ && (fst_.Final(state) != Weight::Zero()))
break;
190 enqueued_[state] =
false;
191 const auto r = radder_[state].Sum();
192 radder_[state].Reset();
195 const auto &arc = aiter.Value();
196 const auto nextstate = arc.nextstate;
197 if (!arc_filter_(arc))
continue;
198 EnsureDistanceIndexIsValid(nextstate);
200 EnsureSourcesIndexIsValid(nextstate);
201 if (sources_[nextstate] != source_id_) {
202 (*distance_)[nextstate] = Weight::Zero();
203 adder_[nextstate].Reset();
204 radder_[nextstate].Reset();
205 enqueued_[nextstate] =
false;
206 sources_[nextstate] = source_id_;
209 auto &nd = (*distance_)[nextstate];
210 auto &na = adder_[nextstate];
211 auto &nr = radder_[nextstate];
212 auto weight =
Times(r, arc.weight);
213 if (!weight_equal_(nd,
Plus(nd, weight))) {
216 if (!nd.Member() || !nr.Sum().Member()) {
220 if (!enqueued_[nextstate]) {
221 state_queue_->Enqueue(nextstate);
222 enqueued_[nextstate] =
true;
224 state_queue_->Update(nextstate);
230 if (fst_.Properties(
kError,
false)) error_ =
true;
258 template <
class Arc,
class Queue,
class ArcFilter>
260 const Fst<Arc> &
fst, std::vector<typename Arc::Weight> *distance,
265 if (sd_state.
Error()) {
266 distance->assign(1, Arc::Weight::NoWeight());
300 std::vector<typename Arc::Weight> *distance,
301 bool reverse =
false,
float delta = kShortestDelta) {
302 using StateId =
typename Arc::StateId;
315 std::vector<ReverseWeight> rdistance;
322 if (rdistance.size() == 1 && !rdistance[0].Member()) {
323 distance->assign(1, Arc::Weight::NoWeight());
327 distance->reserve(rdistance.size() - 1);
328 while (distance->size() < rdistance.size() - 1) {
329 distance->push_back(rdistance[distance->size() + 1].Reverse());
339 float delta = kShortestDelta) {
340 using StateId =
typename Arc::StateId;
341 using Weight =
typename Arc::Weight;
342 std::vector<Weight> distance;
345 if (distance.size() == 1 && !distance[0].Member()) {
346 return Arc::Weight::NoWeight();
349 for (
StateId state = 0; state < distance.size(); ++state) {
355 const auto state = fst.
Start();
356 if (distance.size() == 1 && !distance[0].Member()) {
357 return Arc::Weight::NoWeight();
359 return state !=
kNoStateId && state < distance.size() ? distance[state]
366 #endif // FST_SHORTEST_DISTANCE_H_
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void ShortestDistance(StateId source)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
virtual Weight Final(StateId) const =0
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
constexpr uint64_t kRightSemiring
Weight Add(const Weight &w)
typename Arc::Weight::ReverseWeight Weight
typename Arc::StateId StateId
virtual StateId Start() const =0
constexpr float kShortestDelta
typename Arc::StateId StateId
Arc::StateId CountStates(const Fst< Arc > &fst)
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
ShortestDistanceState(const Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter, StateId source=kNoStateId, float delta=kShortestDelta, bool first_path=false)
constexpr uint64_t kExpanded
typename Arc::Weight Weight