20 #ifndef FST_EXTENSIONS_PDT_EXPAND_H_ 21 #define FST_EXTENSIONS_PDT_EXPAND_H_ 23 #include <sys/types.h> 27 #include <forward_list> 29 #include <unordered_map> 49 #include <unordered_map> 65 keep_parentheses(keep_parentheses),
67 state_table(state_table) {}
76 using Label =
typename Arc::Label;
98 const std::vector<std::pair<Label, Label>> &parens,
103 state_table_(opts.state_table ? opts.state_table
105 own_stack_(opts.stack == nullptr),
106 own_state_table_(opts.state_table == nullptr),
107 keep_parentheses_(opts.keep_parentheses) {
117 fst_(impl.fst_->Copy(true)),
121 own_state_table_(true),
122 keep_parentheses_(impl.keep_parentheses_) {
130 if (own_stack_)
delete stack_;
131 if (own_state_table_)
delete state_table_;
136 const auto s = fst_->Start();
139 const auto start = state_table_->FindState(tuple);
147 const auto &tuple = state_table_->Tuple(s);
148 const auto weight = fst_->Final(tuple.state_id);
149 if (weight != Weight::Zero() && tuple.stack_id == 0)
152 SetFinal(s, Weight::Zero());
158 if (!HasArcs(s)) ExpandState(s);
163 if (!HasArcs(s)) ExpandState(s);
168 if (!HasArcs(s)) ExpandState(s);
173 if (!HasArcs(s)) ExpandState(s);
183 auto arc = aiter.Value();
184 const auto stack_id = stack_->Find(tuple.
stack_id, arc.ilabel);
185 if (stack_id == -1) {
187 }
else if ((stack_id != tuple.
stack_id) && !keep_parentheses_) {
193 arc.nextstate = state_table_->FindState(ntuple);
202 return *state_table_;
207 inline uint64_t PdtExpandProperties(uint64_t inprops) {
211 std::unique_ptr<const Fst<Arc>> fst_;
215 bool own_state_table_;
216 bool keep_parentheses_;
248 const std::vector<std::pair<Label, Label>> &parens)
253 const std::vector<std::pair<Label, Label>> &parens,
269 GetMutableImpl()->InitArcIterator(s, data);
273 return GetImpl()->GetStack();
277 return GetImpl()->GetStateTable();
312 data->
base = std::make_unique<StateIterator<PdtExpandFst<Arc>>>(*this);
348 const std::vector<std::pair<Label, Label>> &parens,
349 bool keep_parentheses =
false,
351 : ifst_(ifst.Copy()),
352 keep_parentheses_(keep_parentheses),
356 queue_(state_table_, stack_, stack_length_, distance_, fdistance_),
358 Reverse(*ifst_, parens, &rfst_);
363 reverse_shortest_path_->ShortestPath(&path);
365 balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse(
366 rfst_.NumStates(), 10, -1));
367 InitCloseParenMultimap(parens);
370 bool Error()
const {
return error_; }
377 static constexpr uint8_t kEnqueued = 0x01;
378 static constexpr uint8_t
kExpanded = 0x02;
379 static constexpr uint8_t kSourceState = 0x04;
389 const std::vector<StackId> &stack_length,
390 const std::vector<Weight> &distance,
391 const std::vector<Weight> &fdistance)
392 : state_table_(state_table),
394 stack_length_(stack_length),
396 fdistance_(fdistance),
400 auto si1 = state_table_.Tuple(s1).stack_id;
401 auto si2 = state_table_.Tuple(s2).stack_id;
402 if (stack_length_[si1] < stack_length_[si2])
return true;
403 if (stack_length_[si1] > stack_length_[si2])
return false;
406 return less_(Distance(s1), Distance(s2));
409 for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
410 if (stack_.Top(si1) < stack_.Top(si2))
return true;
411 if (stack_.Top(si1) > stack_.Top(si2))
return false;
418 return (s < distance_.size()) && (s < fdistance_.size())
419 ?
Times(distance_[s], fdistance_[s])
425 const std::vector<StackId> &stack_length_;
426 const std::vector<Weight> &distance_;
427 const std::vector<Weight> &fdistance_;
432 class ShortestStackFirstQueue
437 const std::vector<StackId> &stack_length,
438 const std::vector<Weight> &distance,
439 const std::vector<Weight> &fdistance)
441 state_table, stack, stack_length, distance, fdistance)) {}
444 void InitCloseParenMultimap(
445 const std::vector<std::pair<Label, Label>> &parens);
449 uint8_t Flags(
StateId s)
const;
465 void AddStateAndEnqueue(
StateId s);
469 bool PruneArc(
StateId s,
const Arc &arc);
475 bool ProcNonParen(
StateId s,
const Arc &arc,
bool add_arc);
479 bool ProcCloseParen(
StateId s,
const Arc &arc);
484 std::unique_ptr<Fst<Arc>> ifst_;
488 const bool keep_parentheses_;
496 std::vector<StackId> stack_length_;
498 std::vector<Weight> distance_;
500 std::vector<Weight> fdistance_;
502 ShortestStackFirstQueue queue_;
506 std::vector<uint8_t> flags_;
508 std::vector<StateId> sources_;
510 std::unique_ptr<PdtShortestPath<Arc, FifoQueue<StateId>>>
511 reverse_shortest_path_;
512 std::unique_ptr<internal::PdtBalanceData<Arc>> balance_data_;
515 close_paren_multimap_;
521 std::unordered_map<StateId, Weight> dest_map_;
525 ssize_t current_paren_id_;
526 ssize_t cached_stack_id_;
533 std::forward_list<std::pair<StateId, Weight>> cached_dest_list_;
541 const std::vector<std::pair<Label, Label>> &parens) {
542 std::unordered_map<Label, Label> paren_map;
543 for (
size_t i = 0; i < parens.size(); ++i) {
544 const auto &pair = parens[i];
545 paren_map[pair.first] = i;
546 paren_map[pair.second] = i;
549 const auto s = siter.Value();
551 const auto &arc = aiter.Value();
552 const auto it = paren_map.find(arc.ilabel);
553 if (it == paren_map.end())
continue;
554 if (arc.ilabel == parens[it->second].second) {
556 close_paren_multimap_.emplace(key, arc);
569 const SearchState ss(source + 1, dest + 1);
570 const auto distance =
571 reverse_shortest_path_->GetShortestPathData().Distance(ss);
572 VLOG(2) <<
"D(" << source <<
", " << dest <<
") =" << distance;
579 return s < flags_.size() ? flags_[s] : 0;
585 while (flags_.size() <= s) flags_.push_back(0);
587 flags_[s] |= flags & mask;
593 return s < distance_.size() ? distance_[s] : Weight::Zero();
599 while (distance_.size() <= s) distance_.push_back(Weight::Zero());
600 distance_[s] = std::move(weight);
606 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
612 while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
613 fdistance_[s] = std::move(weight);
619 return s < sources_.size() ? sources_[s] :
kNoStateId;
625 while (sources_.size() <= s) sources_.push_back(
kNoStateId);
633 if (!(Flags(s) & (kEnqueued |
kExpanded))) {
634 while (ofst_->NumStates() <= s) ofst_->AddState();
637 }
else if (Flags(s) & kEnqueued) {
652 const auto nd =
Times(Distance(s), arc.weight);
653 if (less_(nd, Distance(arc.nextstate))) {
654 SetDistance(arc.nextstate, nd);
655 SetSourceState(arc.nextstate, SourceState(s));
657 if (less_(fd, FinalDistance(arc.nextstate))) {
658 SetFinalDistance(arc.nextstate, fd);
660 VLOG(2) <<
"Relax: " << s <<
", d[s] = " << Distance(s) <<
", to " 661 << arc.nextstate <<
", d[ns] = " << Distance(arc.nextstate)
668 VLOG(2) <<
"Prune ?";
669 auto fd = Weight::Zero();
670 if ((cached_source_ != SourceState(s)) ||
671 (cached_stack_id_ != current_stack_id_)) {
672 cached_source_ = SourceState(s);
673 cached_stack_id_ = current_stack_id_;
674 cached_dest_list_.clear();
675 if (cached_source_ != ifst_->Start()) {
677 balance_data_->Find(current_paren_id_, cached_source_);
678 !set_iter.Done(); set_iter.Next()) {
679 auto dest = set_iter.Element();
680 const auto it = dest_map_.find(dest);
681 cached_dest_list_.push_front(*it);
687 cached_dest_list_.push_front(
688 std::make_pair(rfst_.Start() - 1, Weight::One()));
691 for (
auto it = cached_dest_list_.begin(); it != cached_dest_list_.end();
694 DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first);
698 return less_(limit_,
Times(Distance(s),
Times(arc.weight, fd)));
705 const auto s = efst_.Start();
706 AddStateAndEnqueue(s);
708 SetSourceState(s, ifst_->Start());
709 current_stack_id_ = 0;
710 current_paren_id_ = -1;
711 stack_length_.push_back(0);
712 const auto r = rfst_.Start() - 1;
713 cached_source_ = ifst_->Start();
714 cached_stack_id_ = 0;
715 cached_dest_list_.push_front(std::make_pair(r, Weight::One()));
717 SetFinalDistance(state_table_.FindState(tuple), Weight::One());
718 SetDistance(s, Weight::One());
719 const auto d = DistanceToDest(ifst_->Start(), r);
720 SetFinalDistance(s, d);
728 const auto weight = efst_.Final(s);
729 if (weight == Weight::Zero())
return;
730 if (less_(limit_,
Times(Distance(s), weight)))
return;
731 ofst_->SetFinal(s, weight);
739 VLOG(2) <<
"ProcNonParen: " << s <<
" to " << arc.nextstate <<
", " 740 << arc.ilabel <<
":" << arc.olabel <<
" / " << arc.weight
741 <<
", add_arc = " << (add_arc ?
"true" :
"false");
742 if (PruneArc(s, arc))
return false;
743 if (add_arc) ofst_->AddArc(s, arc);
744 AddStateAndEnqueue(arc.nextstate);
762 while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
763 if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1;
764 const auto ns = arc.nextstate;
765 VLOG(2) <<
"Open paren: " << s <<
"(" << state_table_.Tuple(s).state_id
766 <<
") to " << ns <<
"(" << state_table_.Tuple(ns).state_id <<
")";
767 bool proc_arc =
false;
768 auto fd = Weight::Zero();
769 const auto paren_id = stack_.ParenId(arc.ilabel);
770 std::forward_list<StateId> sources;
772 balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
773 !set_iter.Done(); set_iter.Next()) {
774 sources.push_front(set_iter.Element());
776 for (
const auto source : sources) {
777 VLOG(2) <<
"Close paren source: " << source;
779 for (
auto it = close_paren_multimap_.find(paren_state);
780 it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
781 auto meta_arc = it->second;
783 meta_arc.nextstate = state_table_.FindState(tuple);
784 const auto state_id = state_table_.Tuple(ns).state_id;
785 const auto d = DistanceToDest(state_id, source);
786 VLOG(2) << state_id <<
", " << source;
787 VLOG(2) <<
"Meta arc weight = " << arc.weight <<
" Times " << d
788 <<
" Times " << meta_arc.weight;
789 meta_arc.weight =
Times(arc.weight,
Times(d, meta_arc.weight));
790 proc_arc |= ProcNonParen(s, meta_arc,
false);
793 Times(
Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
795 FinalDistance(meta_arc.nextstate)));
799 VLOG(2) <<
"Proc open paren " << s <<
" to " << arc.nextstate;
801 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
802 AddStateAndEnqueue(arc.nextstate);
803 const auto nd =
Times(Distance(s), arc.weight);
804 if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd);
807 if (less_(fd, FinalDistance(arc.nextstate)))
808 SetFinalDistance(arc.nextstate, fd);
809 SetFlags(arc.nextstate, kSourceState, kSourceState);
819 Times(Distance(s),
Times(arc.weight, FinalDistance(arc.nextstate)));
820 if (less_(limit_, weight))
return false;
822 keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
833 if (!(Flags(s) & kSourceState))
return;
834 if (si != current_stack_id_) {
836 current_stack_id_ = si;
837 current_paren_id_ = stack_.Top(current_stack_id_);
838 VLOG(2) <<
"StackID " << si <<
" dequeued for first time";
842 SetSourceState(s, state_table_.Tuple(s).state_id);
843 const auto paren_id = stack_.Top(si);
845 balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
846 !set_iter.Done(); set_iter.Next()) {
847 const auto dest_state = set_iter.Element();
848 if (dest_map_.find(dest_state) != dest_map_.end())
continue;
849 auto dest_weight = Weight::Zero();
851 for (
auto it = close_paren_multimap_.find(paren_state);
852 it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
853 const auto &arc = it->second;
858 Times(arc.weight, FinalDistance(state_table_.FindState(tuple))));
860 dest_map_[dest_state] = dest_weight;
861 VLOG(2) <<
"State " << dest_state <<
" is a dest state for stack ID " << si
862 <<
" with weight " << dest_weight;
869 const typename Arc::Weight &threshold) {
875 ofst_->DeleteStates();
876 ofst_->SetInputSymbols(ifst_->InputSymbols());
877 ofst_->SetOutputSymbols(ifst_->OutputSymbols());
878 limit_ =
Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
881 while (!queue_.Empty()) {
882 const auto s = queue_.Head();
885 VLOG(2) << s <<
" dequeued!";
887 StackId stack_id = state_table_.Tuple(s).stack_id;
888 ProcDestStates(s, stack_id);
891 const auto &arc = aiter.Value();
892 const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
893 if (stack_id == nextstack_id) {
894 ProcNonParen(s, arc,
true);
895 }
else if (stack_id == stack_.Pop(nextstack_id)) {
896 ProcOpenParen(s, arc, stack_id, nextstack_id);
898 ProcCloseParen(s, arc);
901 VLOG(2) <<
"d[" << s <<
"] = " << Distance(s) <<
", fd[" << s
902 <<
"] = " << FinalDistance(s);
917 Weight weight_threshold = Weight::Zero())
919 keep_parentheses(keep_parentheses),
920 weight_threshold(std::move(weight_threshold)) {}
932 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
935 using Weight =
typename Arc::Weight;
945 FSTERROR() <<
"Expand: non-Zero weight_threshold with non-idempotent" 946 <<
" Weight " << Weight::Type();
961 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
963 MutableFst<Arc> *ofst,
bool connect =
true,
bool keep_parentheses =
false) {
965 Expand(ifst, parens, ofst, opts);
970 #endif // FST_EXTENSIONS_PDT_EXPAND_H_ ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::StateId StateId
const PdtStateTable< StateId, StackId > & GetStateTable() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
typename Arc::Label Label
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
ArcIterator(const PdtExpandFst< Arc > &fst, StateId s)
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Store::State State
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
size_t NumOutputEpsilons(StateId s)
size_t NumInputEpsilons(StateId s)
PdtExpandFst(const PdtExpandFst< Arc > &fst, bool safe=false)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, const char *src="")
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens)
constexpr uint64_t kError
typename Arc::StateId StateId
constexpr uint64_t kInitialAcyclic
typename Arc::StateId StateId
PdtStack< typename Arc::StateId, typename Arc::Label > * stack
typename PdtExpandFst< Arc >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
typename internal::PdtBalanceData< Arc >::SetIterator SetIterator
~PdtExpandFstImpl() override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
typename Arc::Weight Weight
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
void ExpandState(StateId s)
const PdtStack< StackId, Label > & GetStack() const
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
typename Arc::Weight Weight
typename Collection< ssize_t, StateId >::SetIterator SetIterator
size_t NumArcs(StateId s)
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
PdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool keep_parentheses=false, PdtStack< typename Arc::StateId, typename Arc::Label > *stack=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *state_table=nullptr)
virtual uint64_t Properties() const
typename Arc::StateId StateId
constexpr uint64_t kCopyProperties
constexpr uint64_t kAcyclic
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
PdtExpandOptions(bool connect=true, bool keep_parentheses=false, Weight weight_threshold=Weight::Zero())
void Expand(MutableFst< Arc > *ofst, const Weight &threshold)
std::unique_ptr< StateIteratorBase< Arc > > base
typename Arc::Weight Weight
PdtExpandFst< Arc > * Copy(bool safe=false) const override
PdtExpandFstImpl(const PdtExpandFstImpl &impl)
StateIterator(const PdtExpandFst< Arc > &fst)
constexpr uint64_t kFstProperties
constexpr uint64_t kUnweighted
const PdtStateTable< StateId, StackId > & GetStateTable() const
typename PdtExpandFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
PdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
const PdtStack< StackId, Label > & GetStack() const
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
typename Arc::Weight Weight
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
typename Arc::Label Label
constexpr uint64_t kExpanded
typename Arc::Label Label
uint64_t Properties(uint64_t mask, bool test) const override
const Impl * GetImpl() const
PdtPrunedExpand(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, bool keep_parentheses=false, const CacheOptions &opts=CacheOptions())
constexpr uint64_t kAcceptor
virtual const SymbolTable * OutputSymbols() const =0