6 #ifndef FST_EXTENSIONS_PDT_EXPAND_H_ 7 #define FST_EXTENSIONS_PDT_EXPAND_H_ 9 #include <forward_list> 38 keep_parentheses(keep_parentheses),
40 state_table(state_table) {}
49 using Label =
typename Arc::Label;
71 const std::vector<std::pair<Label, Label>> &parens,
76 state_table_(opts.state_table ? opts.state_table
78 own_stack_(opts.stack == 0),
79 own_state_table_(opts.state_table == 0),
80 keep_parentheses_(opts.keep_parentheses) {
90 fst_(impl.fst_->Copy(true)),
94 own_state_table_(true),
95 keep_parentheses_(impl.keep_parentheses_) {
103 if (own_stack_)
delete stack_;
104 if (own_state_table_)
delete state_table_;
109 const auto s = fst_->Start();
112 const auto start = state_table_->FindState(tuple);
120 const auto &tuple = state_table_->Tuple(s);
121 const auto weight = fst_->Final(tuple.state_id);
122 if (weight != Weight::Zero() && tuple.stack_id == 0)
125 SetFinal(s, Weight::Zero());
131 if (!HasArcs(s)) ExpandState(s);
136 if (!HasArcs(s)) ExpandState(s);
141 if (!HasArcs(s)) ExpandState(s);
146 if (!HasArcs(s)) ExpandState(s);
156 auto arc = aiter.Value();
157 const auto stack_id = stack_->Find(tuple.
stack_id, arc.ilabel);
158 if (stack_id == -1) {
160 }
else if ((stack_id != tuple.
stack_id) && !keep_parentheses_) {
166 arc.nextstate = state_table_->FindState(ntuple);
175 return *state_table_;
184 std::unique_ptr<const Fst<Arc>> fst_;
188 bool own_state_table_;
189 bool keep_parentheses_;
221 const std::vector<std::pair<Label, Label>> &parens)
226 const std::vector<std::pair<Label, Label>> &parens,
242 GetMutableImpl()->InitArcIterator(s, data);
246 return GetImpl()->GetStack();
250 return GetImpl()->GetStateTable();
319 const std::vector<std::pair<Label, Label>> &parens,
320 bool keep_parentheses =
false,
322 : ifst_(ifst.Copy()),
323 keep_parentheses_(keep_parentheses),
327 queue_(state_table_, stack_, stack_length_, distance_, fdistance_),
329 Reverse(*ifst_, parens, &rfst_);
334 reverse_shortest_path_->ShortestPath(&path);
336 balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse(
337 rfst_.NumStates(), 10, -1));
338 InitCloseParenMultimap(parens);
341 bool Error()
const {
return error_; }
348 static constexpr
uint8 kEnqueued = 0x01;
350 static constexpr
uint8 kSourceState = 0x04;
360 const std::vector<StackId> &stack_length,
361 const std::vector<Weight> &distance,
362 const std::vector<Weight> &fdistance)
363 : state_table_(state_table),
365 stack_length_(stack_length),
367 fdistance_(fdistance) {}
370 auto si1 = state_table_.Tuple(s1).stack_id;
371 auto si2 = state_table_.Tuple(s2).stack_id;
372 if (stack_length_[si1] < stack_length_[si2])
return true;
373 if (stack_length_[si1] > stack_length_[si2])
return false;
376 return less_(Distance(s1), Distance(s2));
379 for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
380 if (stack_.Top(si1) < stack_.Top(si2))
return true;
381 if (stack_.Top(si1) > stack_.Top(si2))
return false;
388 return (s < distance_.size()) && (s < fdistance_.size())
389 ?
Times(distance_[s], fdistance_[s])
395 const std::vector<StackId> &stack_length_;
396 const std::vector<Weight> &distance_;
397 const std::vector<Weight> &fdistance_;
401 class ShortestStackFirstQueue
406 const std::vector<StackId> &stack_length,
407 const std::vector<Weight> &distance,
408 const std::vector<Weight> &fdistance)
410 state_table, stack, stack_length, distance, fdistance)) {}
413 void InitCloseParenMultimap(
414 const std::vector<std::pair<Label, Label>> &parens);
434 void AddStateAndEnqueue(
StateId s);
438 bool PruneArc(
StateId s,
const Arc &arc);
444 bool ProcNonParen(
StateId s,
const Arc &arc,
bool add_arc);
448 bool ProcCloseParen(
StateId s,
const Arc &arc);
453 std::unique_ptr<Fst<Arc>> ifst_;
457 const bool keep_parentheses_;
465 std::vector<StackId> stack_length_;
467 std::vector<Weight> distance_;
469 std::vector<Weight> fdistance_;
471 ShortestStackFirstQueue queue_;
475 std::vector<uint8> flags_;
477 std::vector<StateId> sources_;
479 std::unique_ptr<PdtShortestPath<Arc, FifoQueue<StateId>>>
480 reverse_shortest_path_;
481 std::unique_ptr<internal::PdtBalanceData<Arc>> balance_data_;
484 close_paren_multimap_;
490 std::unordered_map<StateId, Weight> dest_map_;
494 ssize_t current_paren_id_;
495 ssize_t cached_stack_id_;
502 std::forward_list<std::pair<StateId, Weight>> cached_dest_list_;
510 const std::vector<std::pair<Label, Label>> &parens) {
511 std::unordered_map<Label, Label> paren_map;
512 for (
size_t i = 0; i < parens.size(); ++i) {
513 const auto &pair = parens[i];
514 paren_map[pair.first] = i;
515 paren_map[pair.second] = i;
518 const auto s = siter.Value();
520 const auto &arc = aiter.Value();
521 const auto it = paren_map.find(arc.ilabel);
522 if (it == paren_map.end())
continue;
523 if (arc.ilabel == parens[it->second].second) {
525 close_paren_multimap_.emplace(key, arc);
538 const SearchState ss(source + 1, dest + 1);
539 const auto distance =
540 reverse_shortest_path_->GetShortestPathData().Distance(ss);
541 VLOG(2) <<
"D(" << source <<
", " << dest <<
") =" << distance;
548 return s < flags_.size() ? flags_[s] : 0;
554 while (flags_.size() <= s) flags_.push_back(0);
556 flags_[s] |= flags & mask;
562 return s < distance_.size() ? distance_[s] : Weight::Zero();
568 while (distance_.size() <= s) distance_.push_back(Weight::Zero());
569 distance_[s] = std::move(weight);
575 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
581 while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
582 fdistance_[s] = std::move(weight);
588 return s < sources_.size() ? sources_[s] :
kNoStateId;
594 while (sources_.size() <= s) sources_.push_back(
kNoStateId);
602 if (!(Flags(s) & (kEnqueued |
kExpanded))) {
603 while (ofst_->NumStates() <= s) ofst_->AddState();
606 }
else if (Flags(s) & kEnqueued) {
621 const auto nd =
Times(Distance(s), arc.weight);
622 if (less_(nd, Distance(arc.nextstate))) {
623 SetDistance(arc.nextstate, nd);
624 SetSourceState(arc.nextstate, SourceState(s));
626 if (less_(fd, FinalDistance(arc.nextstate))) {
627 SetFinalDistance(arc.nextstate, fd);
629 VLOG(2) <<
"Relax: " << s <<
", d[s] = " << Distance(s) <<
", to " 630 << arc.nextstate <<
", d[ns] = " << Distance(arc.nextstate)
637 VLOG(2) <<
"Prune ?";
638 auto fd = Weight::Zero();
639 if ((cached_source_ != SourceState(s)) ||
640 (cached_stack_id_ != current_stack_id_)) {
641 cached_source_ = SourceState(s);
642 cached_stack_id_ = current_stack_id_;
643 cached_dest_list_.clear();
644 if (cached_source_ != ifst_->Start()) {
646 balance_data_->Find(current_paren_id_, cached_source_);
647 !set_iter.Done(); set_iter.Next()) {
648 auto dest = set_iter.Element();
649 const auto it = dest_map_.find(dest);
650 cached_dest_list_.push_front(*it);
656 cached_dest_list_.push_front(
657 std::make_pair(rfst_.Start() - 1, Weight::One()));
660 for (
auto it = cached_dest_list_.begin(); it != cached_dest_list_.end();
663 DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first);
667 return less_(limit_,
Times(Distance(s),
Times(arc.weight, fd)));
674 const auto s = efst_.Start();
675 AddStateAndEnqueue(s);
677 SetSourceState(s, ifst_->Start());
678 current_stack_id_ = 0;
679 current_paren_id_ = -1;
680 stack_length_.push_back(0);
681 const auto r = rfst_.Start() - 1;
682 cached_source_ = ifst_->Start();
683 cached_stack_id_ = 0;
684 cached_dest_list_.push_front(std::make_pair(r, Weight::One()));
686 SetFinalDistance(state_table_.FindState(tuple), Weight::One());
687 SetDistance(s, Weight::One());
688 const auto d = DistanceToDest(ifst_->Start(), r);
689 SetFinalDistance(s, d);
697 const auto weight = efst_.Final(s);
698 if (weight == Weight::Zero())
return;
699 if (less_(limit_,
Times(Distance(s), weight)))
return;
700 ofst_->SetFinal(s, weight);
708 VLOG(2) <<
"ProcNonParen: " << s <<
" to " << arc.nextstate <<
", " 709 << arc.ilabel <<
":" << arc.olabel <<
" / " << arc.weight
710 <<
", add_arc = " << (add_arc ?
"true" :
"false");
711 if (PruneArc(s, arc))
return false;
712 if (add_arc) ofst_->AddArc(s, arc);
713 AddStateAndEnqueue(arc.nextstate);
731 while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
732 if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1;
733 const auto ns = arc.nextstate;
734 VLOG(2) <<
"Open paren: " << s <<
"(" << state_table_.Tuple(s).state_id
735 <<
") to " << ns <<
"(" << state_table_.Tuple(ns).state_id <<
")";
736 bool proc_arc =
false;
737 auto fd = Weight::Zero();
738 const auto paren_id = stack_.ParenId(arc.ilabel);
739 std::forward_list<StateId> sources;
741 balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
742 !set_iter.Done(); set_iter.Next()) {
743 sources.push_front(set_iter.Element());
745 for (
const auto source : sources) {
746 VLOG(2) <<
"Close paren source: " << source;
748 for (
auto it = close_paren_multimap_.find(paren_state);
749 it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
750 auto meta_arc = it->second;
752 meta_arc.nextstate = state_table_.FindState(tuple);
753 const auto state_id = state_table_.Tuple(ns).state_id;
754 const auto d = DistanceToDest(state_id, source);
755 VLOG(2) << state_id <<
", " << source;
756 VLOG(2) <<
"Meta arc weight = " << arc.weight <<
" Times " << d
757 <<
" Times " << meta_arc.weight;
758 meta_arc.weight =
Times(arc.weight,
Times(d, meta_arc.weight));
759 proc_arc |= ProcNonParen(s, meta_arc,
false);
762 Times(
Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
764 FinalDistance(meta_arc.nextstate)));
768 VLOG(2) <<
"Proc open paren " << s <<
" to " << arc.nextstate;
770 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
771 AddStateAndEnqueue(arc.nextstate);
772 const auto nd =
Times(Distance(s), arc.weight);
773 if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd);
776 if (less_(fd, FinalDistance(arc.nextstate)))
777 SetFinalDistance(arc.nextstate, fd);
778 SetFlags(arc.nextstate, kSourceState, kSourceState);
788 Times(Distance(s),
Times(arc.weight, FinalDistance(arc.nextstate)));
789 if (less_(limit_, weight))
return false;
791 keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
802 if (!(Flags(s) & kSourceState))
return;
803 if (si != current_stack_id_) {
805 current_stack_id_ = si;
806 current_paren_id_ = stack_.Top(current_stack_id_);
807 VLOG(2) <<
"StackID " << si <<
" dequeued for first time";
811 SetSourceState(s, state_table_.Tuple(s).state_id);
812 const auto paren_id = stack_.Top(si);
814 balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
815 !set_iter.Done(); set_iter.Next()) {
816 const auto dest_state = set_iter.Element();
817 if (dest_map_.find(dest_state) != dest_map_.end())
continue;
818 auto dest_weight = Weight::Zero();
820 for (
auto it = close_paren_multimap_.find(paren_state);
821 it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
822 const auto &arc = it->second;
827 Times(arc.weight, FinalDistance(state_table_.FindState(tuple))));
829 dest_map_[dest_state] = dest_weight;
830 VLOG(2) <<
"State " << dest_state <<
" is a dest state for stack ID " << si
831 <<
" with weight " << dest_weight;
838 const typename Arc::Weight &threshold) {
844 ofst_->DeleteStates();
845 ofst_->SetInputSymbols(ifst_->InputSymbols());
846 ofst_->SetOutputSymbols(ifst_->OutputSymbols());
847 limit_ =
Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
850 while (!queue_.Empty()) {
851 const auto s = queue_.Head();
854 VLOG(2) << s <<
" dequeued!";
856 StackId stack_id = state_table_.Tuple(s).stack_id;
857 ProcDestStates(s, stack_id);
860 const auto &arc = aiter.Value();
861 const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
862 if (stack_id == nextstack_id) {
863 ProcNonParen(s, arc,
true);
864 }
else if (stack_id == stack_.Pop(nextstack_id)) {
865 ProcOpenParen(s, arc, stack_id, nextstack_id);
867 ProcCloseParen(s, arc);
870 VLOG(2) <<
"d[" << s <<
"] = " << Distance(s) <<
", fd[" << s
871 <<
"] = " << FinalDistance(s);
886 Weight weight_threshold = Weight::Zero())
888 keep_parentheses(keep_parentheses),
889 weight_threshold(
std::move(weight_threshold)) {}
901 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
924 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
926 bool keep_parentheses =
false) {
928 Expand(ifst, parens, ofst, opts);
933 #endif // FST_EXTENSIONS_PDT_EXPAND_H_ ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::StateId StateId
const PdtStateTable< StateId, StackId > & GetStateTable() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
typename Arc::Label Label
constexpr uint64 kInitialAcyclic
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
ArcIterator(const PdtExpandFst< Arc > &fst, StateId s)
typename Store::State State
size_t NumOutputEpsilons(StateId s)
size_t NumInputEpsilons(StateId s)
PdtExpandFst(const PdtExpandFst< Arc > &fst, bool safe=false)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, const char *src="")
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens)
typename Arc::StateId StateId
typename Arc::StateId StateId
PdtStack< typename Arc::StateId, typename Arc::Label > * stack
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
typename PdtExpandFst< Arc >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename internal::PdtBalanceData< Arc >::SetIterator SetIterator
~PdtExpandFstImpl() override
constexpr uint64 kFstProperties
constexpr uint64 kCopyProperties
virtual uint64 Properties() const
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
typename Arc::Weight Weight
constexpr uint64 kExpanded
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
virtual uint64 Properties(uint64 mask, bool test) const =0
void ExpandState(StateId s)
const PdtStack< StackId, Label > & GetStack() const
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
typename Arc::Weight Weight
constexpr uint64 kUnweighted
typename Collection< ssize_t, StateId >::SetIterator SetIterator
size_t NumArcs(StateId s)
StateIteratorBase< Arc > * base
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
CacheOptions(bool gc=FLAGS_fst_default_cache_gc, size_t gc_limit=FLAGS_fst_default_cache_gc_limit)
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
PdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool keep_parentheses=false, PdtStack< typename Arc::StateId, typename Arc::Label > *stack=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *state_table=nullptr)
typename Arc::StateId StateId
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
PdtExpandOptions(bool connect=true, bool keep_parentheses=false, Weight weight_threshold=Weight::Zero())
void Expand(MutableFst< Arc > *ofst, const Weight &threshold)
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
constexpr uint64 kAcceptor
typename Arc::Weight Weight
PdtExpandFst< Arc > * Copy(bool safe=false) const override
PdtExpandFstImpl(const PdtExpandFstImpl &impl)
StateIterator(const PdtExpandFst< Arc > &fst)
const PdtStateTable< StateId, StackId > & GetStateTable() const
typename PdtExpandFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
constexpr uint64 kAcyclic
PdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
const PdtStack< StackId, Label > & GetStack() const
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
typename Arc::Weight Weight
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
typename Arc::Label Label
uint64 Properties(uint64 mask, bool test) const override
typename Arc::Label Label
const Impl * GetImpl() const
PdtPrunedExpand(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, bool keep_parentheses=false, const CacheOptions &opts=CacheOptions())
virtual void SetProperties(uint64 props, uint64 mask)=0
virtual const SymbolTable * OutputSymbols() const =0