20 #ifndef FST_EXTENSIONS_PDT_REPLACE_H_ 21 #define FST_EXTENSIONS_PDT_REPLACE_H_ 30 #include <type_traits> 44 #include <unordered_map> 52 size_t operator()(
const std::pair<size_t, S> &paren)
const {
53 static constexpr
auto prime = 7853;
54 return paren.first + paren.second * prime;
92 using Label =
typename Arc::Label;
97 std::string left_paren_prefix =
"(_",
98 std::string right_paren_prefix =
")_")
101 start_paren_labels(start_paren_labels),
102 left_paren_prefix(std::move(left_paren_prefix)),
103 right_paren_prefix(std::move(right_paren_prefix)) {}
131 start_paren_labels_(opts.start_paren_labels),
132 left_paren_prefix_(std::move(opts.left_paren_prefix)),
133 right_paren_prefix_(std::move(opts.right_paren_prefix)),
135 for (
size_t i = 0; i < fst_array.size(); ++i) {
137 fst_array[i].second->InputSymbols())) {
138 FSTERROR() <<
"PdtParser: Input symbol table of input FST " << i
139 <<
" does not match input symbol table of 0th input FST";
143 fst_array[i].second->OutputSymbols())) {
144 FSTERROR() <<
"PdtParser: Output symbol table of input FST " << i
145 <<
" does not match output symbol table of 0th input FST";
148 fst_array_.emplace_back(fst_array[i].first, fst_array[i].second->Copy());
150 label2id_[fst_array[i].first] = i;
155 for (
auto &pair : fst_array_)
delete pair.second;
160 std::vector<LabelPair> *parens) = 0;
163 const std::vector<LabelFstPair> &
FstArray()
const {
return fst_array_; }
170 auto it = label2id_.find(l);
171 return it == label2id_.end() ?
kNoStateId : it->second;
177 if (os >= label_state_pairs_.size()) {
181 return label_state_pairs_[os];
188 auto it = state_map_.find(lsp);
189 if (it == state_map_.end()) {
200 std::vector<std::vector<StateWeightPair>> *close_src);
205 for (
size_t paren_id = 0; paren_id < total_nparens; ++paren_id) {
206 const auto open_paren = start_paren_labels_ + paren_id;
207 const auto close_paren = open_paren + total_nparens;
208 parens->emplace_back(open_paren, close_paren);
213 virtual size_t AssignParenIds(
const Fst<Arc> &ofst,
230 const std::vector<LabelPair> &parens,
const ParenMap &paren_map,
231 const std::vector<StateId> &open_dest,
232 const std::vector<std::vector<StateWeightPair>> &close_src,
233 const std::vector<bool> &close_non_term_weight,
MutableFst<Arc> *ofst);
236 void AddParensToSymbolTables(
const std::vector<LabelPair> &parens,
240 std::vector<LabelFstPair> fst_array_;
243 Label start_paren_labels_;
244 const std::string left_paren_prefix_;
245 const std::string right_paren_prefix_;
247 std::unordered_map<Label, StateId> label2id_;
249 std::vector<LabelStatePair> label_state_pairs_;
251 std::map<LabelStatePair, StateId> state_map_;
258 std::vector<std::vector<StateWeightPair>> *close_src) {
264 open_dest->resize(fst_array_.size(),
kNoStateId);
265 close_src->resize(fst_array_.size());
267 std::deque<Label> non_term_queue;
268 non_term_queue.push_back(root_);
270 std::vector<bool> enqueued(fst_array_.size(),
false);
271 enqueued[label2id_[root_]] =
true;
274 const auto label = non_term_queue.front();
275 non_term_queue.pop_front();
276 StateId fst_id = Label2Id(label);
277 const auto *ifst = fst_array_[fst_id].second;
279 const auto is = siter.Value();
282 label_state_pairs_.push_back(lsp);
283 state_map_[lsp] = os;
284 if (is == ifst->Start()) {
285 (*open_dest)[fst_id] = os;
286 if (label == root_) ofst->
SetStart(os);
288 if (ifst->Final(is) != Weight::Zero()) {
289 if (label == root_) ofst->
SetFinal(os, ifst->Final(is));
290 (*close_src)[fst_id].emplace_back(os, ifst->Final(is));
294 auto arc = aiter.Value();
295 arc.nextstate += soff;
296 if (max_label ==
kNoLabel || arc.olabel > max_label)
297 max_label = arc.olabel;
298 const auto nfst_id = Label2Id(arc.olabel);
300 if (fst_array_[nfst_id].second->Start() ==
kNoStateId)
continue;
301 if (!enqueued[nfst_id]) {
302 non_term_queue.push_back(arc.olabel);
303 enqueued[nfst_id] =
true;
310 if (start_paren_labels_ ==
kNoLabel) start_paren_labels_ = max_label + 1;
315 const std::vector<LabelPair> &parens,
const ParenMap &paren_map,
316 const std::vector<StateId> &open_dest,
317 const std::vector<std::vector<StateWeightPair>> &close_src,
318 const std::vector<bool> &close_non_term_weight,
MutableFst<Arc> *ofst) {
323 std::unique_ptr<MIter> aiter(
new MIter(ofst, os));
324 for (
auto n = 0; !aiter->Done(); aiter->Next(), ++n) {
325 const auto arc = aiter->Value();
326 StateId nfst_id = Label2Id(arc.olabel);
329 const ParenKey paren_key(nfst_id, arc.nextstate);
330 auto it = paren_map.find(paren_key);
331 Label open_paren = 0;
332 Label close_paren = 0;
333 if (it != paren_map.end()) {
334 const auto paren_id = it->second;
335 open_paren = parens[paren_id].first;
336 close_paren = parens[paren_id].second;
339 if (open_paren != 0 || !close_non_term_weight[nfst_id]) {
340 const auto open_weight =
341 close_non_term_weight[nfst_id] ? Weight::One() : arc.weight;
342 const Arc sarc(open_paren, open_paren, open_weight,
344 aiter->SetValue(sarc);
349 const Arc sarc(0, 0, Weight::One(), dead_state);
350 aiter->SetValue(sarc);
353 if (close_paren != 0 || close_non_term_weight[nfst_id]) {
354 for (
size_t i = 0; i < close_src[nfst_id].size(); ++i) {
355 const auto &pair = close_src[nfst_id][i];
356 const auto close_weight = close_non_term_weight[nfst_id]
357 ?
Times(arc.weight, pair.second)
359 const Arc farc(close_paren, close_paren, close_weight,
362 ofst->
AddArc(pair.first, farc);
363 if (os == pair.first) {
364 aiter.reset(
new MIter(ofst, os));
377 auto size = parens.size();
434 std::vector<LabelPair> *parens)
override;
439 size_t AssignParenIds(
const Fst<Arc> &ofst,
440 ParenMap *paren_map)
const override;
445 std::vector<LabelPair> *parens) {
448 const auto &fst_array = FstArray();
454 std::vector<StateId> open_dest(fst_array.size(),
kNoStateId);
458 std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
462 std::vector<bool> close_non_term_weight(fst_array.size(),
false);
463 CreateFst(ofst, &open_dest, &close_src);
464 auto total_nparens = AssignParenIds(*ofst, &paren_map);
465 AssignParenLabels(total_nparens, parens);
466 AddParensToFst(*parens, paren_map, open_dest, close_src,
467 close_non_term_weight, ofst);
468 if (!fst_array.empty()) {
472 AddParensToSymbolTables(*parens, ofst);
479 std::vector<size_t> nparens(FstArray().size(), 0);
481 size_t total_nparens = 0;
483 const auto os = siter.Value();
485 const auto &arc = aiter.Value();
486 const auto nfst_id = Label2Id(arc.olabel);
488 const ParenKey paren_key(nfst_id, arc.nextstate);
489 auto it = paren_map->find(paren_key);
490 if (it == paren_map->end()) {
492 (*paren_map)[paren_key] = nparens[nfst_id]++;
493 if (nparens[nfst_id] > total_nparens)
494 total_nparens = nparens[nfst_id];
499 return total_nparens;
537 std::vector<LabelPair> *parens)
override;
545 size_t AssignParenIds(
const Fst<Arc> &ofst,
546 ParenMap *paren_map)
const override;
549 size_t SCC(
Label label)
const {
return replace_util_.SCC(label); }
554 const auto scc_props = replace_util_.SCCProperties(scc_id);
555 return (scc_props & ll_props) == ll_props;
561 const auto scc_props = replace_util_.SCCProperties(scc_id);
562 return (scc_props & lr_props) == lr_props;
566 const std::vector<size_t> &
SCCComps(
size_t scc_id)
const {
567 if (scc_comps_.empty()) GetSCCComps();
568 return scc_comps_[scc_id];
575 if (SCCComps(scc_id).empty())
return kNoStateId;
576 const auto fst_id = SCCComps(scc_id).front();
577 const auto &fst_array = FstArray();
578 const auto label = fst_array[fst_id].first;
579 const auto *ifst = fst_array[fst_id].second;
580 if (SCCLeftLinear(scc_id)) {
582 return GetState(lsp);
585 return GetState(lsp);
594 std::vector<std::vector<StateWeightPair>> *close_src,
595 std::vector<bool> *close_non_term_weight)
const;
598 void GetSCCComps()
const {
599 const std::vector<LabelFstPair> &fst_array = FstArray();
600 for (
size_t i = 0; i < fst_array.size(); ++i) {
601 const auto label = fst_array[i].first;
602 const auto scc_id = SCC(label);
603 if (scc_comps_.size() <= scc_id) scc_comps_.resize(scc_id + 1);
604 if (SCCLeftLinear(scc_id) || SCCRightLinear(scc_id)) {
605 scc_comps_[scc_id].push_back(i);
610 const std::set<StateId> &NonTermDests(
StateId fst_id)
const {
611 if (non_term_dests_.empty()) GetNonTermDests();
612 return non_term_dests_[fst_id];
617 void GetNonTermDests()
const;
622 mutable std::vector<std::vector<size_t>> scc_comps_;
624 mutable std::vector<std::set<StateId>> non_term_dests_;
629 std::vector<LabelPair> *parens) {
632 const auto &fst_array = FstArray();
637 std::vector<StateId> open_dest(fst_array.size(),
kNoStateId);
641 std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
645 std::vector<bool> close_non_term_weight(fst_array.size(),
false);
646 CreateFst(ofst, &open_dest, &close_src);
647 ProcSCCs(ofst, &open_dest, &close_src, &close_non_term_weight);
648 const auto total_nparens = AssignParenIds(*ofst, &paren_map);
649 AssignParenLabels(total_nparens, parens);
650 AddParensToFst(*parens, paren_map, open_dest, close_src,
651 close_non_term_weight, ofst);
652 if (!fst_array.empty()) {
656 AddParensToSymbolTables(*parens, ofst);
663 std::vector<std::vector<StateWeightPair>> *close_src,
664 std::vector<bool> *close_non_term_weight)
const {
665 const auto &fst_array = FstArray();
667 const auto os = siter.Value();
668 const auto label = GetLabelStatePair(os).first;
669 const auto is = GetLabelStatePair(os).second;
670 const auto fst_id = Label2Id(label);
671 const auto scc_id = SCC(label);
672 const auto rs = RepState(scc_id);
673 const auto *ifst = fst_array[fst_id].second;
676 if (SCCLeftLinear(scc_id)) {
677 (*close_non_term_weight)[fst_id] =
true;
678 if (is == ifst->Start() && os != rs) {
681 const auto &arc = aiter.Value();
686 (*open_dest)[fst_id] = rs;
692 if (SCCRightLinear(scc_id)) {
695 auto arc = aiter.Value();
696 const auto idest = GetLabelStatePair(arc.nextstate).second;
697 if (NonTermDests(fst_id).count(idest) > 0) {
698 if (ofst->
Final(arc.nextstate) != Weight::Zero()) {
699 ofst->
SetFinal(arc.nextstate, Weight::Zero());
702 arc.weight =
Times(arc.weight, ifst->Final(idest));
707 const auto final_weight = ifst->Final(is);
708 if (final_weight != Weight::Zero() &&
709 NonTermDests(fst_id).count(is) == 0) {
710 ofst->
AddArc(os, Arc(0, 0, final_weight, rs));
711 if (ofst->
Final(os) != Weight::Zero()) {
716 if (is == ifst->Start()) {
717 (*close_src)[fst_id].clear();
718 (*close_src)[fst_id].emplace_back(rs, Weight::One());
726 const auto &fst_array = FstArray();
727 non_term_dests_.resize(fst_array.size());
728 for (
size_t fst_id = 0; fst_id < fst_array.size(); ++fst_id) {
729 const auto label = fst_array[fst_id].first;
730 const auto scc_id = SCC(label);
731 if (SCCRightLinear(scc_id)) {
732 const auto *ifst = fst_array[fst_id].second;
734 const auto is = siter.Value();
737 const auto &arc = aiter.Value();
739 non_term_dests_[fst_id].insert(arc.nextstate);
750 const auto &fst_array = FstArray();
752 std::vector<size_t> nparens(fst_array.size(), 0);
754 size_t total_nparens = 0;
756 const auto os = siter.Value();
757 const auto label = GetLabelStatePair(os).first;
758 const auto scc_id = SCC(label);
760 const auto &arc = aiter.Value();
761 const auto nfst_id = Label2Id(arc.olabel);
763 size_t nscc_id = SCC(arc.olabel);
764 bool nscc_linear = !SCCComps(nscc_id).empty();
768 if (!nscc_linear || scc_id != nscc_id) {
773 nscc_linear ? SCCComps(nscc_id).front() : nfst_id;
774 ParenKey paren_key(pfst_id, arc.nextstate);
775 const auto it = paren_map->find(paren_key);
776 if (it == paren_map->end()) {
781 const ParenKey nparen_key(nfst_id, arc.nextstate);
782 (*paren_map)[nparen_key] = nparens[pfst_id];
784 (*paren_map)[paren_key] = nparens[pfst_id]++;
785 if (nparens[pfst_id] > total_nparens) {
786 total_nparens = nparens[pfst_id];
793 return total_nparens;
803 const std::vector<std::pair<
typename Arc::Label,
const Fst<Arc> *>>
806 std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
820 FSTERROR() <<
"Replace: Unknown PDT parser type: " 821 <<
static_cast<std::underlying_type_t<PdtParserType>
>(
833 const std::vector<std::pair<
typename Arc::Label,
const Fst<Arc> *>>
836 std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
837 typename Arc::Label root) {
839 Replace(ifst_array, ofst, parens, opts);
844 #endif // FST_EXTENSIONS_PDT_REPLACE_H_
const std::string right_paren_prefix
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
std::unordered_map< ParenKey, size_t, internal::ReplaceParenHash< StateId >> ParenMap
virtual SymbolTable * MutableOutputSymbols()=0
PdtLeftSRParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
StateId RepState(size_t scc_id) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
StateId Label2Id(Label l) const
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool SCCLeftLinear(size_t scc_id) const
void Connect(MutableFst< Arc > *fst)
constexpr uint8_t kReplaceSCCLeftLinear
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
size_t operator()(const std::pair< size_t, S > &paren) const
const std::string left_paren_prefix
void AssignParenLabels(size_t total_nparens, std::vector< LabelPair > *parens)
std::pair< Label, StateId > LabelStatePair
bool SCCRightLinear(size_t scc_id) const
const Arc & Value() const
typename Arc::StateId StateId
PdtReplaceOptions(Label root, PdtParserType type=PdtParserType::LEFT, Label start_paren_labels=kNoLabel, std::string left_paren_prefix="(_", std::string right_paren_prefix=")_")
typename Arc::Label Label
void CreateFst(MutableFst< Arc > *ofst, std::vector< StateId > *open_dest, std::vector< std::vector< StateWeightPair >> *close_src)
const SymbolTable * OutputSymbols() const override=0
std::pair< StateId, Weight > StateWeightPair
constexpr uint8_t kReplaceSCCNonTrivial
const std::vector< size_t > & SCCComps(size_t scc_id) const
virtual void SetProperties(uint64_t props, uint64_t mask)=0
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
virtual void DeleteArcs(StateId, size_t)=0
virtual StateId Start() const =0
StateId GetState(const LabelStatePair &lsp) const
PdtParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
std::pair< size_t, StateId > ParenKey
const std::vector< LabelFstPair > & FstArray() const
std::pair< Label, Label > LabelPair
void AddParensToSymbolTables(const std::vector< LabelPair > &parens, MutableFst< Arc > *ofst)
typename Arc::Label Label
virtual SymbolTable * MutableInputSymbols()=0
virtual void AddArc(StateId, const Arc &)=0
constexpr uint8_t kReplaceSCCRightLinear
virtual StateId AddState()=0
std::pair< Label, const Fst< Arc > * > LabelFstPair
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
void AddParensToFst(const std::vector< LabelPair > &parens, const ParenMap &paren_map, const std::vector< StateId > &open_dest, const std::vector< std::vector< StateWeightPair >> &close_src, const std::vector< bool > &close_non_term_weight, MutableFst< Arc > *ofst)
virtual void DeleteStates(const std::vector< StateId > &)=0
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
typename Arc::Weight Weight
virtual StateId NumStates() const =0
typename Arc::StateId StateId
LabelStatePair GetLabelStatePair(StateId os) const
PdtLeftParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
bool AddAuxiliarySymbols(const std::string &prefix, int64_t start_label, int64_t nlabels, SymbolTable *syms)
size_t SCC(Label label) const