20 #ifndef FST_EXTENSIONS_PDT_REPLACE_H_ 21 #define FST_EXTENSIONS_PDT_REPLACE_H_ 28 #include <type_traits> 35 #include <unordered_map> 43 size_t operator()(
const std::pair<size_t, S> &paren)
const {
44 static constexpr
auto prime = 7853;
45 return paren.first + paren.second * prime;
83 using Label =
typename Arc::Label;
88 std::string left_paren_prefix =
"(_",
89 std::string right_paren_prefix =
")_")
92 start_paren_labels(start_paren_labels),
93 left_paren_prefix(std::move(left_paren_prefix)),
94 right_paren_prefix(std::move(right_paren_prefix)) {}
122 start_paren_labels_(opts.start_paren_labels),
123 left_paren_prefix_(std::move(opts.left_paren_prefix)),
124 right_paren_prefix_(std::move(opts.right_paren_prefix)),
126 for (
size_t i = 0; i < fst_array.size(); ++i) {
128 fst_array[i].second->InputSymbols())) {
129 FSTERROR() <<
"PdtParser: Input symbol table of input FST " << i
130 <<
" does not match input symbol table of 0th input FST";
134 fst_array[i].second->OutputSymbols())) {
135 FSTERROR() <<
"PdtParser: Output symbol table of input FST " << i
136 <<
" does not match output symbol table of 0th input FST";
139 fst_array_.emplace_back(fst_array[i].first, fst_array[i].second->Copy());
141 label2id_[fst_array[i].first] = i;
146 for (
auto &pair : fst_array_)
delete pair.second;
151 std::vector<LabelPair> *parens) = 0;
154 const std::vector<LabelFstPair> &
FstArray()
const {
return fst_array_; }
161 auto it = label2id_.find(l);
162 return it == label2id_.end() ?
kNoStateId : it->second;
168 if (os >= label_state_pairs_.size()) {
172 return label_state_pairs_[os];
179 auto it = state_map_.find(lsp);
180 if (it == state_map_.end()) {
191 std::vector<std::vector<StateWeightPair>> *close_src);
196 for (
size_t paren_id = 0; paren_id < total_nparens; ++paren_id) {
197 const auto open_paren = start_paren_labels_ + paren_id;
198 const auto close_paren = open_paren + total_nparens;
199 parens->emplace_back(open_paren, close_paren);
204 virtual size_t AssignParenIds(
const Fst<Arc> &ofst,
221 const std::vector<LabelPair> &parens,
const ParenMap &paren_map,
222 const std::vector<StateId> &open_dest,
223 const std::vector<std::vector<StateWeightPair>> &close_src,
224 const std::vector<bool> &close_non_term_weight,
MutableFst<Arc> *ofst);
227 void AddParensToSymbolTables(
const std::vector<LabelPair> &parens,
231 std::vector<LabelFstPair> fst_array_;
234 Label start_paren_labels_;
235 const std::string left_paren_prefix_;
236 const std::string right_paren_prefix_;
238 std::unordered_map<Label, StateId> label2id_;
240 std::vector<LabelStatePair> label_state_pairs_;
242 std::map<LabelStatePair, StateId> state_map_;
249 std::vector<std::vector<StateWeightPair>> *close_src) {
255 open_dest->resize(fst_array_.size(),
kNoStateId);
256 close_src->resize(fst_array_.size());
258 std::deque<Label> non_term_queue;
259 non_term_queue.push_back(root_);
261 std::vector<bool> enqueued(fst_array_.size(),
false);
262 enqueued[label2id_[root_]] =
true;
265 const auto label = non_term_queue.front();
266 non_term_queue.pop_front();
267 StateId fst_id = Label2Id(label);
268 const auto *ifst = fst_array_[fst_id].second;
270 const auto is = siter.Value();
273 label_state_pairs_.push_back(lsp);
274 state_map_[lsp] = os;
275 if (is == ifst->Start()) {
276 (*open_dest)[fst_id] = os;
277 if (label == root_) ofst->
SetStart(os);
279 if (ifst->Final(is) != Weight::Zero()) {
280 if (label == root_) ofst->
SetFinal(os, ifst->Final(is));
281 (*close_src)[fst_id].emplace_back(os, ifst->Final(is));
285 auto arc = aiter.Value();
286 arc.nextstate += soff;
287 if (max_label ==
kNoLabel || arc.olabel > max_label)
288 max_label = arc.olabel;
289 const auto nfst_id = Label2Id(arc.olabel);
291 if (fst_array_[nfst_id].second->Start() ==
kNoStateId)
continue;
292 if (!enqueued[nfst_id]) {
293 non_term_queue.push_back(arc.olabel);
294 enqueued[nfst_id] =
true;
301 if (start_paren_labels_ ==
kNoLabel) start_paren_labels_ = max_label + 1;
306 const std::vector<LabelPair> &parens,
const ParenMap &paren_map,
307 const std::vector<StateId> &open_dest,
308 const std::vector<std::vector<StateWeightPair>> &close_src,
309 const std::vector<bool> &close_non_term_weight,
MutableFst<Arc> *ofst) {
314 std::unique_ptr<MIter> aiter(
new MIter(ofst, os));
315 for (
auto n = 0; !aiter->Done(); aiter->Next(), ++n) {
316 const auto arc = aiter->Value();
317 StateId nfst_id = Label2Id(arc.olabel);
320 const ParenKey paren_key(nfst_id, arc.nextstate);
321 auto it = paren_map.find(paren_key);
322 Label open_paren = 0;
323 Label close_paren = 0;
324 if (it != paren_map.end()) {
325 const auto paren_id = it->second;
326 open_paren = parens[paren_id].first;
327 close_paren = parens[paren_id].second;
330 if (open_paren != 0 || !close_non_term_weight[nfst_id]) {
331 const auto open_weight =
332 close_non_term_weight[nfst_id] ? Weight::One() : arc.weight;
333 const Arc sarc(open_paren, open_paren, open_weight,
335 aiter->SetValue(sarc);
340 const Arc sarc(0, 0, Weight::One(), dead_state);
341 aiter->SetValue(sarc);
344 if (close_paren != 0 || close_non_term_weight[nfst_id]) {
345 for (
size_t i = 0; i < close_src[nfst_id].size(); ++i) {
346 const auto &pair = close_src[nfst_id][i];
347 const auto close_weight = close_non_term_weight[nfst_id]
348 ?
Times(arc.weight, pair.second)
350 const Arc farc(close_paren, close_paren, close_weight,
353 ofst->
AddArc(pair.first, farc);
354 if (os == pair.first) {
355 aiter.reset(
new MIter(ofst, os));
368 auto size = parens.size();
425 std::vector<LabelPair> *parens)
override;
430 size_t AssignParenIds(
const Fst<Arc> &ofst,
431 ParenMap *paren_map)
const override;
436 std::vector<LabelPair> *parens) {
439 const auto &fst_array = FstArray();
445 std::vector<StateId> open_dest(fst_array.size(),
kNoStateId);
449 std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
453 std::vector<bool> close_non_term_weight(fst_array.size(),
false);
454 CreateFst(ofst, &open_dest, &close_src);
455 auto total_nparens = AssignParenIds(*ofst, &paren_map);
456 AssignParenLabels(total_nparens, parens);
457 AddParensToFst(*parens, paren_map, open_dest, close_src,
458 close_non_term_weight, ofst);
459 if (!fst_array.empty()) {
463 AddParensToSymbolTables(*parens, ofst);
470 std::vector<size_t> nparens(FstArray().size(), 0);
472 size_t total_nparens = 0;
474 const auto os = siter.Value();
476 const auto &arc = aiter.Value();
477 const auto nfst_id = Label2Id(arc.olabel);
479 const ParenKey paren_key(nfst_id, arc.nextstate);
480 auto it = paren_map->find(paren_key);
481 if (it == paren_map->end()) {
483 (*paren_map)[paren_key] = nparens[nfst_id]++;
484 if (nparens[nfst_id] > total_nparens)
485 total_nparens = nparens[nfst_id];
490 return total_nparens;
528 std::vector<LabelPair> *parens)
override;
536 size_t AssignParenIds(
const Fst<Arc> &ofst,
537 ParenMap *paren_map)
const override;
540 size_t SCC(
Label label)
const {
return replace_util_.SCC(label); }
545 const auto scc_props = replace_util_.SCCProperties(scc_id);
546 return (scc_props & ll_props) == ll_props;
552 const auto scc_props = replace_util_.SCCProperties(scc_id);
553 return (scc_props & lr_props) == lr_props;
557 const std::vector<size_t> &
SCCComps(
size_t scc_id)
const {
558 if (scc_comps_.empty()) GetSCCComps();
559 return scc_comps_[scc_id];
566 if (SCCComps(scc_id).empty())
return kNoStateId;
567 const auto fst_id = SCCComps(scc_id).front();
568 const auto &fst_array = FstArray();
569 const auto label = fst_array[fst_id].first;
570 const auto *ifst = fst_array[fst_id].second;
571 if (SCCLeftLinear(scc_id)) {
573 return GetState(lsp);
576 return GetState(lsp);
585 std::vector<std::vector<StateWeightPair>> *close_src,
586 std::vector<bool> *close_non_term_weight)
const;
589 void GetSCCComps()
const {
590 const std::vector<LabelFstPair> &fst_array = FstArray();
591 for (
size_t i = 0; i < fst_array.size(); ++i) {
592 const auto label = fst_array[i].first;
593 const auto scc_id = SCC(label);
594 if (scc_comps_.size() <= scc_id) scc_comps_.resize(scc_id + 1);
595 if (SCCLeftLinear(scc_id) || SCCRightLinear(scc_id)) {
596 scc_comps_[scc_id].push_back(i);
601 const std::set<StateId> &NonTermDests(
StateId fst_id)
const {
602 if (non_term_dests_.empty()) GetNonTermDests();
603 return non_term_dests_[fst_id];
608 void GetNonTermDests()
const;
613 mutable std::vector<std::vector<size_t>> scc_comps_;
615 mutable std::vector<std::set<StateId>> non_term_dests_;
620 std::vector<LabelPair> *parens) {
623 const auto &fst_array = FstArray();
628 std::vector<StateId> open_dest(fst_array.size(),
kNoStateId);
632 std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
636 std::vector<bool> close_non_term_weight(fst_array.size(),
false);
637 CreateFst(ofst, &open_dest, &close_src);
638 ProcSCCs(ofst, &open_dest, &close_src, &close_non_term_weight);
639 const auto total_nparens = AssignParenIds(*ofst, &paren_map);
640 AssignParenLabels(total_nparens, parens);
641 AddParensToFst(*parens, paren_map, open_dest, close_src,
642 close_non_term_weight, ofst);
643 if (!fst_array.empty()) {
647 AddParensToSymbolTables(*parens, ofst);
654 std::vector<std::vector<StateWeightPair>> *close_src,
655 std::vector<bool> *close_non_term_weight)
const {
656 const auto &fst_array = FstArray();
658 const auto os = siter.Value();
659 const auto label = GetLabelStatePair(os).first;
660 const auto is = GetLabelStatePair(os).second;
661 const auto fst_id = Label2Id(label);
662 const auto scc_id = SCC(label);
663 const auto rs = RepState(scc_id);
664 const auto *ifst = fst_array[fst_id].second;
667 if (SCCLeftLinear(scc_id)) {
668 (*close_non_term_weight)[fst_id] =
true;
669 if (is == ifst->Start() && os != rs) {
672 const auto &arc = aiter.Value();
677 (*open_dest)[fst_id] = rs;
683 if (SCCRightLinear(scc_id)) {
686 auto arc = aiter.Value();
687 const auto idest = GetLabelStatePair(arc.nextstate).second;
688 if (NonTermDests(fst_id).count(idest) > 0) {
689 if (ofst->
Final(arc.nextstate) != Weight::Zero()) {
690 ofst->
SetFinal(arc.nextstate, Weight::Zero());
693 arc.weight =
Times(arc.weight, ifst->Final(idest));
698 const auto final_weight = ifst->Final(is);
699 if (final_weight != Weight::Zero() &&
700 NonTermDests(fst_id).count(is) == 0) {
701 ofst->
AddArc(os, Arc(0, 0, final_weight, rs));
702 if (ofst->
Final(os) != Weight::Zero()) {
707 if (is == ifst->Start()) {
708 (*close_src)[fst_id].clear();
709 (*close_src)[fst_id].emplace_back(rs, Weight::One());
717 const auto &fst_array = FstArray();
718 non_term_dests_.resize(fst_array.size());
719 for (
size_t fst_id = 0; fst_id < fst_array.size(); ++fst_id) {
720 const auto label = fst_array[fst_id].first;
721 const auto scc_id = SCC(label);
722 if (SCCRightLinear(scc_id)) {
723 const auto *ifst = fst_array[fst_id].second;
725 const auto is = siter.Value();
728 const auto &arc = aiter.Value();
730 non_term_dests_[fst_id].insert(arc.nextstate);
741 const auto &fst_array = FstArray();
743 std::vector<size_t> nparens(fst_array.size(), 0);
745 size_t total_nparens = 0;
747 const auto os = siter.Value();
748 const auto label = GetLabelStatePair(os).first;
749 const auto scc_id = SCC(label);
751 const auto &arc = aiter.Value();
752 const auto nfst_id = Label2Id(arc.olabel);
754 size_t nscc_id = SCC(arc.olabel);
755 bool nscc_linear = !SCCComps(nscc_id).empty();
759 if (!nscc_linear || scc_id != nscc_id) {
764 nscc_linear ? SCCComps(nscc_id).front() : nfst_id;
765 ParenKey paren_key(pfst_id, arc.nextstate);
766 const auto it = paren_map->find(paren_key);
767 if (it == paren_map->end()) {
772 const ParenKey nparen_key(nfst_id, arc.nextstate);
773 (*paren_map)[nparen_key] = nparens[pfst_id];
775 (*paren_map)[paren_key] = nparens[pfst_id]++;
776 if (nparens[pfst_id] > total_nparens) {
777 total_nparens = nparens[pfst_id];
784 return total_nparens;
794 const std::vector<std::pair<
typename Arc::Label,
const Fst<Arc> *>>
797 std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
811 FSTERROR() <<
"Replace: Unknown PDT parser type: " 812 <<
static_cast<std::underlying_type_t<PdtParserType>
>(
824 const std::vector<std::pair<
typename Arc::Label,
const Fst<Arc> *>>
827 std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
828 typename Arc::Label root) {
830 Replace(ifst_array, ofst, parens, opts);
835 #endif // FST_EXTENSIONS_PDT_REPLACE_H_
const std::string right_paren_prefix
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
std::unordered_map< ParenKey, size_t, internal::ReplaceParenHash< StateId >> ParenMap
virtual SymbolTable * MutableOutputSymbols()=0
PdtLeftSRParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
StateId RepState(size_t scc_id) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
StateId Label2Id(Label l) const
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool SCCLeftLinear(size_t scc_id) const
void Connect(MutableFst< Arc > *fst)
constexpr uint8_t kReplaceSCCLeftLinear
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
size_t operator()(const std::pair< size_t, S > &paren) const
const std::string left_paren_prefix
void AssignParenLabels(size_t total_nparens, std::vector< LabelPair > *parens)
std::pair< Label, StateId > LabelStatePair
bool SCCRightLinear(size_t scc_id) const
const Arc & Value() const
typename Arc::StateId StateId
PdtReplaceOptions(Label root, PdtParserType type=PdtParserType::LEFT, Label start_paren_labels=kNoLabel, std::string left_paren_prefix="(_", std::string right_paren_prefix=")_")
typename Arc::Label Label
void CreateFst(MutableFst< Arc > *ofst, std::vector< StateId > *open_dest, std::vector< std::vector< StateWeightPair >> *close_src)
const SymbolTable * OutputSymbols() const override=0
std::pair< StateId, Weight > StateWeightPair
constexpr uint8_t kReplaceSCCNonTrivial
const std::vector< size_t > & SCCComps(size_t scc_id) const
virtual void SetProperties(uint64_t props, uint64_t mask)=0
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
virtual void DeleteArcs(StateId, size_t)=0
virtual StateId Start() const =0
StateId GetState(const LabelStatePair &lsp) const
PdtParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
std::pair< size_t, StateId > ParenKey
const std::vector< LabelFstPair > & FstArray() const
std::pair< Label, Label > LabelPair
void AddParensToSymbolTables(const std::vector< LabelPair > &parens, MutableFst< Arc > *ofst)
typename Arc::Label Label
virtual SymbolTable * MutableInputSymbols()=0
virtual void AddArc(StateId, const Arc &)=0
constexpr uint8_t kReplaceSCCRightLinear
virtual StateId AddState()=0
std::pair< Label, const Fst< Arc > * > LabelFstPair
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
void AddParensToFst(const std::vector< LabelPair > &parens, const ParenMap &paren_map, const std::vector< StateId > &open_dest, const std::vector< std::vector< StateWeightPair >> &close_src, const std::vector< bool > &close_non_term_weight, MutableFst< Arc > *ofst)
virtual void DeleteStates(const std::vector< StateId > &)=0
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
typename Arc::Weight Weight
virtual StateId NumStates() const =0
typename Arc::StateId StateId
LabelStatePair GetLabelStatePair(StateId os) const
PdtLeftParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
bool AddAuxiliarySymbols(const std::string &prefix, int64_t start_label, int64_t nlabels, SymbolTable *syms)
size_t SCC(Label label) const