20 #ifndef FST_SHORTEST_PATH_H_ 21 #define FST_SHORTEST_PATH_H_ 27 #include <type_traits> 48 template <
class Arc,
class Queue,
class ArcFilter>
70 bool unique =
false,
bool has_distance =
false,
72 Weight weight_threshold = Weight::Zero(),
78 has_distance(has_distance),
79 first_path(first_path),
80 weight_threshold(std::move(weight_threshold)),
81 state_threshold(state_threshold) {}
96 const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
97 typename Arc::StateId f_parent) {
98 using StateId =
typename Arc::StateId;
105 d = state, state = parent[state].first) {
112 aiter.
Seek(parent[d].second);
113 auto arc = aiter.
Value();
115 ofst->
AddArc(s_p, std::move(arc));
132 template <
typename S,
typename W,
typename Queue>
140 template <
typename S,
typename W,
typename Estimate>
146 : estimate_(state_queue.GetCompare().GetEstimate()) {}
149 return f ==
Plus(
Times(d, estimate_(s)), f);
153 const Estimate &estimate_;
169 template <
class Arc,
class Queue,
class ArcFilter>
171 const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
173 typename Arc::StateId *f_parent,
174 std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
175 using StateId =
typename Arc::StateId;
176 using Weight =
typename Arc::Weight;
179 "Weight must be right distributive.");
183 std::vector<bool> enqueued;
186 bool final_seen =
false;
187 auto f_distance = Weight::Zero();
190 while (distance->size() <
source) {
191 distance->push_back(Weight::Zero());
192 enqueued.push_back(
false);
195 distance->push_back(Weight::One());
198 enqueued.push_back(
true);
203 const auto sd = (*distance)[s];
211 if (ifst.
Final(s) != Weight::Zero()) {
213 if (f_distance != plus) {
217 if (!f_distance.Member())
return false;
221 const auto &arc = aiter.Value();
222 while (distance->size() <= arc.nextstate) {
223 distance->push_back(Weight::Zero());
224 enqueued.push_back(
false);
227 auto &nd = (*distance)[arc.nextstate];
228 const auto weight =
Times(sd, arc.weight);
229 if (nd !=
Plus(nd, weight)) {
230 nd =
Plus(nd, weight);
231 if (!nd.Member())
return false;
232 (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
233 if (!enqueued[arc.nextstate]) {
235 enqueued[arc.nextstate] =
true;
245 template <
class StateId,
class Weight>
249 const std::vector<Weight> &distance,
StateId superfinal,
253 superfinal_(superfinal),
257 const auto &px = pairs_[x];
258 const auto &py = pairs_[y];
259 const auto wx =
Times(PWeight(px.first), px.second);
260 const auto wy =
Times(PWeight(py.first), py.second);
264 if (px.first == superfinal_ && py.first != superfinal_) {
265 return less_(wy, wx) ||
ApproxEqual(wx, wy, delta_);
266 }
else if (py.first == superfinal_ && px.first != superfinal_) {
267 return less_(wy, wx) && !
ApproxEqual(wx, wy, delta_);
269 return less_(wy, wx);
275 return (state == superfinal_) ? Weight::One()
276 : (state < distance_.size()) ? distance_[state]
280 const std::vector<std::pair<StateId, Weight>> &pairs_;
281 const std::vector<Weight> &distance_;
319 template <
class Arc,
class RevArc>
321 const std::vector<typename Arc::Weight> &distance,
323 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
324 typename Arc::StateId state_threshold =
kNoStateId) {
325 using StateId =
typename Arc::StateId;
326 using Weight =
typename Arc::Weight;
327 using Pair = std::pair<StateId, Weight>;
328 static_assert((Weight::Properties() &
kPath) ==
kPath,
329 "Weight must have path property.");
331 "Weight must be distributive.");
332 if (nshortest <= 0)
return;
340 std::vector<Pair> pairs;
348 distance[ifst.
Start()] == Weight::Zero() ||
349 less(weight_threshold, Weight::One()) || state_threshold == 0) {
354 const auto final_state = ofst->
AddState();
356 while (pairs.size() <= final_state) {
357 pairs.emplace_back(
kNoStateId, Weight::Zero());
359 pairs[final_state] = std::make_pair(ifst.
Start(), Weight::One());
360 std::vector<StateId> heap;
361 heap.push_back(final_state);
367 while (!heap.empty()) {
368 std::pop_heap(heap.begin(), heap.end(), compare);
369 const auto state = heap.back();
370 const auto p = pairs[state];
372 const auto d = (p.first ==
kNoStateId) ? Weight::One()
373 : (p.first < distance.size()) ? distance[p.first]
375 if (less(limit,
Times(d, p.second)) ||
380 while (r.size() <= p.first + 1) r.push_back(0);
384 if (r[p.first + 1] > nshortest)
continue;
388 const auto &rarc = aiter.Value();
389 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
390 const auto weight =
Times(p.second, arc.weight);
392 pairs.emplace_back(arc.nextstate, weight);
393 arc.nextstate = state;
394 ofst->
AddArc(next, std::move(arc));
395 heap.push_back(next);
396 std::push_heap(heap.begin(), heap.end(), compare);
398 const auto final_weight = ifst.
Final(p.first).Reverse();
399 if (final_weight != Weight::Zero()) {
400 const auto weight =
Times(p.second, final_weight);
403 ofst->
AddArc(next, Arc(0, 0, final_weight, state));
404 heap.push_back(next);
405 std::push_heap(heap.begin(), heap.end(), compare);
442 template <
class Arc,
class Queue,
class ArcFilter>
444 std::vector<typename Arc::Weight> *distance,
446 using StateId =
typename Arc::StateId;
447 using Weight =
typename Arc::Weight;
450 "ShortestPath: Weight needs to have the path property and " 454 std::vector<std::pair<StateId, size_t>> parent;
467 if (distance->size() == 1 && !(*distance)[0].Member()) {
477 auto d = Weight::Zero();
480 const auto &arc = aiter.Value();
481 const auto state = arc.nextstate - 1;
482 if (state < distance->size()) {
483 d =
Plus(d,
Times(arc.weight.Reverse(), (*distance)[state]));
487 distance->insert(distance->begin(), d);
492 std::vector<Weight> ddistance;
499 distance->erase(distance->begin());
513 int32_t nshortest = 1,
bool unique =
false,
514 bool first_path =
false,
515 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
516 typename Arc::StateId state_threshold =
kNoStateId,
518 using StateId =
typename Arc::StateId;
519 std::vector<typename Arc::Weight> distance;
523 &state_queue, arc_filter, nshortest, unique,
false,
delta, first_path,
524 weight_threshold, state_threshold);
530 #endif // FST_SHORTEST_PATH_H_ constexpr uint64_t kSemiring
virtual uint64_t Properties(uint64_t mask, bool test) const =0
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
uint64_t ShortestPathProperties(uint64_t props, bool tree=false)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
void Connect(MutableFst< Arc > *fst)
const Arc & Value() const
bool SingleShortestPath(const Fst< Arc > &ifst, std::vector< typename Arc::Weight > *distance, const ShortestPathOptions< Arc, Queue, ArcFilter > &opts, typename Arc::StateId *f_parent, std::vector< std::pair< typename Arc::StateId, size_t >> *parent)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
constexpr uint64_t kRightSemiring
bool operator()(S s, W d, W f) const
void SingleShortestPathBacktrace(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const std::vector< std::pair< typename Arc::StateId, size_t >> &parent, typename Arc::StateId f_parent)
ShortestPathOptions(Queue *queue, ArcFilter filter, int32_t nshortest=1, bool unique=false, bool has_distance=false, float delta=kShortestDelta, bool first_path=false, Weight weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId)
typename Arc::StateId StateId
virtual void SetProperties(uint64_t props, uint64_t mask)=0
virtual StateId Start() const =0
constexpr float kShortestDelta
void NShortestPath(const Fst< RevArc > &ifst, MutableFst< Arc > *ofst, const std::vector< typename Arc::Weight > &distance, int32_t nshortest, float delta=kShortestDelta, typename Arc::Weight weight_threshold=Arc::Weight::Zero(), typename Arc::StateId state_threshold=kNoStateId)
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
typename Arc::StateId StateId
FirstPathSelect(const Queue &)
constexpr uint64_t kFstProperties
virtual void AddArc(StateId, const Arc &)=0
bool operator()(S s, W d, W f) const
virtual const SymbolTable * InputSymbols() const =0
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
bool operator()(const StateId x, const StateId y) const
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
FirstPathSelect(const Queue &state_queue)
virtual StateId NumStates() const =0
ShortestPathCompare(const std::vector< std::pair< StateId, Weight >> &pairs, const std::vector< Weight > &distance, StateId superfinal, float delta)
typename Arc::Weight Weight
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
virtual const SymbolTable * OutputSymbols() const =0