20 #ifndef FST_EXTENSIONS_PDT_EXPAND_H_ 21 #define FST_EXTENSIONS_PDT_EXPAND_H_ 24 #include <forward_list> 37 #include <unordered_map> 53 keep_parentheses(keep_parentheses),
55 state_table(state_table) {}
64 using Label =
typename Arc::Label;
86 const std::vector<std::pair<Label, Label>> &parens,
91 state_table_(opts.state_table ? opts.state_table
93 own_stack_(opts.stack == nullptr),
94 own_state_table_(opts.state_table == nullptr),
95 keep_parentheses_(opts.keep_parentheses) {
105 fst_(impl.fst_->Copy(true)),
109 own_state_table_(true),
110 keep_parentheses_(impl.keep_parentheses_) {
118 if (own_stack_)
delete stack_;
119 if (own_state_table_)
delete state_table_;
124 const auto s = fst_->Start();
127 const auto start = state_table_->FindState(tuple);
135 const auto &tuple = state_table_->Tuple(s);
136 const auto weight = fst_->Final(tuple.state_id);
137 if (weight != Weight::Zero() && tuple.stack_id == 0)
140 SetFinal(s, Weight::Zero());
146 if (!HasArcs(s)) ExpandState(s);
151 if (!HasArcs(s)) ExpandState(s);
156 if (!HasArcs(s)) ExpandState(s);
161 if (!HasArcs(s)) ExpandState(s);
171 auto arc = aiter.Value();
172 const auto stack_id = stack_->Find(tuple.
stack_id, arc.ilabel);
173 if (stack_id == -1) {
175 }
else if ((stack_id != tuple.
stack_id) && !keep_parentheses_) {
181 arc.nextstate = state_table_->FindState(ntuple);
190 return *state_table_;
195 inline uint64_t PdtExpandProperties(uint64_t inprops) {
199 std::unique_ptr<const Fst<Arc>> fst_;
203 bool own_state_table_;
204 bool keep_parentheses_;
236 const std::vector<std::pair<Label, Label>> &parens)
241 const std::vector<std::pair<Label, Label>> &parens,
257 GetMutableImpl()->InitArcIterator(s, data);
261 return GetImpl()->GetStack();
265 return GetImpl()->GetStateTable();
300 data->
base = std::make_unique<StateIterator<PdtExpandFst<Arc>>>(*this);
336 const std::vector<std::pair<Label, Label>> &parens,
337 bool keep_parentheses =
false,
339 : ifst_(ifst.Copy()),
340 keep_parentheses_(keep_parentheses),
344 queue_(state_table_, stack_, stack_length_, distance_, fdistance_),
346 Reverse(*ifst_, parens, &rfst_);
351 reverse_shortest_path_->ShortestPath(&path);
353 balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse(
354 rfst_.NumStates(), 10, -1));
355 InitCloseParenMultimap(parens);
358 bool Error()
const {
return error_; }
365 static constexpr uint8_t kEnqueued = 0x01;
366 static constexpr uint8_t
kExpanded = 0x02;
367 static constexpr uint8_t kSourceState = 0x04;
377 const std::vector<StackId> &stack_length,
378 const std::vector<Weight> &distance,
379 const std::vector<Weight> &fdistance)
380 : state_table_(state_table),
382 stack_length_(stack_length),
384 fdistance_(fdistance) {}
387 auto si1 = state_table_.Tuple(s1).stack_id;
388 auto si2 = state_table_.Tuple(s2).stack_id;
389 if (stack_length_[si1] < stack_length_[si2])
return true;
390 if (stack_length_[si1] > stack_length_[si2])
return false;
393 return less_(Distance(s1), Distance(s2));
396 for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
397 if (stack_.Top(si1) < stack_.Top(si2))
return true;
398 if (stack_.Top(si1) > stack_.Top(si2))
return false;
405 return (s < distance_.size()) && (s < fdistance_.size())
406 ?
Times(distance_[s], fdistance_[s])
412 const std::vector<StackId> &stack_length_;
413 const std::vector<Weight> &distance_;
414 const std::vector<Weight> &fdistance_;
419 class ShortestStackFirstQueue
424 const std::vector<StackId> &stack_length,
425 const std::vector<Weight> &distance,
426 const std::vector<Weight> &fdistance)
428 state_table, stack, stack_length, distance, fdistance)) {}
431 void InitCloseParenMultimap(
432 const std::vector<std::pair<Label, Label>> &parens);
436 uint8_t Flags(
StateId s)
const;
452 void AddStateAndEnqueue(
StateId s);
456 bool PruneArc(
StateId s,
const Arc &arc);
462 bool ProcNonParen(
StateId s,
const Arc &arc,
bool add_arc);
466 bool ProcCloseParen(
StateId s,
const Arc &arc);
471 std::unique_ptr<Fst<Arc>> ifst_;
475 const bool keep_parentheses_;
483 std::vector<StackId> stack_length_;
485 std::vector<Weight> distance_;
487 std::vector<Weight> fdistance_;
489 ShortestStackFirstQueue queue_;
493 std::vector<uint8_t> flags_;
495 std::vector<StateId> sources_;
497 std::unique_ptr<PdtShortestPath<Arc, FifoQueue<StateId>>>
498 reverse_shortest_path_;
499 std::unique_ptr<internal::PdtBalanceData<Arc>> balance_data_;
502 close_paren_multimap_;
508 std::unordered_map<StateId, Weight> dest_map_;
512 ssize_t current_paren_id_;
513 ssize_t cached_stack_id_;
520 std::forward_list<std::pair<StateId, Weight>> cached_dest_list_;
528 const std::vector<std::pair<Label, Label>> &parens) {
529 std::unordered_map<Label, Label> paren_map;
530 for (
size_t i = 0; i < parens.size(); ++i) {
531 const auto &pair = parens[i];
532 paren_map[pair.first] = i;
533 paren_map[pair.second] = i;
536 const auto s = siter.Value();
538 const auto &arc = aiter.Value();
539 const auto it = paren_map.find(arc.ilabel);
540 if (it == paren_map.end())
continue;
541 if (arc.ilabel == parens[it->second].second) {
543 close_paren_multimap_.emplace(key, arc);
556 const SearchState ss(source + 1, dest + 1);
557 const auto distance =
558 reverse_shortest_path_->GetShortestPathData().Distance(ss);
559 VLOG(2) <<
"D(" << source <<
", " << dest <<
") =" << distance;
566 return s < flags_.size() ? flags_[s] : 0;
572 while (flags_.size() <= s) flags_.push_back(0);
574 flags_[s] |= flags & mask;
580 return s < distance_.size() ? distance_[s] : Weight::Zero();
586 while (distance_.size() <= s) distance_.push_back(Weight::Zero());
587 distance_[s] = std::move(weight);
593 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
599 while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
600 fdistance_[s] = std::move(weight);
606 return s < sources_.size() ? sources_[s] :
kNoStateId;
612 while (sources_.size() <= s) sources_.push_back(
kNoStateId);
620 if (!(Flags(s) & (kEnqueued |
kExpanded))) {
621 while (ofst_->NumStates() <= s) ofst_->AddState();
624 }
else if (Flags(s) & kEnqueued) {
639 const auto nd =
Times(Distance(s), arc.weight);
640 if (less_(nd, Distance(arc.nextstate))) {
641 SetDistance(arc.nextstate, nd);
642 SetSourceState(arc.nextstate, SourceState(s));
644 if (less_(fd, FinalDistance(arc.nextstate))) {
645 SetFinalDistance(arc.nextstate, fd);
647 VLOG(2) <<
"Relax: " << s <<
", d[s] = " << Distance(s) <<
", to " 648 << arc.nextstate <<
", d[ns] = " << Distance(arc.nextstate)
655 VLOG(2) <<
"Prune ?";
656 auto fd = Weight::Zero();
657 if ((cached_source_ != SourceState(s)) ||
658 (cached_stack_id_ != current_stack_id_)) {
659 cached_source_ = SourceState(s);
660 cached_stack_id_ = current_stack_id_;
661 cached_dest_list_.clear();
662 if (cached_source_ != ifst_->Start()) {
664 balance_data_->Find(current_paren_id_, cached_source_);
665 !set_iter.Done(); set_iter.Next()) {
666 auto dest = set_iter.Element();
667 const auto it = dest_map_.find(dest);
668 cached_dest_list_.push_front(*it);
674 cached_dest_list_.push_front(
675 std::make_pair(rfst_.Start() - 1, Weight::One()));
678 for (
auto it = cached_dest_list_.begin(); it != cached_dest_list_.end();
681 DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first);
685 return less_(limit_,
Times(Distance(s),
Times(arc.weight, fd)));
692 const auto s = efst_.Start();
693 AddStateAndEnqueue(s);
695 SetSourceState(s, ifst_->Start());
696 current_stack_id_ = 0;
697 current_paren_id_ = -1;
698 stack_length_.push_back(0);
699 const auto r = rfst_.Start() - 1;
700 cached_source_ = ifst_->Start();
701 cached_stack_id_ = 0;
702 cached_dest_list_.push_front(std::make_pair(r, Weight::One()));
704 SetFinalDistance(state_table_.FindState(tuple), Weight::One());
705 SetDistance(s, Weight::One());
706 const auto d = DistanceToDest(ifst_->Start(), r);
707 SetFinalDistance(s, d);
715 const auto weight = efst_.Final(s);
716 if (weight == Weight::Zero())
return;
717 if (less_(limit_,
Times(Distance(s), weight)))
return;
718 ofst_->SetFinal(s, weight);
726 VLOG(2) <<
"ProcNonParen: " << s <<
" to " << arc.nextstate <<
", " 727 << arc.ilabel <<
":" << arc.olabel <<
" / " << arc.weight
728 <<
", add_arc = " << (add_arc ?
"true" :
"false");
729 if (PruneArc(s, arc))
return false;
730 if (add_arc) ofst_->AddArc(s, arc);
731 AddStateAndEnqueue(arc.nextstate);
749 while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
750 if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1;
751 const auto ns = arc.nextstate;
752 VLOG(2) <<
"Open paren: " << s <<
"(" << state_table_.Tuple(s).state_id
753 <<
") to " << ns <<
"(" << state_table_.Tuple(ns).state_id <<
")";
754 bool proc_arc =
false;
755 auto fd = Weight::Zero();
756 const auto paren_id = stack_.ParenId(arc.ilabel);
757 std::forward_list<StateId> sources;
759 balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
760 !set_iter.Done(); set_iter.Next()) {
761 sources.push_front(set_iter.Element());
763 for (
const auto source : sources) {
764 VLOG(2) <<
"Close paren source: " << source;
766 for (
auto it = close_paren_multimap_.find(paren_state);
767 it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
768 auto meta_arc = it->second;
770 meta_arc.nextstate = state_table_.FindState(tuple);
771 const auto state_id = state_table_.Tuple(ns).state_id;
772 const auto d = DistanceToDest(state_id, source);
773 VLOG(2) << state_id <<
", " << source;
774 VLOG(2) <<
"Meta arc weight = " << arc.weight <<
" Times " << d
775 <<
" Times " << meta_arc.weight;
776 meta_arc.weight =
Times(arc.weight,
Times(d, meta_arc.weight));
777 proc_arc |= ProcNonParen(s, meta_arc,
false);
780 Times(
Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
782 FinalDistance(meta_arc.nextstate)));
786 VLOG(2) <<
"Proc open paren " << s <<
" to " << arc.nextstate;
788 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
789 AddStateAndEnqueue(arc.nextstate);
790 const auto nd =
Times(Distance(s), arc.weight);
791 if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd);
794 if (less_(fd, FinalDistance(arc.nextstate)))
795 SetFinalDistance(arc.nextstate, fd);
796 SetFlags(arc.nextstate, kSourceState, kSourceState);
806 Times(Distance(s),
Times(arc.weight, FinalDistance(arc.nextstate)));
807 if (less_(limit_, weight))
return false;
809 keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
820 if (!(Flags(s) & kSourceState))
return;
821 if (si != current_stack_id_) {
823 current_stack_id_ = si;
824 current_paren_id_ = stack_.Top(current_stack_id_);
825 VLOG(2) <<
"StackID " << si <<
" dequeued for first time";
829 SetSourceState(s, state_table_.Tuple(s).state_id);
830 const auto paren_id = stack_.Top(si);
832 balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
833 !set_iter.Done(); set_iter.Next()) {
834 const auto dest_state = set_iter.Element();
835 if (dest_map_.find(dest_state) != dest_map_.end())
continue;
836 auto dest_weight = Weight::Zero();
838 for (
auto it = close_paren_multimap_.find(paren_state);
839 it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
840 const auto &arc = it->second;
845 Times(arc.weight, FinalDistance(state_table_.FindState(tuple))));
847 dest_map_[dest_state] = dest_weight;
848 VLOG(2) <<
"State " << dest_state <<
" is a dest state for stack ID " << si
849 <<
" with weight " << dest_weight;
856 const typename Arc::Weight &threshold) {
862 ofst_->DeleteStates();
863 ofst_->SetInputSymbols(ifst_->InputSymbols());
864 ofst_->SetOutputSymbols(ifst_->OutputSymbols());
865 limit_ =
Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
868 while (!queue_.Empty()) {
869 const auto s = queue_.Head();
872 VLOG(2) << s <<
" dequeued!";
874 StackId stack_id = state_table_.Tuple(s).stack_id;
875 ProcDestStates(s, stack_id);
878 const auto &arc = aiter.Value();
879 const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
880 if (stack_id == nextstack_id) {
881 ProcNonParen(s, arc,
true);
882 }
else if (stack_id == stack_.Pop(nextstack_id)) {
883 ProcOpenParen(s, arc, stack_id, nextstack_id);
885 ProcCloseParen(s, arc);
888 VLOG(2) <<
"d[" << s <<
"] = " << Distance(s) <<
", fd[" << s
889 <<
"] = " << FinalDistance(s);
904 Weight weight_threshold = Weight::Zero())
906 keep_parentheses(keep_parentheses),
907 weight_threshold(std::move(weight_threshold)) {}
919 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
922 using Weight =
typename Arc::Weight;
932 FSTERROR() <<
"Expand: non-Zero weight_threshold with non-idempotent" 933 <<
" Weight " << Weight::Type();
948 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
950 MutableFst<Arc> *ofst,
bool connect =
true,
bool keep_parentheses =
false) {
952 Expand(ifst, parens, ofst, opts);
957 #endif // FST_EXTENSIONS_PDT_EXPAND_H_ ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::StateId StateId
const PdtStateTable< StateId, StackId > & GetStateTable() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
typename Arc::Label Label
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
ArcIterator(const PdtExpandFst< Arc > &fst, StateId s)
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Store::State State
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
size_t NumOutputEpsilons(StateId s)
size_t NumInputEpsilons(StateId s)
PdtExpandFst(const PdtExpandFst< Arc > &fst, bool safe=false)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, const char *src="")
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens)
constexpr uint64_t kError
typename Arc::StateId StateId
constexpr uint64_t kInitialAcyclic
typename Arc::StateId StateId
PdtStack< typename Arc::StateId, typename Arc::Label > * stack
typename PdtExpandFst< Arc >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
typename internal::PdtBalanceData< Arc >::SetIterator SetIterator
~PdtExpandFstImpl() override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
typename Arc::Weight Weight
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
void ExpandState(StateId s)
const PdtStack< StackId, Label > & GetStack() const
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
typename Arc::Weight Weight
typename Collection< ssize_t, StateId >::SetIterator SetIterator
size_t NumArcs(StateId s)
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
PdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool keep_parentheses=false, PdtStack< typename Arc::StateId, typename Arc::Label > *stack=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *state_table=nullptr)
virtual uint64_t Properties() const
typename Arc::StateId StateId
constexpr uint64_t kCopyProperties
constexpr uint64_t kAcyclic
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
PdtExpandOptions(bool connect=true, bool keep_parentheses=false, Weight weight_threshold=Weight::Zero())
void Expand(MutableFst< Arc > *ofst, const Weight &threshold)
std::unique_ptr< StateIteratorBase< Arc > > base
typename Arc::Weight Weight
PdtExpandFst< Arc > * Copy(bool safe=false) const override
PdtExpandFstImpl(const PdtExpandFstImpl &impl)
StateIterator(const PdtExpandFst< Arc > &fst)
constexpr uint64_t kFstProperties
constexpr uint64_t kUnweighted
const PdtStateTable< StateId, StackId > & GetStateTable() const
typename PdtExpandFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
PdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
const PdtStack< StackId, Label > & GetStack() const
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
typename Arc::Weight Weight
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
typename Arc::Label Label
constexpr uint64_t kExpanded
typename Arc::Label Label
uint64_t Properties(uint64_t mask, bool test) const override
const Impl * GetImpl() const
PdtPrunedExpand(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, bool keep_parentheses=false, const CacheOptions &opts=CacheOptions())
constexpr uint64_t kAcceptor
virtual const SymbolTable * OutputSymbols() const =0