20 #ifndef FST_SHORTEST_PATH_H_ 21 #define FST_SHORTEST_PATH_H_ 25 #include <type_traits> 40 template <
class Arc,
class Queue,
class ArcFilter>
62 bool unique =
false,
bool has_distance =
false,
64 Weight weight_threshold = Weight::Zero(),
70 has_distance(has_distance),
71 first_path(first_path),
72 weight_threshold(std::move(weight_threshold)),
73 state_threshold(state_threshold) {}
88 const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
89 typename Arc::StateId f_parent) {
90 using StateId =
typename Arc::StateId;
97 d = state, state = parent[state].first) {
104 aiter.
Seek(parent[d].second);
105 auto arc = aiter.
Value();
107 ofst->
AddArc(s_p, std::move(arc));
124 template <
typename S,
typename W,
typename Queue>
132 template <
typename S,
typename W,
typename Estimate>
138 : estimate_(state_queue.GetCompare().GetEstimate()) {}
141 return f ==
Plus(
Times(d, estimate_(s)), f);
145 const Estimate &estimate_;
161 template <
class Arc,
class Queue,
class ArcFilter>
163 const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
165 typename Arc::StateId *f_parent,
166 std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
167 using StateId =
typename Arc::StateId;
168 using Weight =
typename Arc::Weight;
171 "Weight must be right distributive.");
175 std::vector<bool> enqueued;
178 bool final_seen =
false;
179 auto f_distance = Weight::Zero();
182 while (distance->size() <
source) {
183 distance->push_back(Weight::Zero());
184 enqueued.push_back(
false);
187 distance->push_back(Weight::One());
190 enqueued.push_back(
true);
195 const auto sd = (*distance)[s];
203 if (ifst.
Final(s) != Weight::Zero()) {
205 if (f_distance != plus) {
209 if (!f_distance.Member())
return false;
213 const auto &arc = aiter.Value();
214 while (distance->size() <= arc.nextstate) {
215 distance->push_back(Weight::Zero());
216 enqueued.push_back(
false);
219 auto &nd = (*distance)[arc.nextstate];
220 const auto weight =
Times(sd, arc.weight);
221 if (nd !=
Plus(nd, weight)) {
222 nd =
Plus(nd, weight);
223 if (!nd.Member())
return false;
224 (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
225 if (!enqueued[arc.nextstate]) {
227 enqueued[arc.nextstate] =
true;
237 template <
class StateId,
class Weight>
241 const std::vector<Weight> &distance,
StateId superfinal,
245 superfinal_(superfinal),
249 const auto &px = pairs_[x];
250 const auto &py = pairs_[y];
251 const auto wx =
Times(PWeight(px.first), px.second);
252 const auto wy =
Times(PWeight(py.first), py.second);
256 if (px.first == superfinal_ && py.first != superfinal_) {
257 return less_(wy, wx) ||
ApproxEqual(wx, wy, delta_);
258 }
else if (py.first == superfinal_ && px.first != superfinal_) {
259 return less_(wy, wx) && !
ApproxEqual(wx, wy, delta_);
261 return less_(wy, wx);
267 return (state == superfinal_)
269 : (state < distance_.size()) ? distance_[state] : Weight::Zero();
272 const std::vector<std::pair<StateId, Weight>> &pairs_;
273 const std::vector<Weight> &distance_;
311 template <
class Arc,
class RevArc>
313 const std::vector<typename Arc::Weight> &distance,
315 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
316 typename Arc::StateId state_threshold =
kNoStateId) {
317 using StateId =
typename Arc::StateId;
318 using Weight =
typename Arc::Weight;
319 using Pair = std::pair<StateId, Weight>;
320 static_assert((Weight::Properties() &
kPath) ==
kPath,
321 "Weight must have path property.");
323 "Weight must be distributive.");
324 if (nshortest <= 0)
return;
332 std::vector<Pair> pairs;
340 distance[ifst.
Start()] == Weight::Zero() ||
341 less(weight_threshold, Weight::One()) || state_threshold == 0) {
346 const auto final_state = ofst->
AddState();
348 while (pairs.size() <= final_state) {
349 pairs.emplace_back(
kNoStateId, Weight::Zero());
351 pairs[final_state] = std::make_pair(ifst.
Start(), Weight::One());
352 std::vector<StateId> heap;
353 heap.push_back(final_state);
359 while (!heap.empty()) {
360 std::pop_heap(heap.begin(), heap.end(), compare);
361 const auto state = heap.back();
362 const auto p = pairs[state];
367 : (p.first < distance.size()) ? distance[p.first] : Weight::Zero();
368 if (less(limit,
Times(d, p.second)) ||
373 while (r.size() <= p.first + 1) r.push_back(0);
377 if (r[p.first + 1] > nshortest)
continue;
381 const auto &rarc = aiter.Value();
382 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
383 const auto weight =
Times(p.second, arc.weight);
385 pairs.emplace_back(arc.nextstate, weight);
386 arc.nextstate = state;
387 ofst->
AddArc(next, std::move(arc));
388 heap.push_back(next);
389 std::push_heap(heap.begin(), heap.end(), compare);
391 const auto final_weight = ifst.
Final(p.first).Reverse();
392 if (final_weight != Weight::Zero()) {
393 const auto weight =
Times(p.second, final_weight);
396 ofst->
AddArc(next, Arc(0, 0, final_weight, state));
397 heap.push_back(next);
398 std::push_heap(heap.begin(), heap.end(), compare);
435 template <
class Arc,
class Queue,
class ArcFilter>
437 std::vector<typename Arc::Weight> *distance,
439 using StateId =
typename Arc::StateId;
440 using Weight =
typename Arc::Weight;
443 "ShortestPath: Weight needs to have the path property and " 447 std::vector<std::pair<StateId, size_t>> parent;
460 if (distance->size() == 1 && !(*distance)[0].Member()) {
470 auto d = Weight::Zero();
473 const auto &arc = aiter.Value();
474 const auto state = arc.nextstate - 1;
475 if (state < distance->size()) {
476 d =
Plus(d,
Times(arc.weight.Reverse(), (*distance)[state]));
480 distance->insert(distance->begin(), d);
485 std::vector<Weight> ddistance;
492 distance->erase(distance->begin());
506 int32_t nshortest = 1,
bool unique =
false,
507 bool first_path =
false,
508 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
509 typename Arc::StateId state_threshold =
kNoStateId,
511 using StateId =
typename Arc::StateId;
512 std::vector<typename Arc::Weight> distance;
516 &state_queue, arc_filter, nshortest, unique,
false,
delta, first_path,
517 weight_threshold, state_threshold);
523 #endif // FST_SHORTEST_PATH_H_ constexpr uint64_t kSemiring
virtual uint64_t Properties(uint64_t mask, bool test) const =0
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
uint64_t ShortestPathProperties(uint64_t props, bool tree=false)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
void Connect(MutableFst< Arc > *fst)
const Arc & Value() const
bool SingleShortestPath(const Fst< Arc > &ifst, std::vector< typename Arc::Weight > *distance, const ShortestPathOptions< Arc, Queue, ArcFilter > &opts, typename Arc::StateId *f_parent, std::vector< std::pair< typename Arc::StateId, size_t >> *parent)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
constexpr uint64_t kRightSemiring
bool operator()(S s, W d, W f) const
void SingleShortestPathBacktrace(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const std::vector< std::pair< typename Arc::StateId, size_t >> &parent, typename Arc::StateId f_parent)
ShortestPathOptions(Queue *queue, ArcFilter filter, int32_t nshortest=1, bool unique=false, bool has_distance=false, float delta=kShortestDelta, bool first_path=false, Weight weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId)
typename Arc::StateId StateId
virtual void SetProperties(uint64_t props, uint64_t mask)=0
virtual StateId Start() const =0
constexpr float kShortestDelta
void NShortestPath(const Fst< RevArc > &ifst, MutableFst< Arc > *ofst, const std::vector< typename Arc::Weight > &distance, int32_t nshortest, float delta=kShortestDelta, typename Arc::Weight weight_threshold=Arc::Weight::Zero(), typename Arc::StateId state_threshold=kNoStateId)
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
typename Arc::StateId StateId
FirstPathSelect(const Queue &)
constexpr uint64_t kFstProperties
virtual void AddArc(StateId, const Arc &)=0
bool operator()(S s, W d, W f) const
virtual const SymbolTable * InputSymbols() const =0
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
bool operator()(const StateId x, const StateId y) const
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
FirstPathSelect(const Queue &state_queue)
virtual StateId NumStates() const =0
ShortestPathCompare(const std::vector< std::pair< StateId, Weight >> &pairs, const std::vector< Weight > &distance, StateId superfinal, float delta)
typename Arc::Weight Weight
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
virtual const SymbolTable * OutputSymbols() const =0