20 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H_ 21 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H_ 23 #include <sys/types.h> 29 #include <unordered_map> 46 #include <unordered_map> 50 template <
class Arc,
class Queue>
57 : keep_parentheses(keep_parentheses), path_gc(path_gc) {}
86 using Label =
typename Arc::Label;
95 : state(s), start(t) {}
98 if (&other ==
this)
return true;
99 return other.
state == state && other.
start == start;
109 : paren_id(paren_id), src_start(src_start), dest_start(dest_start) {}
116 if (&other ==
this)
return true;
117 return (other.
paren_id == paren_id &&
125 : distance(
Weight::Zero()),
137 : gc_(gc), nstates_(0), ngc_(0), finished_(false) {}
140 VLOG(1) <<
"opm size: " << paren_map_.size();
141 VLOG(1) <<
"# of search states: " << nstates_;
142 if (gc_)
VLOG(1) <<
"# of GC'd search states: " << ngc_;
147 search_multimap_.clear();
160 return GetSearchData(paren)->distance;
166 return GetSearchData(paren)->parent;
174 GetSearchData(s)->distance = std::move(weight);
178 GetSearchData(paren)->distance = std::move(weight);
184 GetSearchData(paren)->parent = p;
189 FSTERROR() <<
"PdtShortestPathData: Paren ID does not fit in an int16_t";
191 GetSearchData(s)->paren_id = p;
195 auto *data = GetSearchData(s);
196 data->flags &= ~mask;
197 data->flags |= f & mask;
206 struct SearchStateHash {
208 static constexpr
auto prime = 7853;
215 size_t operator()(
const ParenSpec &paren)
const {
216 static constexpr
auto prime0 = 7853;
217 static constexpr
auto prime1 = 7867;
224 std::unordered_map<SearchState, SearchData, SearchStateHash>;
226 using SearchMultimap = std::unordered_multimap<StateId, StateId>;
229 using ParenMap = std::unordered_map<ParenSpec, SearchData, ParenHash>;
232 if (s == state_)
return state_data_;
234 auto it = search_map_.find(s);
235 if (it == search_map_.end())
return &null_search_data_;
237 return state_data_ = &(it->second);
240 state_data_ = &search_map_[s];
241 if (!(state_data_->flags & kPdtInited)) {
243 if (gc_) search_multimap_.insert(std::make_pair(s.
start, s.
state));
251 if (paren == paren_)
return paren_data_;
253 auto it = paren_map_.find(paren);
254 if (it == paren_map_.end())
return &null_search_data_;
256 return state_data_ = &(it->second);
259 return paren_data_ = &paren_map_[paren];
263 mutable SearchMap search_map_;
264 mutable SearchMultimap search_multimap_;
265 mutable ParenMap paren_map_;
271 mutable size_t nstates_;
286 std::vector<StateId> finals;
287 for (
auto it = search_multimap_.find(start);
288 it != search_multimap_.end() && it->first == start; ++it) {
290 if (search_map_[s].flags & kPdtFinal) finals.push_back(s.
state);
293 for (
const auto state : finals) {
296 auto &sdata = search_map_[ss];
297 if (sdata.flags & kPdtMarked)
break;
299 const auto p = sdata.parent;
300 if (p.start != start && p.start !=
kNoLabel) {
302 ss = paren_map_[paren].parent;
309 auto it = search_multimap_.find(start);
310 while (it != search_multimap_.end() && it->first == start) {
312 auto mit = search_map_.find(s);
314 if (!(data.
flags & kPdtMarked)) {
315 search_map_.erase(mit);
318 search_multimap_.erase(it++);
339 template <
class Arc,
class Queue>
353 const std::vector<std::pair<Label, Label>> &parens,
355 : ifst_(ifst.Copy()),
358 start_(ifst.Start()),
368 FSTERROR() <<
"PdtShortestPath: Weight needs to have the path" 369 <<
" property and be right distributive: " << Weight::Type();
372 for (
Label i = 0; i < parens.size(); ++i) {
373 const auto &pair = parens[i];
374 paren_map_[pair.first] = i;
375 paren_map_[pair.second] = i;
381 VLOG(1) <<
"# of enqueued: " << nenqueued_;
382 VLOG(1) <<
"cpmm size: " << close_paren_multimap_.size();
402 std::unordered_multimap<internal::ParenState<Arc>, Arc,
406 return close_paren_multimap_;
412 void GetDistance(
StateId start);
434 std::unique_ptr<Fst<Arc>> ifst_;
436 const std::vector<std::pair<Label, Label>> &parens_;
443 std::unordered_map<Label, Label> paren_map_;
449 static constexpr uint8_t kEnqueued = 0x10;
450 static constexpr uint8_t
kExpanded = 0x20;
451 static constexpr uint8_t kFinished = 0x40;
456 template <
class Arc,
class Queue>
463 fdistance_ = Weight::Zero();
466 close_paren_multimap_.clear();
467 balance_data_.Clear();
471 const auto s = siter.Value();
473 const auto &arc = aiter.Value();
474 const auto it = paren_map_.find(arc.ilabel);
475 if (it != paren_map_.end()) {
476 const auto paren_id = it->second;
477 if (arc.ilabel == parens_[paren_id].first) {
478 balance_data_.OpenInsert(paren_id, arc.nextstate);
481 close_paren_multimap_.emplace(paren_state, arc);
490 template <
class Arc,
class Queue>
494 state_queue_ = &state_queue;
497 sp_data_.SetDistance(q, Weight::One());
498 while (!state_queue_->Empty()) {
499 const auto state = state_queue_->Head();
500 state_queue_->Dequeue();
502 sp_data_.SetFlags(s, 0, kEnqueued);
507 sp_data_.SetFlags(q, kFinished, kFinished);
508 balance_data_.FinishInsert(start);
513 template <
class Arc,
class Queue>
515 if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
516 const auto weight =
Times(sp_data_.Distance(s), ifst_->Final(s.state));
517 if (fdistance_ !=
Plus(fdistance_, weight)) {
522 fdistance_ =
Plus(fdistance_, weight);
529 template <
class Arc,
class Queue>
533 const auto &arc = aiter.Value();
534 const auto weight =
Times(sp_data_.Distance(s), arc.weight);
535 const auto it = paren_map_.find(arc.ilabel);
536 if (it != paren_map_.end()) {
537 const auto paren_id = it->second;
538 if (arc.ilabel == parens_[paren_id].first) {
539 ProcOpenParen(paren_id, s, arc.nextstate, weight);
541 ProcCloseParen(paren_id, s, weight);
544 ProcNonParen(s, arc.nextstate, weight);
553 template <
class Arc,
class Queue>
559 const ParenSpec paren(paren_id, s.start, d.start);
560 const auto pdist = sp_data_.Distance(paren);
561 if (pdist !=
Plus(pdist, weight)) {
562 sp_data_.SetDistance(paren, weight);
563 sp_data_.SetParent(paren, s);
564 const auto dist = sp_data_.Distance(d);
565 if (dist == Weight::Zero()) {
566 auto *state_queue = state_queue_;
567 GetDistance(d.start);
568 state_queue_ = state_queue;
569 }
else if (!(sp_data_.Flags(d) & kFinished)) {
571 <<
"PdtShortestPath: open parenthesis recursion: not bounded stack";
574 for (
auto set_iter = balance_data_.Find(paren_id, nextstate);
575 !set_iter.Done(); set_iter.Next()) {
576 const SearchState cpstate(set_iter.Element(), d.start);
578 for (
auto cpit = close_paren_multimap_.find(paren_state);
579 cpit != close_paren_multimap_.end() && paren_state == cpit->first;
581 const auto &cparc = cpit->second;
583 Times(weight,
Times(sp_data_.Distance(cpstate), cparc.weight));
584 Relax(cpstate, s, cparc.nextstate, cpw, paren_id);
593 template <
class Arc,
class Queue>
599 balance_data_.CloseInsert(paren_id, s.start, s.state);
605 template <
class Arc,
class Queue>
609 Relax(s, s, nextstate, weight,
kNoLabel);
615 template <
class Arc,
class Queue>
621 Weight dist = sp_data_.Distance(d);
622 if (dist !=
Plus(dist, weight)) {
623 sp_data_.SetParent(d, s);
624 sp_data_.SetParenId(d, paren_id);
625 sp_data_.SetDistance(d,
Plus(dist, weight));
630 template <
class Arc,
class Queue>
632 if (!(sp_data_.Flags(s) & kEnqueued)) {
633 state_queue_->Enqueue(s.state);
634 sp_data_.SetFlags(s, kEnqueued, kEnqueued);
637 state_queue_->Update(s.state);
643 template <
class Arc,
class Queue>
651 std::stack<ParenSpec> paren_stack;
654 s_p = ofst_->AddState();
656 ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
659 if (arc.ilabel == parens_[paren_id].first) {
662 const ParenSpec paren(paren_id, d.start, s.start);
663 paren_stack.push(paren);
665 if (!keep_parens_) arc.ilabel = arc.olabel = 0;
668 ofst_->AddArc(s_p, arc);
671 s = sp_data_.Parent(d);
672 paren_id = sp_data_.ParenId(d);
674 arc = GetPathArc(s, d, paren_id,
false);
675 }
else if (!paren_stack.empty()) {
676 const ParenSpec paren = paren_stack.top();
677 s = sp_data_.Parent(paren);
678 paren_id = paren.paren_id;
679 arc = GetPathArc(s, d, paren_id,
true);
682 ofst_->SetStart(s_p);
683 ofst_->SetProperties(
690 template <
class Arc,
class Queue>
692 Label paren_id,
bool open_paren) {
696 const auto &arc = aiter.Value();
697 if (arc.nextstate != d.state)
continue;
699 const auto it = paren_map_.find(arc.ilabel);
700 if (it != paren_map_.end()) {
701 arc_paren_id = it->second;
702 bool arc_open_paren = (arc.ilabel == parens_[arc_paren_id].first);
703 if (arc_open_paren != open_paren)
continue;
705 if (arc_paren_id != paren_id)
continue;
706 if (arc.weight ==
Plus(arc.weight, path_arc.weight)) path_arc = arc;
709 FSTERROR() <<
"PdtShortestPath::GetPathArc: Failed to find arc";
715 template <
class Arc,
class Queue>
721 template <
class Arc,
class Queue>
724 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
734 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
745 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
bool operator==(const ParenSpec &other) const
internal::PdtBalanceData< Arc > * GetBalanceData()
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
void SetFlags(SearchState s, uint8_t f, uint8_t mask)
uint8_t Flags(SearchState s) const
PdtShortestPathData(bool gc)
SearchState Parent(SearchState s) const
uint64_t ShortestPathProperties(uint64_t props, bool tree=false)
typename Arc::StateId StateId
void SetDistance(SearchState s, Weight weight)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
SearchState Parent(const ParenSpec &paren) const
constexpr uint64_t kError
constexpr uint8_t kPdtFinal
virtual void SetInputSymbols(const SymbolTable *isyms)=0
ParenSpec(Label paren_id=kNoLabel, StateId src_start=kNoStateId, StateId dest_start=kNoStateId)
std::unordered_multimap< internal::ParenState< Arc >, Arc, typename internal::ParenState< Arc >::Hash > CloseParenMultimap
void SetParent(SearchState s, SearchState p)
constexpr uint64_t kRightSemiring
typename SpData::SearchState SearchState
typename Arc::Label Label
SearchState(StateId s=kNoStateId, StateId t=kNoStateId)
typename Collection< ssize_t, StateId >::SetIterator SetIterator
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename Arc::Label Label
PdtShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, const PdtShortestPathOptions< Arc, Queue > &opts)
const CloseParenMultimap & GetCloseParenMultimap() const
Weight Distance(const ParenSpec &paren) const
typename Arc::Weight Weight
void SetParenId(SearchState s, Label p)
PdtShortestPathOptions(bool keep_parentheses=false, bool path_gc=true)
Label ParenId(SearchState s) const
constexpr uint8_t kPdtInited
void SetDistance(const ParenSpec &paren, Weight weight)
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
constexpr uint64_t kFstProperties
void ShortestPath(MutableFst< Arc > *ofst)
typename Arc::Weight Weight
Arc::StateId CountStates(const Fst< Arc > &fst)
bool operator==(const SearchState &other) const
virtual void DeleteStates(const std::vector< StateId > &)=0
typename internal::PdtBalanceData< Arc >::SetIterator CloseSourceIterator
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
Weight Distance(SearchState s) const
const internal::PdtShortestPathData< Arc > & GetShortestPathData() const
constexpr uint64_t kExpanded
void SetParent(const ParenSpec &paren, SearchState p)
constexpr uint8_t kPdtMarked
typename SpData::ParenSpec ParenSpec
typename Arc::StateId StateId