20 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H_ 21 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H_ 25 #include <unordered_map> 33 #include <unordered_map> 37 template <
class Arc,
class Queue>
44 : keep_parentheses(keep_parentheses), path_gc(path_gc) {}
73 using Label =
typename Arc::Label;
82 : state(s), start(t) {}
85 if (&other ==
this)
return true;
86 return other.
state == state && other.
start == start;
96 : paren_id(paren_id), src_start(src_start), dest_start(dest_start) {}
103 if (&other ==
this)
return true;
104 return (other.
paren_id == paren_id &&
112 : distance(
Weight::Zero()),
124 : gc_(gc), nstates_(0), ngc_(0), finished_(false) {}
127 VLOG(1) <<
"opm size: " << paren_map_.size();
128 VLOG(1) <<
"# of search states: " << nstates_;
129 if (gc_)
VLOG(1) <<
"# of GC'd search states: " << ngc_;
134 search_multimap_.clear();
147 return GetSearchData(paren)->distance;
153 return GetSearchData(paren)->parent;
161 GetSearchData(s)->distance = std::move(weight);
165 GetSearchData(paren)->distance = std::move(weight);
171 GetSearchData(paren)->parent = p;
176 FSTERROR() <<
"PdtShortestPathData: Paren ID does not fit in an int16_t";
178 GetSearchData(s)->paren_id = p;
182 auto *data = GetSearchData(s);
183 data->flags &= ~mask;
184 data->flags |= f & mask;
193 struct SearchStateHash {
195 static constexpr
auto prime = 7853;
202 size_t operator()(
const ParenSpec &paren)
const {
203 static constexpr
auto prime0 = 7853;
204 static constexpr
auto prime1 = 7867;
211 std::unordered_map<SearchState, SearchData, SearchStateHash>;
213 using SearchMultimap = std::unordered_multimap<StateId, StateId>;
216 using ParenMap = std::unordered_map<ParenSpec, SearchData, ParenHash>;
219 if (s == state_)
return state_data_;
221 auto it = search_map_.find(s);
222 if (it == search_map_.end())
return &null_search_data_;
224 return state_data_ = &(it->second);
227 state_data_ = &search_map_[s];
228 if (!(state_data_->flags & kPdtInited)) {
230 if (gc_) search_multimap_.insert(std::make_pair(s.
start, s.
state));
238 if (paren == paren_)
return paren_data_;
240 auto it = paren_map_.find(paren);
241 if (it == paren_map_.end())
return &null_search_data_;
243 return state_data_ = &(it->second);
246 return paren_data_ = &paren_map_[paren];
250 mutable SearchMap search_map_;
251 mutable SearchMultimap search_multimap_;
252 mutable ParenMap paren_map_;
258 mutable size_t nstates_;
273 std::vector<StateId> finals;
274 for (
auto it = search_multimap_.find(start);
275 it != search_multimap_.end() && it->first == start; ++it) {
277 if (search_map_[s].flags & kPdtFinal) finals.push_back(s.
state);
280 for (
const auto state : finals) {
283 auto &sdata = search_map_[ss];
284 if (sdata.flags & kPdtMarked)
break;
286 const auto p = sdata.parent;
287 if (p.start != start && p.start !=
kNoLabel) {
289 ss = paren_map_[paren].parent;
296 auto it = search_multimap_.find(start);
297 while (it != search_multimap_.end() && it->first == start) {
299 auto mit = search_map_.find(s);
301 if (!(data.
flags & kPdtMarked)) {
302 search_map_.erase(mit);
305 search_multimap_.erase(it++);
326 template <
class Arc,
class Queue>
340 const std::vector<std::pair<Label, Label>> &parens,
342 : ifst_(ifst.Copy()),
345 start_(ifst.Start()),
355 FSTERROR() <<
"PdtShortestPath: Weight needs to have the path" 356 <<
" property and be right distributive: " << Weight::Type();
359 for (
Label i = 0; i < parens.size(); ++i) {
360 const auto &pair = parens[i];
361 paren_map_[pair.first] = i;
362 paren_map_[pair.second] = i;
368 VLOG(1) <<
"# of enqueued: " << nenqueued_;
369 VLOG(1) <<
"cpmm size: " << close_paren_multimap_.size();
389 std::unordered_multimap<internal::ParenState<Arc>, Arc,
393 return close_paren_multimap_;
399 void GetDistance(
StateId start);
421 std::unique_ptr<Fst<Arc>> ifst_;
423 const std::vector<std::pair<Label, Label>> &parens_;
430 std::unordered_map<Label, Label> paren_map_;
436 static constexpr uint8_t kEnqueued = 0x10;
437 static constexpr uint8_t
kExpanded = 0x20;
438 static constexpr uint8_t kFinished = 0x40;
443 template <
class Arc,
class Queue>
450 fdistance_ = Weight::Zero();
453 close_paren_multimap_.clear();
454 balance_data_.Clear();
458 const auto s = siter.Value();
460 const auto &arc = aiter.Value();
461 const auto it = paren_map_.find(arc.ilabel);
462 if (it != paren_map_.end()) {
463 const auto paren_id = it->second;
464 if (arc.ilabel == parens_[paren_id].first) {
465 balance_data_.OpenInsert(paren_id, arc.nextstate);
468 close_paren_multimap_.emplace(paren_state, arc);
477 template <
class Arc,
class Queue>
481 state_queue_ = &state_queue;
484 sp_data_.SetDistance(q, Weight::One());
485 while (!state_queue_->Empty()) {
486 const auto state = state_queue_->Head();
487 state_queue_->Dequeue();
489 sp_data_.SetFlags(s, 0, kEnqueued);
494 sp_data_.SetFlags(q, kFinished, kFinished);
495 balance_data_.FinishInsert(start);
500 template <
class Arc,
class Queue>
502 if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
503 const auto weight =
Times(sp_data_.Distance(s), ifst_->Final(s.state));
504 if (fdistance_ !=
Plus(fdistance_, weight)) {
509 fdistance_ =
Plus(fdistance_, weight);
516 template <
class Arc,
class Queue>
520 const auto &arc = aiter.Value();
521 const auto weight =
Times(sp_data_.Distance(s), arc.weight);
522 const auto it = paren_map_.find(arc.ilabel);
523 if (it != paren_map_.end()) {
524 const auto paren_id = it->second;
525 if (arc.ilabel == parens_[paren_id].first) {
526 ProcOpenParen(paren_id, s, arc.nextstate, weight);
528 ProcCloseParen(paren_id, s, weight);
531 ProcNonParen(s, arc.nextstate, weight);
540 template <
class Arc,
class Queue>
546 const ParenSpec paren(paren_id, s.start, d.start);
547 const auto pdist = sp_data_.Distance(paren);
548 if (pdist !=
Plus(pdist, weight)) {
549 sp_data_.SetDistance(paren, weight);
550 sp_data_.SetParent(paren, s);
551 const auto dist = sp_data_.Distance(d);
552 if (dist == Weight::Zero()) {
553 auto *state_queue = state_queue_;
554 GetDistance(d.start);
555 state_queue_ = state_queue;
556 }
else if (!(sp_data_.Flags(d) & kFinished)) {
558 <<
"PdtShortestPath: open parenthesis recursion: not bounded stack";
561 for (
auto set_iter = balance_data_.Find(paren_id, nextstate);
562 !set_iter.Done(); set_iter.Next()) {
563 const SearchState cpstate(set_iter.Element(), d.start);
565 for (
auto cpit = close_paren_multimap_.find(paren_state);
566 cpit != close_paren_multimap_.end() && paren_state == cpit->first;
568 const auto &cparc = cpit->second;
570 Times(weight,
Times(sp_data_.Distance(cpstate), cparc.weight));
571 Relax(cpstate, s, cparc.nextstate, cpw, paren_id);
580 template <
class Arc,
class Queue>
586 balance_data_.CloseInsert(paren_id, s.start, s.state);
592 template <
class Arc,
class Queue>
596 Relax(s, s, nextstate, weight,
kNoLabel);
602 template <
class Arc,
class Queue>
608 Weight dist = sp_data_.Distance(d);
609 if (dist !=
Plus(dist, weight)) {
610 sp_data_.SetParent(d, s);
611 sp_data_.SetParenId(d, paren_id);
612 sp_data_.SetDistance(d,
Plus(dist, weight));
617 template <
class Arc,
class Queue>
619 if (!(sp_data_.Flags(s) & kEnqueued)) {
620 state_queue_->Enqueue(s.state);
621 sp_data_.SetFlags(s, kEnqueued, kEnqueued);
624 state_queue_->Update(s.state);
630 template <
class Arc,
class Queue>
638 std::stack<ParenSpec> paren_stack;
641 s_p = ofst_->AddState();
643 ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
646 if (arc.ilabel == parens_[paren_id].first) {
649 const ParenSpec paren(paren_id, d.start, s.start);
650 paren_stack.push(paren);
652 if (!keep_parens_) arc.ilabel = arc.olabel = 0;
655 ofst_->AddArc(s_p, arc);
658 s = sp_data_.Parent(d);
659 paren_id = sp_data_.ParenId(d);
661 arc = GetPathArc(s, d, paren_id,
false);
662 }
else if (!paren_stack.empty()) {
663 const ParenSpec paren = paren_stack.top();
664 s = sp_data_.Parent(paren);
665 paren_id = paren.paren_id;
666 arc = GetPathArc(s, d, paren_id,
true);
669 ofst_->SetStart(s_p);
670 ofst_->SetProperties(
677 template <
class Arc,
class Queue>
679 Label paren_id,
bool open_paren) {
683 const auto &arc = aiter.Value();
684 if (arc.nextstate != d.state)
continue;
686 const auto it = paren_map_.find(arc.ilabel);
687 if (it != paren_map_.end()) {
688 arc_paren_id = it->second;
689 bool arc_open_paren = (arc.ilabel == parens_[arc_paren_id].first);
690 if (arc_open_paren != open_paren)
continue;
692 if (arc_paren_id != paren_id)
continue;
693 if (arc.weight ==
Plus(arc.weight, path_arc.weight)) path_arc = arc;
696 FSTERROR() <<
"PdtShortestPath::GetPathArc: Failed to find arc";
702 template <
class Arc,
class Queue>
708 template <
class Arc,
class Queue>
711 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
721 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
732 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H_ bool operator==(const ParenSpec &other) const
internal::PdtBalanceData< Arc > * GetBalanceData()
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
void SetFlags(SearchState s, uint8_t f, uint8_t mask)
uint8_t Flags(SearchState s) const
PdtShortestPathData(bool gc)
SearchState Parent(SearchState s) const
uint64_t ShortestPathProperties(uint64_t props, bool tree=false)
typename Arc::StateId StateId
void SetDistance(SearchState s, Weight weight)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
SearchState Parent(const ParenSpec &paren) const
constexpr uint64_t kError
constexpr uint8_t kPdtFinal
virtual void SetInputSymbols(const SymbolTable *isyms)=0
ParenSpec(Label paren_id=kNoLabel, StateId src_start=kNoStateId, StateId dest_start=kNoStateId)
std::unordered_multimap< internal::ParenState< Arc >, Arc, typename internal::ParenState< Arc >::Hash > CloseParenMultimap
void SetParent(SearchState s, SearchState p)
constexpr uint64_t kRightSemiring
typename SpData::SearchState SearchState
typename Arc::Label Label
SearchState(StateId s=kNoStateId, StateId t=kNoStateId)
typename Collection< ssize_t, StateId >::SetIterator SetIterator
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename Arc::Label Label
PdtShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, const PdtShortestPathOptions< Arc, Queue > &opts)
const CloseParenMultimap & GetCloseParenMultimap() const
Weight Distance(const ParenSpec &paren) const
typename Arc::Weight Weight
void SetParenId(SearchState s, Label p)
PdtShortestPathOptions(bool keep_parentheses=false, bool path_gc=true)
Label ParenId(SearchState s) const
constexpr uint8_t kPdtInited
void SetDistance(const ParenSpec &paren, Weight weight)
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
constexpr uint64_t kFstProperties
void ShortestPath(MutableFst< Arc > *ofst)
typename Arc::Weight Weight
Arc::StateId CountStates(const Fst< Arc > &fst)
bool operator==(const SearchState &other) const
virtual void DeleteStates(const std::vector< StateId > &)=0
typename internal::PdtBalanceData< Arc >::SetIterator CloseSourceIterator
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
Weight Distance(SearchState s) const
const internal::PdtShortestPathData< Arc > & GetShortestPathData() const
constexpr uint64_t kExpanded
void SetParent(const ParenSpec &paren, SearchState p)
constexpr uint8_t kPdtMarked
typename SpData::ParenSpec ParenSpec
typename Arc::StateId StateId