21 #ifndef FST_LABEL_REACHABLE_H_ 22 #define FST_LABEL_REACHABLE_H_ 24 #include <sys/types.h> 43 #include <unordered_map> 48 template <
typename Label>
55 : reach_input_(reach_input),
56 keep_relabel_data_(keep_relabel_data),
57 have_relabel_data_(true),
65 return &interval_sets_;
69 return interval_sets_[s];
75 if (!have_relabel_data_) {
76 FSTERROR() <<
"LabelReachableData: No relabeling data";
82 if (!have_relabel_data_) {
83 FSTERROR() <<
"LabelReachableData: No relabeling data";
96 ReadType(istrm, &data->reach_input_);
97 ReadType(istrm, &data->keep_relabel_data_);
98 data->have_relabel_data_ = data->keep_relabel_data_;
99 if (data->keep_relabel_data_)
ReadType(istrm, &data->label2index_);
100 ReadType(istrm, &data->final_label_);
101 ReadType(istrm, &data->interval_sets_);
102 return data.release();
108 if (keep_relabel_data_)
WriteType(ostrm, label2index_);
118 bool keep_relabel_data_;
119 bool have_relabel_data_;
121 std::unordered_map<Label, Label> label2index_;
122 std::vector<LabelIntervalSet> interval_sets_;
129 template <
typename Label,
typename StateId>
131 const std::vector<StateId> &order) {
132 if (order.size() != interval_sets->size()) {
133 FSTERROR() <<
"StateSort: Bad order vector size: " << order.size()
134 <<
", expected: " << interval_sets->size();
137 std::vector<IntervalSet<Label>> reordered_interval_sets(
138 interval_sets->size());
141 for (StateId s = 0; s < order.size(); ++s) {
142 reordered_interval_sets[order[s]] = std::move((*interval_sets)[s]);
144 *interval_sets = std::move(reordered_interval_sets);
149 template <
typename Label,
typename StateId>
151 const std::vector<StateId> &order) {
171 reach_input_ = reach_input;
181 template <
class ArcIterator>
183 Label match_label)
const {
188 ssize_t low = aiter_begin;
189 ssize_t high = aiter_end;
191 const ssize_t mid = low + (high - low) / 2;
193 auto label = reach_input_ ? aiter->
Value().ilabel : aiter->
Value().olabel;
194 if (label < match_label) {
206 bool reach_input_ =
false;
237 template <
class Arc,
class Accumulator = DefaultAccumulator<Arc>,
238 class D = LabelReachableData<
typename Arc::Label>,
239 class LB = LabelLowerBound<Arc>>
253 std::unique_ptr<Accumulator> accumulator =
nullptr,
254 bool keep_relabel_data =
true)
255 : fst_(std::make_unique<
VectorFst<Arc>>(fst)),
257 data_(std::make_shared<
Data>(reach_input, keep_relabel_data)),
258 accumulator_(accumulator ? std::move(accumulator)
259 : std::make_unique<Accumulator>()) {
260 const auto ins = fst_->NumStates();
267 std::unique_ptr<Accumulator> accumulator =
nullptr)
269 data_(std::move(data)),
270 accumulator_(accumulator ? std::move(accumulator)
271 : std::make_unique<Accumulator>()) {}
275 data_(reachable.data_),
277 std::make_unique<Accumulator>(*reachable.accumulator_, safe)),
278 lower_bound_(reachable.lower_bound_),
279 reach_fst_input_(reachable.reach_fst_input_),
280 error_(reachable.error_) {}
284 VLOG(2) <<
"# of calls: " << ncalls_;
285 VLOG(2) <<
"# of intervals/call: " << (nintervals_ / ncalls_);
291 if (label == 0 || error_)
return label;
292 const auto &label2index = *data_->Label2Index();
293 if (
auto iter = label2index.find(label); iter != label2index.end()) {
296 auto &relabel = oov_label2index_[label];
299 relabel = label2index.size() + oov_label2index_.size() + 1;
309 !aiter.Done(); aiter.Next()) {
310 auto arc = aiter.Value();
312 arc.ilabel =
Relabel(arc.ilabel);
314 arc.olabel =
Relabel(arc.olabel);
332 bool avoid_collisions =
false) {
334 const auto &label2index = *data_->Label2Index();
336 for (
const auto &kv : label2index) {
337 if (kv.second != data_->FinalLabel()) {
338 pairs->emplace_back(kv);
342 pairs->insert(pairs->end(), oov_label2index_.begin(),
343 oov_label2index_.end());
344 if (avoid_collisions) {
347 for (
size_t i = 1; i <= label2index.size(); ++i) {
348 const auto it = label2index.find(i);
349 bool unmapped = it == label2index.end();
350 if (unmapped) unmapped = oov_label2index_.count(i) == 0;
351 if (unmapped || it->second == data_->FinalLabel()) {
352 pairs->emplace_back(i, label2index.size() + 1);
363 accumulator_->SetState(aiter_s);
364 if (accumulator_->Error()) error_ =
true;
365 lower_bound_.SetState(aiter_s);
372 if (label == 0 || error_)
return false;
373 return data_->GetIntervalSet(s_).Member(label);
378 if (error_)
return false;
379 return data_->GetIntervalSet(s_).Member(data_->FinalLabel());
389 reach_fst_input_ = reach_input;
392 FSTERROR() <<
"LabelReachable::ReachInit: Fst is not sorted";
395 accumulator_->Init(fst, copy);
396 if (accumulator_->Error()) error_ =
true;
397 lower_bound_.Init(fst, reach_input, copy);
404 template <
class Iterator>
405 bool Reach(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end,
406 bool compute_weight) {
407 if (error_)
return false;
408 const auto &interval_set = data_->GetIntervalSet(s_);
410 nintervals_ += interval_set.Size();
413 reach_weight_ = Weight::Zero();
414 const auto flags = aiter->Flags();
416 aiter->Seek(aiter_begin);
417 if (2 * (aiter_end - aiter_begin) < interval_set.Size()) {
424 for (
auto aiter_pos = aiter_begin; aiter_pos < aiter_end;
425 aiter->Next(), ++aiter_pos) {
426 const auto &arc = aiter->Value();
427 const auto label = reach_fst_input_ ? arc.ilabel : arc.olabel;
428 if (label == reach_label || Reach(label)) {
430 if (reach_begin_ < 0) reach_begin_ = aiter_pos;
431 reach_end_ = aiter_pos + 1;
432 if (compute_weight) {
438 const auto &arcb = aiter->Value();
440 reach_weight_ = accumulator_->Sum(reach_weight_, arcb.weight);
447 reach_weight_ = accumulator_->Sum(reach_weight_, arc.weight);
454 auto begin_low = aiter_begin;
455 auto end_low = aiter_begin;
456 for (
const auto &interval : interval_set) {
457 begin_low = lower_bound_(aiter, end_low, aiter_end, interval.begin);
458 end_low = lower_bound_(aiter, begin_low, aiter_end, interval.end);
459 if (end_low - begin_low > 0) {
460 if (reach_begin_ < 0) reach_begin_ = begin_low;
461 reach_end_ = end_low;
462 if (compute_weight) {
465 accumulator_->Sum(reach_weight_, aiter, begin_low, end_low);
471 return reach_begin_ >= 0;
488 return *data_->Label2Index();
495 bool Error()
const {
return error_ || accumulator_->Error(); }
503 void TransformFst() {
504 auto ins = fst_->NumStates();
506 std::vector<ssize_t> indeg(ins, 0);
508 for (
StateId s = 0; s < ins; ++s) {
510 !aiter.Done(); aiter.Next()) {
511 auto arc = aiter.Value();
512 const auto label = data_->ReachInput() ? arc.ilabel : arc.olabel;
514 if (
auto insert_result = label2state_.emplace(label, ons);
515 insert_result.second) {
519 arc.nextstate = label2state_[label];
522 ++indeg[arc.nextstate];
525 auto final_weight = fst_->Final(s);
526 if (final_weight != Weight::Zero()) {
527 if (
auto insert_result = label2state_.emplace(
kNoLabel, ons);
528 insert_result.second) {
532 const auto nextstate = label2state_[
kNoLabel];
536 fst_->SetFinal(s, Weight::Zero());
540 while (fst_->NumStates() < ons) {
545 const auto start = fst_->AddState();
546 fst_->SetStart(start);
547 for (
StateId s = 0; s < start; ++s) {
548 if (indeg[s] == 0) fst_->EmplaceArc(start, 0, 0, s);
552 void FindIntervals(
StateId ins) {
554 if (state_reachable.
Error()) {
559 auto &interval_sets = *data_->MutableIntervalSets();
561 interval_sets.resize(ins);
562 auto &label2index = *data_->MutableLabel2Index();
563 for (
const auto &kv : label2state_) {
564 Label i = state2index[kv.second];
565 label2index[kv.first] = i;
566 if (kv.first ==
kNoLabel) data_->SetFinalLabel(i);
568 label2state_.clear();
569 double nintervals = 0;
570 ssize_t non_intervals = 0;
571 for (
StateId s = 0; s < ins; ++s) {
572 nintervals += interval_sets[s].Size();
573 if (interval_sets[s].Size() > 1) {
575 VLOG(3) <<
"state: " << s
576 <<
" # of intervals: " << interval_sets[s].Size();
579 VLOG(2) <<
"# of states: " << ins;
580 VLOG(2) <<
"# of intervals: " << nintervals;
581 VLOG(2) <<
"# of intervals/state: " << nintervals / ins;
582 VLOG(2) <<
"# of non-interval states: " << non_intervals;
585 std::unique_ptr<VectorFst<Arc>> fst_;
589 std::unordered_map<Label, StateId> label2state_;
591 ssize_t reach_begin_;
597 std::shared_ptr<Data> data_;
599 std::unique_ptr<Accumulator> accumulator_;
603 std::unordered_map<Label, Label> oov_label2index_;
605 double nintervals_ = 0;
606 bool reach_fst_input_ =
false;
612 #endif // FST_LABEL_REACHABLE_H_ void Relabel(MutableFst< Arc > *fst, bool relabel_input)
typename Data::LabelIntervalSet LabelIntervalSet
constexpr uint8_t kArcValueFlags
typename Arc::Weight Weight
const std::unordered_map< Label, Label > * Label2Index() const
const std::vector< ISet > & IntervalSets()
constexpr uint8_t kArcNoCache
std::shared_ptr< Data > GetSharedData() const
std::vector< Index > & State2Index()
LabelReachable(const LabelReachable &reachable, bool safe=false)
void SetFinalLabel(Label final_label)
typename Arc::Label Label
constexpr uint8_t kArcFlags
virtual void SetInputSymbols(const SymbolTable *isyms)=0
bool Reach(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end, bool compute_weight)
constexpr uint8_t kArcILabelValue
std::unordered_map< Label, Label > * MutableLabel2Index()
std::unique_ptr< T > WrapUnique(T *ptr)
void SetState(StateId aiter_s)
Label Relabel(Label label)
const Arc & Value() const
LabelReachableData(bool reach_input, bool keep_relabel_data=true)
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
~LabelReachableData()=default
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
std::ostream & WriteType(std::ostream &strm, const T t)
void SetFlags(uint8_t flags, uint8_t mask)
int NumIntervalSets() const
static LabelReachableData * Read(std::istream &istrm, const FstReadOptions &opts)
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kOLabelSorted
typename Arc::StateId StateId
LabelReachable(const Fst< Arc > &fst, bool reach_input, std::unique_ptr< Accumulator > accumulator=nullptr, bool keep_relabel_data=true)
constexpr uint8_t kArcOLabelValue
ssize_t operator()(ArcIterator *aiter, ssize_t aiter_begin, ssize_t aiter_end, Label match_label) const
bool Reach(Label label) const
typename Arc::Label Label
typename LabelIntervalSet::Interval Interval
const Data * GetData() const
Weight ReachWeight() const
constexpr uint8_t kArcWeightValue
constexpr uint64_t kILabelSorted
LabelReachable(std::shared_ptr< Data > data, std::unique_ptr< Accumulator > accumulator=nullptr)
ssize_t ReachBegin() const
const std::unordered_map< Label, Label > & Label2Index() const
std::istream & ReadType(std::istream &strm, T *t)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
IntInterval< T > Interval
void SetState(StateId s, StateId aiter_s=kNoStateId)
typename Arc::StateId StateId
std::vector< LabelIntervalSet > * MutableIntervalSets()
void ReachInit(const FST &fst, bool reach_input, bool copy=false)
bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const
bool StateSort(std::vector< IntervalSet< Label >> *interval_sets, const std::vector< StateId > &order)
const LabelIntervalSet & GetIntervalSet(int s) const
typename LabelIntervalSet::Interval Interval
void Init(const FST &fst, bool reach_input, bool is_copy)