21 #ifndef FST_LABEL_REACHABLE_H_ 22 #define FST_LABEL_REACHABLE_H_ 24 #include <sys/types.h> 43 #include <unordered_map> 48 template <
typename Label>
55 : reach_input_(reach_input),
56 keep_relabel_data_(keep_relabel_data),
57 have_relabel_data_(true),
65 return &interval_sets_;
69 return interval_sets_[s];
75 if (!have_relabel_data_) {
76 FSTERROR() <<
"LabelReachableData: No relabeling data";
82 if (!have_relabel_data_) {
83 FSTERROR() <<
"LabelReachableData: No relabeling data";
96 ReadType(istrm, &data->reach_input_);
97 ReadType(istrm, &data->keep_relabel_data_);
98 data->have_relabel_data_ = data->keep_relabel_data_;
99 if (data->keep_relabel_data_)
ReadType(istrm, &data->label2index_);
100 ReadType(istrm, &data->final_label_);
101 ReadType(istrm, &data->interval_sets_);
102 return data.release();
108 if (keep_relabel_data_)
WriteType(ostrm, label2index_);
118 bool keep_relabel_data_;
119 bool have_relabel_data_;
121 std::unordered_map<Label, Label> label2index_;
122 std::vector<LabelIntervalSet> interval_sets_;
129 template <
typename Label,
typename StateId>
131 const std::vector<StateId> &order) {
132 if (order.size() != interval_sets->size()) {
133 FSTERROR() <<
"StateSort: Bad order vector size: " << order.size()
134 <<
", expected: " << interval_sets->size();
137 std::vector<IntervalSet<Label>> reordered_interval_sets(
138 interval_sets->size());
141 for (StateId s = 0; s < order.size(); ++s) {
142 reordered_interval_sets[order[s]] = std::move((*interval_sets)[s]);
144 *interval_sets = std::move(reordered_interval_sets);
149 template <
typename Label,
typename StateId>
151 const std::vector<StateId> &order) {
171 reach_input_ = reach_input;
181 template <
class ArcIterator>
183 Label match_label)
const {
188 ssize_t low = aiter_begin;
189 ssize_t high = aiter_end;
191 const ssize_t mid = low + (high - low) / 2;
193 auto label = reach_input_ ? aiter->
Value().ilabel : aiter->
Value().olabel;
194 if (label < match_label) {
206 bool reach_input_ =
false;
237 template <
class Arc,
class Accumulator = DefaultAccumulator<Arc>,
238 class D = LabelReachableData<
typename Arc::Label>,
239 class LB = LabelLowerBound<Arc>>
253 std::unique_ptr<Accumulator> accumulator =
nullptr,
254 bool keep_relabel_data =
true)
255 : fst_(std::make_unique<
VectorFst<Arc>>(fst)),
257 data_(std::make_shared<
Data>(reach_input, keep_relabel_data)),
258 accumulator_(accumulator ? std::move(accumulator)
259 : std::make_unique<Accumulator>()) {
260 const auto ins = fst_->NumStates();
267 std::unique_ptr<Accumulator> accumulator =
nullptr)
269 data_(std::move(data)),
270 accumulator_(accumulator ? std::move(accumulator)
271 : std::make_unique<Accumulator>()) {}
275 data_(reachable.data_),
277 std::make_unique<Accumulator>(*reachable.accumulator_, safe)),
278 lower_bound_(reachable.lower_bound_),
279 reach_fst_input_(reachable.reach_fst_input_),
280 error_(reachable.error_) {}
284 VLOG(2) <<
"# of calls: " << ncalls_;
285 VLOG(2) <<
"# of intervals/call: " << (nintervals_ / ncalls_);
291 if (label == 0 || error_)
return label;
292 const auto &label2index = *data_->Label2Index();
293 if (
auto iter = label2index.find(label); iter != label2index.end()) {
296 auto &relabel = oov_label2index_[label];
299 relabel = label2index.size() + oov_label2index_.size() + 1;
309 !aiter.Done(); aiter.Next()) {
310 auto arc = aiter.Value();
312 arc.ilabel =
Relabel(arc.ilabel);
314 arc.olabel =
Relabel(arc.olabel);
332 bool avoid_collisions =
false) {
334 const auto &label2index = *data_->Label2Index();
336 for (
const auto &kv : label2index) {
337 if (kv.second != data_->FinalLabel()) {
338 pairs->emplace_back(kv);
342 pairs->insert(pairs->end(), oov_label2index_.begin(),
343 oov_label2index_.end());
344 if (avoid_collisions) {
347 for (
size_t i = 1; i <= label2index.size(); ++i) {
348 const auto it = label2index.find(i);
349 bool unmapped = it == label2index.end();
352 unmapped = oov_label2index_.find(i) == oov_label2index_.end();
354 if (unmapped || it->second == data_->FinalLabel()) {
355 pairs->emplace_back(i, label2index.size() + 1);
366 accumulator_->SetState(aiter_s);
367 if (accumulator_->Error()) error_ =
true;
368 lower_bound_.SetState(aiter_s);
375 if (label == 0 || error_)
return false;
376 return data_->GetIntervalSet(s_).Member(label);
381 if (error_)
return false;
382 return data_->GetIntervalSet(s_).Member(data_->FinalLabel());
392 reach_fst_input_ = reach_input;
395 FSTERROR() <<
"LabelReachable::ReachInit: Fst is not sorted";
398 accumulator_->Init(fst, copy);
399 if (accumulator_->Error()) error_ =
true;
400 lower_bound_.Init(fst, reach_input, copy);
407 template <
class Iterator>
408 bool Reach(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end,
409 bool compute_weight) {
410 if (error_)
return false;
411 const auto &interval_set = data_->GetIntervalSet(s_);
413 nintervals_ += interval_set.Size();
416 reach_weight_ = Weight::Zero();
417 const auto flags = aiter->Flags();
419 aiter->Seek(aiter_begin);
420 if (2 * (aiter_end - aiter_begin) < interval_set.Size()) {
427 for (
auto aiter_pos = aiter_begin; aiter_pos < aiter_end;
428 aiter->Next(), ++aiter_pos) {
429 const auto &arc = aiter->Value();
430 const auto label = reach_fst_input_ ? arc.ilabel : arc.olabel;
431 if (label == reach_label || Reach(label)) {
433 if (reach_begin_ < 0) reach_begin_ = aiter_pos;
434 reach_end_ = aiter_pos + 1;
435 if (compute_weight) {
441 const auto &arcb = aiter->Value();
443 reach_weight_ = accumulator_->Sum(reach_weight_, arcb.weight);
450 reach_weight_ = accumulator_->Sum(reach_weight_, arc.weight);
457 auto begin_low = aiter_begin;
458 auto end_low = aiter_begin;
459 for (
const auto &interval : interval_set) {
460 begin_low = lower_bound_(aiter, end_low, aiter_end, interval.begin);
461 end_low = lower_bound_(aiter, begin_low, aiter_end, interval.end);
462 if (end_low - begin_low > 0) {
463 if (reach_begin_ < 0) reach_begin_ = begin_low;
464 reach_end_ = end_low;
465 if (compute_weight) {
468 accumulator_->Sum(reach_weight_, aiter, begin_low, end_low);
474 return reach_begin_ >= 0;
491 return *data_->Label2Index();
498 bool Error()
const {
return error_ || accumulator_->Error(); }
506 void TransformFst() {
507 auto ins = fst_->NumStates();
509 std::vector<ssize_t> indeg(ins, 0);
511 for (
StateId s = 0; s < ins; ++s) {
513 !aiter.Done(); aiter.Next()) {
514 auto arc = aiter.Value();
515 const auto label = data_->ReachInput() ? arc.ilabel : arc.olabel;
517 if (
auto insert_result = label2state_.emplace(label, ons);
518 insert_result.second) {
522 arc.nextstate = label2state_[label];
525 ++indeg[arc.nextstate];
528 auto final_weight = fst_->Final(s);
529 if (final_weight != Weight::Zero()) {
530 if (
auto insert_result = label2state_.emplace(
kNoLabel, ons);
531 insert_result.second) {
535 const auto nextstate = label2state_[
kNoLabel];
539 fst_->SetFinal(s, Weight::Zero());
543 while (fst_->NumStates() < ons) {
548 const auto start = fst_->AddState();
549 fst_->SetStart(start);
550 for (
StateId s = 0; s < start; ++s) {
551 if (indeg[s] == 0) fst_->EmplaceArc(start, 0, 0, s);
555 void FindIntervals(
StateId ins) {
557 if (state_reachable.
Error()) {
562 auto &interval_sets = *data_->MutableIntervalSets();
564 interval_sets.resize(ins);
565 auto &label2index = *data_->MutableLabel2Index();
566 for (
const auto &kv : label2state_) {
567 Label i = state2index[kv.second];
568 label2index[kv.first] = i;
569 if (kv.first ==
kNoLabel) data_->SetFinalLabel(i);
571 label2state_.clear();
572 double nintervals = 0;
573 ssize_t non_intervals = 0;
574 for (
StateId s = 0; s < ins; ++s) {
575 nintervals += interval_sets[s].Size();
576 if (interval_sets[s].Size() > 1) {
578 VLOG(3) <<
"state: " << s
579 <<
" # of intervals: " << interval_sets[s].Size();
582 VLOG(2) <<
"# of states: " << ins;
583 VLOG(2) <<
"# of intervals: " << nintervals;
584 VLOG(2) <<
"# of intervals/state: " << nintervals / ins;
585 VLOG(2) <<
"# of non-interval states: " << non_intervals;
588 std::unique_ptr<VectorFst<Arc>> fst_;
592 std::unordered_map<Label, StateId> label2state_;
594 ssize_t reach_begin_;
600 std::shared_ptr<Data> data_;
602 std::unique_ptr<Accumulator> accumulator_;
606 std::unordered_map<Label, Label> oov_label2index_;
608 double nintervals_ = 0;
609 bool reach_fst_input_ =
false;
615 #endif // FST_LABEL_REACHABLE_H_ void Relabel(MutableFst< Arc > *fst, bool relabel_input)
typename Data::LabelIntervalSet LabelIntervalSet
constexpr uint8_t kArcValueFlags
typename Arc::Weight Weight
const std::unordered_map< Label, Label > * Label2Index() const
const std::vector< ISet > & IntervalSets()
constexpr uint8_t kArcNoCache
std::shared_ptr< Data > GetSharedData() const
std::vector< Index > & State2Index()
LabelReachable(const LabelReachable &reachable, bool safe=false)
void SetFinalLabel(Label final_label)
typename Arc::Label Label
constexpr uint8_t kArcFlags
virtual void SetInputSymbols(const SymbolTable *isyms)=0
bool Reach(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end, bool compute_weight)
constexpr uint8_t kArcILabelValue
std::unordered_map< Label, Label > * MutableLabel2Index()
std::unique_ptr< T > WrapUnique(T *ptr)
void SetState(StateId aiter_s)
Label Relabel(Label label)
const Arc & Value() const
LabelReachableData(bool reach_input, bool keep_relabel_data=true)
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
~LabelReachableData()=default
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
std::ostream & WriteType(std::ostream &strm, const T t)
void SetFlags(uint8_t flags, uint8_t mask)
int NumIntervalSets() const
static LabelReachableData * Read(std::istream &istrm, const FstReadOptions &opts)
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kOLabelSorted
typename Arc::StateId StateId
LabelReachable(const Fst< Arc > &fst, bool reach_input, std::unique_ptr< Accumulator > accumulator=nullptr, bool keep_relabel_data=true)
constexpr uint8_t kArcOLabelValue
ssize_t operator()(ArcIterator *aiter, ssize_t aiter_begin, ssize_t aiter_end, Label match_label) const
bool Reach(Label label) const
typename Arc::Label Label
typename LabelIntervalSet::Interval Interval
const Data * GetData() const
Weight ReachWeight() const
constexpr uint8_t kArcWeightValue
constexpr uint64_t kILabelSorted
LabelReachable(std::shared_ptr< Data > data, std::unique_ptr< Accumulator > accumulator=nullptr)
ssize_t ReachBegin() const
const std::unordered_map< Label, Label > & Label2Index() const
std::istream & ReadType(std::istream &strm, T *t)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
IntInterval< T > Interval
void SetState(StateId s, StateId aiter_s=kNoStateId)
typename Arc::StateId StateId
std::vector< LabelIntervalSet > * MutableIntervalSets()
void ReachInit(const FST &fst, bool reach_input, bool copy=false)
bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const
bool StateSort(std::vector< IntervalSet< Label >> *interval_sets, const std::vector< StateId > &order)
const LabelIntervalSet & GetIntervalSet(int s) const
typename LabelIntervalSet::Interval Interval
void Init(const FST &fst, bool reach_input, bool is_copy)