20 #ifndef FST_REPLACE_UTIL_H_ 21 #define FST_REPLACE_UTIL_H_ 40 #include <unordered_map> 41 #include <unordered_set> 72 int64_t return_label = 0)
74 call_label_type(call_label_type),
75 return_label_type(return_label_type),
76 return_label(return_label) {}
99 const std::vector<std::pair<
typename Arc::Label,
const Fst<Arc> *>> &,
115 using FstPair = std::pair<Label, const Fst<Arc> *>;
120 ReplaceUtil(
const std::vector<MutableFstPair> &fst_pairs,
133 for (
Label i = 0; i < fst_array_.size(); ++i)
delete fst_array_[i];
139 GetDependencies(
false);
146 GetDependencies(
false);
147 if (
const auto it = nonterminal_hash_.find(label);
148 it != nonterminal_hash_.end()) {
149 return depscc_[it->second];
161 return depsccprops_[scc_id];
167 GetDependencies(
false);
169 for (
Label i = 0; i < fst_array_.size(); ++i) {
170 if (!fst_array_[i])
continue;
171 if (fst_array_[i]->Properties(props,
true) != props || !depaccess_[i]) {
182 void ReplaceLabels(
const std::vector<Label> &labels);
187 void ReplaceBySize(
size_t nstates,
size_t narcs,
size_t nnonterms);
194 void ReplaceByInstances(
size_t ninstances);
201 void GetFstPairs(std::vector<FstPair> *fst_pairs);
204 void GetMutableFstPairs(std::vector<MutableFstPair> *mutable_fst_pairs);
208 struct ReplaceStats {
215 std::map<Label, size_t> inref;
217 std::map<Label, size_t> outref;
219 ReplaceStats() : nstates(0), nfinal(0), narcs(0), nnonterms(0), nref(0) {}
223 void CheckMutableFsts();
227 void GetDependencies(
bool stats)
const;
229 void ClearDependencies()
const {
230 depfst_.DeleteStates();
233 depsccprops_.clear();
238 bool GetTopOrder(
const Fst<Arc> &
fst, std::vector<Label> *toporder)
const;
241 void UpdateStats(
Label j);
245 void GetSCCProperties()
const;
251 int64_t return_label_;
252 std::vector<const Fst<Arc> *> fst_array_;
253 std::vector<MutableFst<Arc> *> mutable_fst_array_;
254 std::vector<Label> nonterminal_array_;
257 mutable std::vector<StateId> depscc_;
258 mutable std::vector<bool> depaccess_;
259 mutable uint64_t depprops_;
260 mutable bool have_stats_;
261 mutable std::vector<ReplaceStats> stats_;
262 mutable std::vector<uint8_t> depsccprops_;
270 : root_label_(opts.
root),
276 fst_array_.push_back(
nullptr);
277 mutable_fst_array_.push_back(
nullptr);
278 nonterminal_array_.push_back(
kNoLabel);
279 for (
const auto &fst_pair : fst_pairs) {
280 const auto label = fst_pair.first;
281 auto *
fst = fst_pair.second;
282 nonterminal_hash_[label] = fst_array_.size();
283 nonterminal_array_.push_back(label);
284 fst_array_.push_back(
fst);
285 mutable_fst_array_.push_back(
fst);
287 root_fst_ = nonterminal_hash_[root_label_];
289 FSTERROR() <<
"ReplaceUtil: No root FST for label: " << root_label_;
296 : root_label_(opts.root),
297 call_label_type_(opts.call_label_type),
298 return_label_type_(opts.return_label_type),
299 return_label_(opts.return_label),
302 fst_array_.push_back(
nullptr);
303 nonterminal_array_.push_back(
kNoLabel);
304 for (
const auto &fst_pair : fst_pairs) {
305 const auto label = fst_pair.first;
306 const auto *
fst = fst_pair.second;
307 nonterminal_hash_[label] = fst_array_.size();
308 nonterminal_array_.push_back(label);
309 fst_array_.push_back(
fst->Copy());
311 root_fst_ = nonterminal_hash_[root_label_];
313 FSTERROR() <<
"ReplaceUtil: No root FST for label: " << root_label_;
319 const std::vector<std::unique_ptr<
const Fst<Arc>>> &fst_array,
321 : root_fst_(opts.root),
322 call_label_type_(opts.call_label_type),
323 return_label_type_(opts.return_label_type),
324 return_label_(opts.return_label),
325 nonterminal_array_(fst_array.size()),
326 nonterminal_hash_(nonterminal_hash),
329 fst_array_.push_back(
nullptr);
330 for (
size_t i = 1; i < fst_array.size(); ++i) {
331 fst_array_.push_back(fst_array[i]->Copy());
333 for (
auto it = nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) {
334 nonterminal_array_[it->second] = it->first;
336 root_label_ = nonterminal_array_[root_fst_];
342 if (stats && !have_stats_) {
349 if (have_stats_) stats_.reserve(fst_array_.size());
350 for (
Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) {
353 if (have_stats_) stats_.push_back(ReplaceStats());
358 for (
Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) {
359 const auto *ifst = fst_array_[ilabel];
362 const auto s = siter.Value();
364 ++stats_[ilabel].nstates;
365 if (ifst->Final(s) != Weight::Zero()) ++stats_[ilabel].nfinal;
368 if (have_stats_) ++stats_[ilabel].narcs;
369 const auto &arc = aiter.Value();
370 if (
auto it = nonterminal_hash_.find(arc.olabel);
371 it != nonterminal_hash_.end()) {
372 const auto nextstate = it->second;
373 depfst_.
EmplaceArc(ilabel, arc.olabel, arc.olabel, nextstate);
375 ++stats_[ilabel].nnonterms;
376 ++stats_[nextstate].nref;
377 ++stats_[nextstate].inref[ilabel];
378 ++stats_[ilabel].outref[nextstate];
385 SccVisitor<Arc> scc_visitor(&depscc_, &depaccess_,
nullptr, &depprops_);
392 FSTERROR() <<
"ReplaceUtil::UpdateStats: Stats not available";
395 if (j == root_fst_)
return;
396 for (
auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) {
397 const auto i = in->first;
398 const auto ni = in->second;
399 stats_[i].nstates += stats_[j].nstates * ni;
400 stats_[i].narcs += (stats_[j].narcs + 1) * ni;
401 stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
402 stats_[i].outref.erase(j);
403 for (
auto out = stats_[j].outref.begin(); out != stats_[j].outref.end();
405 const auto k = out->first;
406 const auto nk = out->second;
407 stats_[i].outref[k] += ni * nk;
410 for (
auto out = stats_[j].outref.begin(); out != stats_[j].outref.end();
412 const auto k = out->first;
413 const auto nk = out->second;
414 stats_[k].nref -= nk;
415 stats_[k].inref.erase(j);
416 for (
auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) {
417 const auto i = in->first;
418 const auto ni = in->second;
419 stats_[k].inref[i] += ni * nk;
420 stats_[k].nref += ni * nk;
427 if (mutable_fst_array_.empty()) {
428 for (
Label i = 0; i < fst_array_.size(); ++i) {
429 if (!fst_array_[i]) {
430 mutable_fst_array_.push_back(
nullptr);
433 delete fst_array_[i];
434 fst_array_[i] = mutable_fst_array_[i];
444 for (
auto *mutable_fst : mutable_fst_array_) {
445 if (!mutable_fst)
continue;
446 if (mutable_fst->Properties(props,
false) != props) {
450 GetDependencies(
false);
451 for (
Label i = 0; i < mutable_fst_array_.size(); ++i) {
452 auto *
fst = mutable_fst_array_[i];
453 if (
fst && !depaccess_[i]) {
455 fst_array_[i] =
nullptr;
456 mutable_fst_array_[i] =
nullptr;
464 std::vector<Label> *toporder)
const {
466 std::vector<StateId> order;
467 bool acyclic =
false;
471 LOG(WARNING) <<
"ReplaceUtil::GetTopOrder: Cyclical label dependencies";
474 toporder->resize(order.size());
475 for (
Label i = 0; i < order.size(); ++i) (*toporder)[order[i]] = i;
482 std::unordered_set<Label> label_set;
483 for (
const auto label : labels) {
485 if (label != root_label_) label_set.insert(label);
488 GetDependencies(
false);
491 std::vector<Arc> arcs;
494 const auto &arc = aiter.Value();
495 const auto label = nonterminal_array_[arc.nextstate];
496 if (label_set.count(label) > 0) arcs.push_back(arc);
499 for (
auto &arc : arcs) pfst.
AddArc(i, std::move(arc));
501 std::vector<Label> toporder;
502 if (!GetTopOrder(pfst, &toporder)) {
508 for (
Label o = toporder.size() - 1; o >= 0; --o) {
509 std::vector<FstPair> fst_pairs;
510 auto s = toporder[o];
513 const auto &arc = aiter.Value();
514 const auto label = nonterminal_array_[arc.nextstate];
515 const auto *fst = fst_array_[arc.nextstate];
516 fst_pairs.emplace_back(label, fst);
518 if (fst_pairs.empty())
continue;
519 const auto label = nonterminal_array_[s];
520 const auto *fst = fst_array_[s];
521 fst_pairs.emplace_back(label, fst);
524 Replace(fst_pairs, mutable_fst_array_[s], opts);
532 std::vector<Label> labels;
533 GetDependencies(
true);
534 std::vector<Label> toporder;
535 if (!GetTopOrder(depfst_, &toporder)) {
539 for (
Label o = toporder.size() - 1; o >= 0; --o) {
540 const auto j = toporder[o];
541 if (stats_[j].nstates <= nstates && stats_[j].narcs <= narcs &&
542 stats_[j].nnonterms <= nnonterms) {
543 labels.push_back(nonterminal_array_[j]);
552 std::vector<Label> labels;
553 GetDependencies(
true);
554 std::vector<Label> toporder;
555 if (!GetTopOrder(depfst_, &toporder)) {
559 for (
Label o = 0; o < toporder.size(); ++o) {
560 const auto j = toporder[o];
561 if (stats_[j].nref <= ninstances) {
562 labels.push_back(nonterminal_array_[j]);
573 for (
Label i = 0; i < fst_array_.size(); ++i) {
574 const auto label = nonterminal_array_[i];
575 const auto *fst = fst_array_[i];
577 fst_pairs->emplace_back(label, fst);
583 std::vector<MutableFstPair> *mutable_fst_pairs) {
585 mutable_fst_pairs->clear();
586 for (
Label i = 0; i < mutable_fst_array_.size(); ++i) {
587 const auto label = nonterminal_array_[i];
588 const auto *fst = mutable_fst_array_[i];
590 mutable_fst_pairs->emplace_back(label, fst->
Copy());
596 if (!depsccprops_.empty())
return;
597 GetDependencies(
false);
598 if (depscc_.empty())
return;
599 for (
StateId scc = 0; scc < depscc_.size(); ++scc) {
600 depsccprops_.push_back(kReplaceSCCLeftLinear | kReplaceSCCRightLinear);
602 if (!(depprops_ &
kCyclic))
return;
604 for (
StateId scc = 0; scc < depscc_.size(); ++scc) {
607 const auto &arc = aiter.Value();
608 if (arc.nextstate == scc) {
613 std::vector<bool> depscc_visited(depscc_.size(),
false);
614 for (
Label i = 0; i < fst_array_.size(); ++i) {
615 const auto *fst = fst_array_[i];
617 const auto depscc = depscc_[i];
618 if (depscc_visited[depscc]) {
621 depscc_visited[depscc] =
true;
622 std::vector<StateId> fstscc;
627 const auto s = siter.Value();
629 const auto &arc = aiter.Value();
630 auto it = nonterminal_hash_.find(arc.olabel);
631 if (it == nonterminal_hash_.end() || depscc_[it->second] != depscc) {
634 const bool arc_in_cycle = fstscc[s] == fstscc[arc.nextstate];
636 if (s != fst->
Start() || arc_in_cycle) {
637 depsccprops_[depscc] &= ~kReplaceSCCLeftLinear;
640 if (fst->
Final(arc.nextstate) == Weight::Zero() || arc_in_cycle) {
641 depsccprops_[depscc] &= ~kReplaceSCCRightLinear;
650 #endif // FST_REPLACE_UTIL_H_ constexpr uint64_t kCyclic
uint8_t SCCProperties(StateId scc_id)
void ReplaceByInstances(size_t ninstances)
void DeleteArcs(StateId s, size_t n) override
void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms)
void GetFstPairs(std::vector< FstPair > *fst_pairs)
constexpr uint64_t kCoAccessible
typename Arc::Weight Weight
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
void ReplaceLabels(const std::vector< Label > &labels)
virtual Weight Final(StateId) const =0
void Connect(MutableFst< Arc > *fst)
std::pair< Label, const Fst< Arc > * > FstPair
constexpr uint8_t kReplaceSCCLeftLinear
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
constexpr uint8_t kReplaceSCCNonTrivial
void GetMutableFstPairs(std::vector< MutableFstPair > *mutable_fst_pairs)
virtual Fst * Copy(bool safe=false) const =0
StateId NumStates() const override
constexpr uint64_t kAccessible
virtual StateId Start() const =0
std::pair< Label, MutableFst< Arc > * > MutableFstPair
void EmplaceArc(StateId state, T &&...ctor_args)
StateId SCC(Label label) const
StateId AddState() override
ReplaceUtil(const std::vector< MutableFstPair > &fst_pairs, const ReplaceUtilOptions &opts)
void AddArc(StateId s, const Arc &arc) override
ReplaceLabelType return_label_type
void SetFinal(StateId s, Weight weight=Weight::One()) override
bool CyclicDependencies() const
constexpr uint8_t kReplaceSCCRightLinear
std::unordered_map< Label, Label > NonTerminalHash
typename Arc::StateId StateId
typename Arc::Label Label
ReplaceLabelType call_label_type
ReplaceUtilOptions(int64_t root=kNoLabel, ReplaceLabelType call_label_type=REPLACE_LABEL_INPUT, ReplaceLabelType return_label_type=REPLACE_LABEL_NEITHER, int64_t return_label=0)
void SetStart(StateId s) override
ReplaceUtilOptions(int64_t root, bool epsilon_replace_arc)