FST  openfst-1.7.1
OpenFst Library
label-reachable.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to determine if a non-epsilon label can be read as the first
5 // non-epsilon symbol along some path from a given state.
6 
7 #ifndef FST_LABEL_REACHABLE_H_
8 #define FST_LABEL_REACHABLE_H_
9 
10 #include <unordered_map>
11 #include <utility>
12 #include <vector>
13 
14 #include <fst/log.h>
15 
16 #include <fst/accumulator.h>
17 #include <fst/arcsort.h>
18 #include <fst/interval-set.h>
19 #include <fst/state-reachable.h>
20 #include <fst/util.h>
21 #include <fst/vector-fst.h>
22 
23 
24 namespace fst {
25 
26 // Stores shareable data for label reachable class copies.
27 template <typename Label>
29  public:
32 
33  explicit LabelReachableData(bool reach_input, bool keep_relabel_data = true)
34  : reach_input_(reach_input),
35  keep_relabel_data_(keep_relabel_data),
36  have_relabel_data_(true),
37  final_label_(kNoLabel) {}
38 
40 
41  bool ReachInput() const { return reach_input_; }
42 
43  std::vector<LabelIntervalSet> *MutableIntervalSets() {
44  return &interval_sets_;
45  }
46 
47  const LabelIntervalSet &GetIntervalSet(int s) const {
48  return interval_sets_[s];
49  }
50 
51  int NumIntervalSets() const { return interval_sets_.size(); }
52 
53  std::unordered_map<Label, Label> *Label2Index() {
54  if (!have_relabel_data_) {
55  FSTERROR() << "LabelReachableData: No relabeling data";
56  }
57  return &label2index_;
58  }
59 
60  void SetFinalLabel(Label final_label) { final_label_ = final_label; }
61 
62  Label FinalLabel() const { return final_label_; }
63 
64  static LabelReachableData<Label> *Read(std::istream &istrm,
65  const FstReadOptions &opts) {
66  auto *data = new LabelReachableData<Label>();
67  ReadType(istrm, &data->reach_input_);
68  ReadType(istrm, &data->keep_relabel_data_);
69  data->have_relabel_data_ = data->keep_relabel_data_;
70  if (data->keep_relabel_data_) ReadType(istrm, &data->label2index_);
71  ReadType(istrm, &data->final_label_);
72  ReadType(istrm, &data->interval_sets_);
73  return data;
74  }
75 
76  bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const {
77  WriteType(ostrm, reach_input_);
78  WriteType(ostrm, keep_relabel_data_);
79  if (keep_relabel_data_) WriteType(ostrm, label2index_);
80  WriteType(ostrm, FinalLabel());
81  WriteType(ostrm, interval_sets_);
82  return true;
83  }
84 
85  private:
87 
88  bool reach_input_; // Input labels considered?
89  bool keep_relabel_data_; // Save label2index_ to file?
90  bool have_relabel_data_; // Using label2index_?
91  Label final_label_; // Final label.
92  std::unordered_map<Label, Label> label2index_; // Finds index for a label.
93  std::vector<LabelIntervalSet> interval_sets_; // Interval sets per state.
94 };
95 
96 // Tests reachability of labels from a given state. If reach_input is true, then
97 // input labels are considered, o.w. output labels are considered. To test for
98 // reachability from a state s, first do SetState(s), then a label l can be
99 // reached from state s of FST f iff Reach(r) is true where r = Relabel(l). The
100 // relabeling is required to ensure a compact representation of the reachable
101 // labels.
102 
103 // The whole FST can be relabeled instead with Relabel(&f, reach_input) so that
104 // the test Reach(r) applies directly to the labels of the transformed FST f.
105 // The relabeled FST will also be sorted appropriately for composition.
106 //
107 // Reachablity of a final state from state s (via an epsilon path) can be
108 // tested with ReachFinal().
109 //
110 // Reachability can also be tested on the set of labels specified by an arc
111 // iterator, useful for FST composition. In particular, Reach(aiter, ...) is
112 // true if labels on the input (output) side of the transitions of the arc
113 // iterator, when iter_input is true (false), can be reached from the state s.
114 // The iterator labels must have already been relabeled.
115 //
116 // With the arc iterator test of reachability, the begin position, end position
117 // and accumulated arc weight of the matches can be returned. The optional
118 // template argument controls how reachable arc weights are accumulated. The
119 // default uses semiring Plus(). Alternative ones can be used to distribute the
120 // weights in composition in various ways.
121 template <class Arc, class Accumulator = DefaultAccumulator<Arc>,
122  class D = LabelReachableData<typename Arc::Label>>
124  public:
125  using Label = typename Arc::Label;
126  using StateId = typename Arc::StateId;
127  using Weight = typename Arc::Weight;
128  using Data = D;
129 
130  using LabelIntervalSet = typename Data::LabelIntervalSet;
131 
133 
134  LabelReachable(const Fst<Arc> &fst, bool reach_input,
135  Accumulator *accumulator = nullptr,
136  bool keep_relabel_data = true)
137  : fst_(new VectorFst<Arc>(fst)),
138  s_(kNoStateId),
139  data_(std::make_shared<Data>(reach_input, keep_relabel_data)),
140  accumulator_(accumulator ? accumulator : new Accumulator()),
141  ncalls_(0),
142  nintervals_(0),
143  reach_fst_input_(false),
144  error_(false) {
145  const auto ins = fst_->NumStates();
146  TransformFst();
147  FindIntervals(ins);
148  fst_.reset();
149  }
150 
151  explicit LabelReachable(std::shared_ptr<Data> data,
152  Accumulator *accumulator = nullptr)
153  : s_(kNoStateId),
154  data_(std::move(data)),
155  accumulator_(accumulator ? accumulator : new Accumulator()),
156  ncalls_(0),
157  nintervals_(0),
158  reach_fst_input_(false),
159  error_(false) {}
160 
162  bool safe = false)
163  : s_(kNoStateId),
164  data_(reachable.data_),
165  accumulator_(new Accumulator(*reachable.accumulator_, safe)),
166  ncalls_(0),
167  nintervals_(0),
168  reach_fst_input_(reachable.reach_fst_input_),
169  error_(reachable.error_) {}
170 
172  if (ncalls_ > 0) {
173  VLOG(2) << "# of calls: " << ncalls_;
174  VLOG(2) << "# of intervals/call: " << (nintervals_ / ncalls_);
175  }
176  }
177 
178  // Relabels w.r.t labels that give compact label sets.
179  Label Relabel(Label label) {
180  if (label == 0 || error_) return label;
181  auto &label2index = *data_->Label2Index();
182  auto &relabel = label2index[label];
183  if (!relabel) relabel = label2index.size() + 1; // Adds new label.
184  return relabel;
185  }
186 
187  // Relabels FST w.r.t to labels that give compact label sets.
188  void Relabel(MutableFst<Arc> *fst, bool relabel_input) {
189  for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
190  siter.Next()) {
191  for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
192  !aiter.Done(); aiter.Next()) {
193  auto arc = aiter.Value();
194  if (relabel_input) {
195  arc.ilabel = Relabel(arc.ilabel);
196  } else {
197  arc.olabel = Relabel(arc.olabel);
198  }
199  aiter.SetValue(arc);
200  }
201  }
202  if (relabel_input) {
203  ArcSort(fst, ILabelCompare<Arc>());
204  fst->SetInputSymbols(nullptr);
205  } else {
206  ArcSort(fst, OLabelCompare<Arc>());
207  fst->SetOutputSymbols(nullptr);
208  }
209  }
210 
211  // Returns relabeling pairs (cf. relabel.h::Relabel()). If avoid_collisions is
212  // true, extra pairs are added to ensure no collisions when relabeling
213  // automata that have labels unseen here.
214  void RelabelPairs(std::vector<std::pair<Label, Label>> *pairs,
215  bool avoid_collisions = false) {
216  pairs->clear();
217  const auto &label2index = *data_->Label2Index();
218  // Maps labels to their new values in [1, label2index().size()].
219  for (auto it = label2index.begin(); it != label2index.end(); ++it) {
220  if (it->second != data_->FinalLabel()) {
221  pairs->push_back(std::make_pair(it->first, it->second));
222  }
223  }
224  if (avoid_collisions) {
225  // Ensures any label in [1, label2index().size()] is mapped either
226  // by the above step or to label2index() + 1 (to avoid collisions).
227  for (size_t i = 1; i <= label2index.size(); ++i) {
228  const auto it = label2index.find(i);
229  if (it == label2index.end() || it->second == data_->FinalLabel()) {
230  pairs->push_back(std::make_pair(i, label2index.size() + 1));
231  }
232  }
233  }
234  }
235 
236  // Set current state. Optionally set state associated
237  // with arc iterator to be passed to Reach.
238  void SetState(StateId s, StateId aiter_s = kNoStateId) {
239  s_ = s;
240  if (aiter_s != kNoStateId) {
241  accumulator_->SetState(aiter_s);
242  if (accumulator_->Error()) error_ = true;
243  }
244  }
245 
246  // Can reach this label from current state?
247  // Original labels must be transformed by the Relabel methods above.
248  bool Reach(Label label) const {
249  if (label == 0 || error_) return false;
250  return data_->GetIntervalSet(s_).Member(label);
251  }
252 
253  // Can reach final state (via epsilon transitions) from this state?
254  bool ReachFinal() const {
255  if (error_) return false;
256  return data_->GetIntervalSet(s_).Member(data_->FinalLabel());
257  }
258 
259  // Initialize with secondary FST to be used with Reach(Iterator,...).
260  // If reach_input = true, then arc input labels are considered in
261  // Reach(aiter, ...), o.w. output labels are considered. If copy is true, then
262  // the FST is a copy of the FST used in the previous call to this method
263  // (useful to avoid unnecessary updates).
264  template <class FST>
265  void ReachInit(const FST &fst, bool reach_input, bool copy = false) {
266  reach_fst_input_ = reach_input;
267  if (!fst.Properties(reach_fst_input_ ? kILabelSorted : kOLabelSorted,
268  true)) {
269  FSTERROR() << "LabelReachable::ReachInit: Fst is not sorted";
270  error_ = true;
271  }
272  accumulator_->Init(fst, copy);
273  if (accumulator_->Error()) error_ = true;
274  }
275 
276  // Can reach any arc iterator label between iterator positions
277  // aiter_begin and aiter_end?
278  // Arc iterator labels must be transformed by the Relabel methods
279  // above. If compute_weight is true, user may call ReachWeight().
280  template <class Iterator>
281  bool Reach(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end,
282  bool compute_weight) {
283  if (error_) return false;
284  const auto &interval_set = data_->GetIntervalSet(s_);
285  ++ncalls_;
286  nintervals_ += interval_set.Size();
287  reach_begin_ = -1;
288  reach_end_ = -1;
289  reach_weight_ = Weight::Zero();
290  const auto flags = aiter->Flags(); // Save flags to restore them on exit.
291  aiter->SetFlags(kArcNoCache, kArcNoCache); // Makes caching optional.
292  aiter->Seek(aiter_begin);
293  if (2 * (aiter_end - aiter_begin) < interval_set.Size()) {
294  // Checks each arc against intervals, setting arc iterator flags to only
295  // compute the ilabel or olabel values, since they are the only values
296  // required for most of the arcs processed.
297  aiter->SetFlags(reach_fst_input_ ? kArcILabelValue : kArcOLabelValue,
298  kArcValueFlags);
299  Label reach_label = kNoLabel;
300  for (auto aiter_pos = aiter_begin; aiter_pos < aiter_end;
301  aiter->Next(), ++aiter_pos) {
302  const auto &arc = aiter->Value();
303  const auto label = reach_fst_input_ ? arc.ilabel : arc.olabel;
304  if (label == reach_label || Reach(label)) {
305  reach_label = label;
306  if (reach_begin_ < 0) reach_begin_ = aiter_pos;
307  reach_end_ = aiter_pos + 1;
308  if (compute_weight) {
309  if (!(aiter->Flags() & kArcWeightValue)) {
310  // If arc.weight wasn't computed by the call to aiter->Value()
311  // above, we need to call aiter->Value() again after having set
312  // the arc iterator flags to compute the arc weight value.
313  aiter->SetFlags(kArcWeightValue, kArcValueFlags);
314  const auto &arcb = aiter->Value();
315  // Call the accumulator.
316  reach_weight_ = accumulator_->Sum(reach_weight_, arcb.weight);
317  // Only ilabel or olabel required to process the following arcs.
318  aiter->SetFlags(
319  reach_fst_input_ ? kArcILabelValue : kArcOLabelValue,
320  kArcValueFlags);
321  } else {
322  // Calls the accumulator.
323  reach_weight_ = accumulator_->Sum(reach_weight_, arc.weight);
324  }
325  }
326  }
327  }
328  } else {
329  // Checks each interval against arcs.
330  auto begin_low = aiter_begin;
331  auto end_low = aiter_begin;
332  for (const auto &interval : interval_set) {
333  begin_low = LowerBound(aiter, end_low, aiter_end, interval.begin);
334  end_low = LowerBound(aiter, begin_low, aiter_end, interval.end);
335  if (end_low - begin_low > 0) {
336  if (reach_begin_ < 0) reach_begin_ = begin_low;
337  reach_end_ = end_low;
338  if (compute_weight) {
339  aiter->SetFlags(kArcWeightValue, kArcValueFlags);
340  reach_weight_ =
341  accumulator_->Sum(reach_weight_, aiter, begin_low, end_low);
342  }
343  }
344  }
345  }
346  aiter->SetFlags(flags, kArcFlags); // Restores original flag values.
347  return reach_begin_ >= 0;
348  }
349 
350  // Returns iterator position of first matching arc.
351  ssize_t ReachBegin() const { return reach_begin_; }
352 
353  // Returns iterator position one past last matching arc.
354  ssize_t ReachEnd() const { return reach_end_; }
355 
356  // Return the sum of the weights for matching arcs. Valid only if
357  // compute_weight was true in Reach() call.
358  Weight ReachWeight() const { return reach_weight_; }
359 
360  // Access to the relabeling map. Excludes epsilon (0) label but
361  // includes kNoLabel that is used internally for super-final
362  // transitons.
363  const std::unordered_map<Label, Label> &Label2Index() const {
364  return *data_->Label2Index();
365  }
366 
367  const Data *GetData() const { return data_.get(); }
368 
369  std::shared_ptr<Data> GetSharedData() const { return data_; }
370 
371  bool Error() const { return error_ || accumulator_->Error(); }
372 
373  private:
374  // Redirects labeled arcs (input or output labels determined by ReachInput())
375  // to new label-specific final states. Each original final state is
376  // redirected via a transition labeled with kNoLabel to a new
377  // kNoLabel-specific final state. Creates super-initial state for all states
378  // with zero in-degree.
379  void TransformFst() {
380  auto ins = fst_->NumStates();
381  auto ons = ins;
382  std::vector<ssize_t> indeg(ins, 0);
383  // Redirects labeled arcs to new final states.
384  for (StateId s = 0; s < ins; ++s) {
385  for (MutableArcIterator<VectorFst<Arc>> aiter(fst_.get(), s);
386  !aiter.Done(); aiter.Next()) {
387  auto arc = aiter.Value();
388  const auto label = data_->ReachInput() ? arc.ilabel : arc.olabel;
389  if (label) {
390  auto insert_result = label2state_.insert(std::make_pair(label, ons));
391  if (insert_result.second) {
392  indeg.push_back(0);
393  ++ons;
394  }
395  arc.nextstate = label2state_[label];
396  aiter.SetValue(arc);
397  }
398  ++indeg[arc.nextstate]; // Finds in-degrees for next step.
399  }
400  // Redirects final weights to new final state.
401  auto final_weight = fst_->Final(s);
402  if (final_weight != Weight::Zero()) {
403  auto insert_result = label2state_.insert(std::make_pair(kNoLabel, ons));
404  if (insert_result.second) {
405  indeg.push_back(0);
406  ++ons;
407  }
408  const auto nextstate = label2state_[kNoLabel];
409  fst_->EmplaceArc(s, kNoLabel, kNoLabel, std::move(final_weight),
410  nextstate);
411  ++indeg[nextstate]; // Finds in-degrees for next step.
412  fst_->SetFinal(s, Weight::Zero());
413  }
414  }
415  // Adds new final states to the FST.
416  while (fst_->NumStates() < ons) {
417  StateId s = fst_->AddState();
418  fst_->SetFinal(s, Weight::One());
419  }
420  // Creates a super-initial state for all states with zero in-degree.
421  const auto start = fst_->AddState();
422  fst_->SetStart(start);
423  for (StateId s = 0; s < start; ++s) {
424  if (indeg[s] == 0) {
425  fst_->EmplaceArc(start, 0, 0, Weight::One(), s);
426  }
427  }
428  }
429 
430  void FindIntervals(StateId ins) {
431  StateReachable<Arc, Label, LabelIntervalSet> state_reachable(*fst_);
432  if (state_reachable.Error()) {
433  error_ = true;
434  return;
435  }
436  auto &state2index = state_reachable.State2Index();
437  auto &interval_sets = *data_->MutableIntervalSets();
438  interval_sets = state_reachable.IntervalSets();
439  interval_sets.resize(ins);
440  auto &label2index = *data_->Label2Index();
441  for (const auto &kv : label2state_) {
442  Label i = state2index[kv.second];
443  label2index[kv.first] = i;
444  if (kv.first == kNoLabel) data_->SetFinalLabel(i);
445  }
446  label2state_.clear();
447  double nintervals = 0;
448  ssize_t non_intervals = 0;
449  for (StateId s = 0; s < ins; ++s) {
450  nintervals += interval_sets[s].Size();
451  if (interval_sets[s].Size() > 1) {
452  ++non_intervals;
453  VLOG(3) << "state: " << s
454  << " # of intervals: " << interval_sets[s].Size();
455  }
456  }
457  VLOG(2) << "# of states: " << ins;
458  VLOG(2) << "# of intervals: " << nintervals;
459  VLOG(2) << "# of intervals/state: " << nintervals / ins;
460  VLOG(2) << "# of non-interval states: " << non_intervals;
461  }
462 
463  template <class Iterator>
464  ssize_t LowerBound(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end,
465  Label match_label) const {
466  // Only needs to compute the ilabel or olabel of arcs when performing the
467  // binary search.
468  aiter->SetFlags(reach_fst_input_ ? kArcILabelValue : kArcOLabelValue,
469  kArcValueFlags);
470  ssize_t low = aiter_begin;
471  ssize_t high = aiter_end;
472  while (low < high) {
473  const ssize_t mid = low + (high - low) / 2;
474  aiter->Seek(mid);
475  auto label =
476  reach_fst_input_ ? aiter->Value().ilabel : aiter->Value().olabel;
477  if (label < match_label) {
478  low = mid + 1;
479  } else {
480  high = mid;
481  }
482  }
483  aiter->Seek(low);
484  aiter->SetFlags(kArcValueFlags, kArcValueFlags);
485  return low;
486  }
487 
488  std::unique_ptr<VectorFst<Arc>> fst_;
489  // Current state
490  StateId s_;
491  // Finds final state for a label
492  std::unordered_map<Label, StateId> label2state_;
493  // Iterator position of first match.
494  ssize_t reach_begin_;
495  // Iterator position after last match.
496  ssize_t reach_end_;
497  // Gives weight sum of arc iterator arcs with reachable labels.
498  Weight reach_weight_;
499  // Shareable data between copies.
500  std::shared_ptr<Data> data_;
501  // Sums arc weights.
502  std::unique_ptr<Accumulator> accumulator_;
503  double ncalls_;
504  double nintervals_;
505  bool reach_fst_input_;
506  bool error_;
507 };
508 
509 } // namespace fst
510 
511 #endif // FST_LABEL_REACHABLE_H_
constexpr int kNoLabel
Definition: fst.h:179
const std::vector< ISet > & IntervalSets()
typename Arc::Label Label
std::vector< Index > & State2Index()
Label Relabel(Label label)
LabelReachable(const LabelReachable< Arc, Accumulator, Data > &reachable, bool safe=false)
bool ReachFinal() const
void SetFinalLabel(Label final_label)
constexpr uint64 kILabelSorted
Definition: properties.h:75
virtual void SetInputSymbols(const SymbolTable *isyms)=0
ssize_t ReachEnd() const
LabelReachable(std::shared_ptr< Data > data, Accumulator *accumulator=nullptr)
constexpr int kNoStateId
Definition: fst.h:180
static LabelReachableData< Label > * Read(std::istream &istrm, const FstReadOptions &opts)
LabelReachableData(bool reach_input, bool keep_relabel_data=true)
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
Definition: relabel.h:29
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
#define FSTERROR()
Definition: util.h:35
int NumIntervalSets() const
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
typename Arc::StateId StateId
void Relabel(MutableFst< Arc > *fst, bool relabel_input)
const std::unordered_map< Label, Label > & Label2Index() const
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:87
std::unordered_map< Label, Label > * Label2Index()
constexpr uint64 kOLabelSorted
Definition: properties.h:80
const Data * GetData() const
typename Arc::Weight Weight
#define VLOG(level)
Definition: log.h:49
typename Data::LabelIntervalSet LabelIntervalSet
bool Reach(Label label) const
bool Reach(Iterator *aiter, ssize_t aiter_begin, ssize_t aiter_end, bool compute_weight)
Label FinalLabel() const
void ReachInit(const FST &fst, bool reach_input, bool copy=false)
Weight ReachWeight() const
ssize_t ReachBegin() const
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
void SetState(StateId s, StateId aiter_s=kNoStateId)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
IntInterval< T > Interval
Definition: interval-set.h:108
std::vector< LabelIntervalSet > * MutableIntervalSets()
bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const
LabelReachable(const Fst< Arc > &fst, bool reach_input, Accumulator *accumulator=nullptr, bool keep_relabel_data=true)
std::shared_ptr< Data > GetSharedData() const
const LabelIntervalSet & GetIntervalSet(int s) const
typename LabelIntervalSet::Interval Interval
typename LabelIntervalSet::Interval Interval