FST  openfst-1.7.2
OpenFst Library
accumulator.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to accumulate arc weights. Useful for weight lookahead.
5 
6 #ifndef FST_ACCUMULATOR_H_
7 #define FST_ACCUMULATOR_H_
8 
9 #include <algorithm>
10 #include <functional>
11 #include <unordered_map>
12 #include <vector>
13 
14 #include <fst/log.h>
15 
16 #include <fst/arcfilter.h>
17 #include <fst/arcsort.h>
18 #include <fst/dfs-visit.h>
19 #include <fst/expanded-fst.h>
20 #include <fst/replace.h>
21 
22 namespace fst {
23 
24 // This class accumulates arc weights using the semiring Plus().
25 template <class A>
27  public:
28  using Arc = A;
29  using StateId = typename Arc::StateId;
30  using Weight = typename Arc::Weight;
31 
33 
34  DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {}
35 
36  void Init(const Fst<Arc> &fst, bool copy = false) {}
37 
38  void SetState(StateId state) {}
39 
40  Weight Sum(Weight w, Weight v) { return Plus(w, v); }
41 
42  template <class ArcIter>
43  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
44  Adder<Weight> adder(w); // maintains cumulative sum accurately
45  aiter->Seek(begin);
46  for (auto pos = begin; pos < end; aiter->Next(), ++pos)
47  adder.Add(aiter->Value().weight);
48  return adder.Sum();
49  }
50 
51  constexpr bool Error() const { return false; }
52 
53  private:
54  DefaultAccumulator &operator=(const DefaultAccumulator &) = delete;
55 };
56 
57 // This class accumulates arc weights using the log semiring Plus() assuming an
58 // arc weight has a WeightConvert specialization to and from log64 weights.
59 template <class A>
61  public:
62  using Arc = A;
63  using StateId = typename Arc::StateId;
64  using Weight = typename Arc::Weight;
65 
67 
68  LogAccumulator(const LogAccumulator &acc, bool safe = false) {}
69 
70  void Init(const Fst<Arc> &fst, bool copy = false) {}
71 
72  void SetState(StateId s) {}
73 
74  Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
75 
76  template <class ArcIter>
77  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
78  auto sum = w;
79  aiter->Seek(begin);
80  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
81  sum = LogPlus(sum, aiter->Value().weight);
82  }
83  return sum;
84  }
85 
86  constexpr bool Error() const { return false; }
87 
88  private:
89  Weight LogPlus(Weight w, Weight v) {
90  if (w == Weight::Zero()) {
91  return v;
92  }
93  const auto f1 = to_log_weight_(w).Value();
94  const auto f2 = to_log_weight_(v).Value();
95  if (f1 > f2) {
96  return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2)));
97  } else {
98  return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1)));
99  }
100  }
101 
102  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
103  const WeightConvert<Log64Weight, Weight> to_weight_{};
104 
105  LogAccumulator &operator=(const LogAccumulator &) = delete;
106 };
107 
108 // Interface for shareable data for fast log accumulator copies. Holds pointers
109 // to data only, storage is provided by derived classes.
111  public:
112  FastLogAccumulatorData(int arc_limit, int arc_period)
113  : arc_limit_(arc_limit),
114  arc_period_(arc_period),
115  weights_ptr_(nullptr),
116  num_weights_(0),
117  weight_positions_ptr_(nullptr),
118  num_positions_(0) {}
119 
121 
122  // Cummulative weight per state for all states s.t. # of arcs > arc_limit_
123  // with arcs in order. The first element per state is Log64Weight::Zero().
124  const double *Weights() const { return weights_ptr_; }
125 
126  int NumWeights() const { return num_weights_; }
127 
128  // Maps from state to corresponding beginning weight position in weights_.
129  // osition -1 means no pre-computed weights for that state.
130  const int *WeightPositions() const { return weight_positions_ptr_; }
131 
132  int NumPositions() const { return num_positions_; }
133 
134  int ArcLimit() const { return arc_limit_; }
135 
136  int ArcPeriod() const { return arc_period_; }
137 
138  // Returns true if the data object is mutable and supports SetData().
139  virtual bool IsMutable() const = 0;
140 
141  // Does not take ownership but may invalidate the contents of weights and
142  // weight_positions.
143  virtual void SetData(std::vector<double> *weights,
144  std::vector<int> *weight_positions) = 0;
145 
146  protected:
147  void Init(int num_weights, const double *weights, int num_positions,
148  const int *weight_positions) {
149  weights_ptr_ = weights;
150  num_weights_ = num_weights;
151  weight_positions_ptr_ = weight_positions;
152  num_positions_ = num_positions;
153  }
154 
155  private:
156  const int arc_limit_;
157  const int arc_period_;
158  const double *weights_ptr_;
159  int num_weights_;
160  const int *weight_positions_ptr_;
161  int num_positions_;
162 
164  FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete;
165 };
166 
167 // FastLogAccumulatorData with mutable storage; filled by
168 // FastLogAccumulator::Init.
170  public:
171  MutableFastLogAccumulatorData(int arc_limit, int arc_period)
172  : FastLogAccumulatorData(arc_limit, arc_period) {}
173 
174  bool IsMutable() const override { return true; }
175 
176  void SetData(std::vector<double> *weights,
177  std::vector<int> *weight_positions) override {
178  weights_.swap(*weights);
179  weight_positions_.swap(*weight_positions);
180  Init(weights_.size(), weights_.data(), weight_positions_.size(),
181  weight_positions_.data());
182  }
183 
184  private:
185  std::vector<double> weights_;
186  std::vector<int> weight_positions_;
187 
190  const MutableFastLogAccumulatorData &) = delete;
191 };
192 
193 // This class accumulates arc weights using the log semiring Plus() assuming an
194 // arc weight has a WeightConvert specialization to and from log64 weights. The
195 // member function Init(fst) has to be called to setup pre-computed weight
196 // information.
197 template <class A>
199  public:
200  using Arc = A;
201  using StateId = typename Arc::StateId;
202  using Weight = typename Arc::Weight;
203 
204  explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
205  : to_log_weight_(),
206  to_weight_(),
207  arc_limit_(arc_limit),
208  arc_period_(arc_period),
209  data_(std::make_shared<MutableFastLogAccumulatorData>(arc_limit,
210  arc_period)),
211  state_weights_(nullptr),
212  error_(false) {}
213 
214  explicit FastLogAccumulator(std::shared_ptr<FastLogAccumulatorData> data)
215  : to_log_weight_(),
216  to_weight_(),
217  arc_limit_(data->ArcLimit()),
218  arc_period_(data->ArcPeriod()),
219  data_(data),
220  state_weights_(nullptr),
221  error_(false) {}
222 
223  FastLogAccumulator(const FastLogAccumulator<Arc> &acc, bool safe = false)
224  : to_log_weight_(),
225  to_weight_(),
226  arc_limit_(acc.arc_limit_),
227  arc_period_(acc.arc_period_),
228  data_(acc.data_),
229  state_weights_(nullptr),
230  error_(acc.error_) {}
231 
232  void SetState(StateId s) {
233  const auto *weights = data_->Weights();
234  const auto *weight_positions = data_->WeightPositions();
235  state_weights_ = nullptr;
236  if (s < data_->NumPositions()) {
237  const auto pos = weight_positions[s];
238  if (pos >= 0) state_weights_ = &(weights[pos]);
239  }
240  }
241 
242  Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); }
243 
244  template <class ArcIter>
245  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const {
246  if (error_) return Weight::NoWeight();
247  auto sum = w;
248  // Finds begin and end of pre-stored weights.
249  ssize_t index_begin = -1;
250  ssize_t index_end = -1;
251  ssize_t stored_begin = end;
252  ssize_t stored_end = end;
253  if (state_weights_) {
254  index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0;
255  index_end = end / arc_period_;
256  stored_begin = index_begin * arc_period_;
257  stored_end = index_end * arc_period_;
258  }
259  // Computes sum before pre-stored weights.
260  if (begin < stored_begin) {
261  const auto pos_end = std::min(stored_begin, end);
262  aiter->Seek(begin);
263  for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) {
264  sum = LogPlus(sum, aiter->Value().weight);
265  }
266  }
267  // Computes sum between pre-stored weights.
268  if (stored_begin < stored_end) {
269  const auto f1 = state_weights_[index_end];
270  const auto f2 = state_weights_[index_begin];
271  if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2));
272  // Commented out for efficiency; adds Zero().
273  /*
274  else {
275  // explicitly computes if cumulative sum lacks precision
276  aiter->Seek(stored_begin);
277  for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos)
278  sum = LogPlus(sum, aiter->Value().weight);
279  }
280  */
281  }
282  // Computes sum after pre-stored weights.
283  if (stored_end < end) {
284  const auto pos_start = std::max(stored_begin, stored_end);
285  aiter->Seek(pos_start);
286  for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) {
287  sum = LogPlus(sum, aiter->Value().weight);
288  }
289  }
290  return sum;
291  }
292 
293  template <class FST>
294  void Init(const FST &fst, bool copy = false) {
295  if (copy || !data_->IsMutable()) return;
296  if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) {
297  FSTERROR() << "FastLogAccumulator: Initialization error";
298  error_ = true;
299  return;
300  }
301  std::vector<double> weights;
302  std::vector<int> weight_positions;
303  weight_positions.reserve(CountStates(fst));
304  for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
305  const auto s = siter.Value();
306  if (fst.NumArcs(s) >= arc_limit_) {
308  if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1);
309  weight_positions[s] = weights.size();
310  weights.push_back(sum);
311  size_t narcs = 0;
312  ArcIterator<FST> aiter(fst, s);
313  aiter.SetFlags(kArcWeightValue | kArcNoCache, kArcFlags);
314  for (; !aiter.Done(); aiter.Next()) {
315  const auto &arc = aiter.Value();
316  sum = LogPlus(sum, arc.weight);
317  // Stores cumulative weight distribution per arc_period_.
318  if (++narcs % arc_period_ == 0) weights.push_back(sum);
319  }
320  }
321  }
322  data_->SetData(&weights, &weight_positions);
323  }
324 
325  bool Error() const { return error_; }
326 
327  std::shared_ptr<FastLogAccumulatorData> GetData() const { return data_; }
328 
329  private:
330  static double LogPosExp(double x) {
331  return x == FloatLimits<double>::PosInfinity() ? 0.0
332  : log(1.0F + exp(-x));
333  }
334 
335  static double LogMinusExp(double x) {
336  return x == FloatLimits<double>::PosInfinity() ? 0.0
337  : log(1.0F - exp(-x));
338  }
339 
340  Weight LogPlus(Weight w, Weight v) const {
341  if (w == Weight::Zero()) {
342  return v;
343  }
344  const auto f1 = to_log_weight_(w).Value();
345  const auto f2 = to_log_weight_(v).Value();
346  if (f1 > f2) {
347  return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
348  } else {
349  return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
350  }
351  }
352 
353  double LogPlus(double f1, Weight v) const {
354  const auto f2 = to_log_weight_(v).Value();
355  if (f1 == FloatLimits<double>::PosInfinity()) {
356  return f2;
357  } else if (f1 > f2) {
358  return f2 - LogPosExp(f1 - f2);
359  } else {
360  return f1 - LogPosExp(f2 - f1);
361  }
362  }
363 
364  // Assumes f1 < f2.
365  Weight LogMinus(double f1, double f2) const {
366  if (f2 == FloatLimits<double>::PosInfinity()) {
367  return to_weight_(Log64Weight(f1));
368  } else {
369  return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
370  }
371  }
372 
373  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
374  const WeightConvert<Log64Weight, Weight> to_weight_{};
375  const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state.
376  const ssize_t arc_period_; // Saves cumulative weights per arc_period_.
377  std::shared_ptr<FastLogAccumulatorData> data_;
378  const double *state_weights_;
379  bool error_;
380 
381  FastLogAccumulator &operator=(const FastLogAccumulator &) = delete;
382 };
383 
384 // Stores shareable data for cache log accumulator copies. All copies share the
385 // same cache.
386 template <class Arc>
388  public:
389  using StateId = typename Arc::StateId;
390  using Weight = typename Arc::Weight;
391 
392  CacheLogAccumulatorData(bool gc, size_t gc_limit)
393  : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
394 
396  : cache_gc_(data.cache_gc_),
397  cache_limit_(data.cache_limit_),
398  cache_size_(0) {}
399 
400  bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
401 
402  std::vector<double> *GetWeights(StateId s) {
403  auto it = cache_.find(s);
404  if (it != cache_.end()) {
405  it->second.recent = true;
406  return it->second.weights.get();
407  } else {
408  return nullptr;
409  }
410  }
411 
412  void AddWeights(StateId s, std::vector<double> *weights) {
413  if (cache_gc_ && cache_size_ >= cache_limit_) GC(false);
414  cache_.insert(std::make_pair(s, CacheState(weights, true)));
415  if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double);
416  }
417 
418  private:
419  // Cached information for a given state.
420  struct CacheState {
421  std::unique_ptr<std::vector<double>> weights; // Accumulated weights.
422  bool recent; // Has this state been accessed since last GC?
423 
424  CacheState(std::vector<double> *weights, bool recent)
425  : weights(weights), recent(recent) {}
426  };
427 
428  // Garbage collect: Deletes from cache states that have not been accessed
429  // since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of
430  // 'cache_limit_'. If it does not free enough memory, start deleting
431  // recently accessed states.
432  void GC(bool free_recent) {
433  auto cache_target = (2 * cache_limit_) / 3 + 1;
434  auto it = cache_.begin();
435  while (it != cache_.end() && cache_size_ > cache_target) {
436  auto &cs = it->second;
437  if (free_recent || !cs.recent) {
438  cache_size_ -= cs.weights->capacity() * sizeof(double);
439  cache_.erase(it++);
440  } else {
441  cs.recent = false;
442  ++it;
443  }
444  }
445  if (!free_recent && cache_size_ > cache_target) GC(true);
446  }
447 
448  std::unordered_map<StateId, CacheState> cache_; // Cache.
449  bool cache_gc_; // Enables garbage collection.
450  size_t cache_limit_; // # of bytes cached.
451  size_t cache_size_; // # of bytes allowed before GC.
452 
453  CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete;
454 };
455 
456 // This class accumulates arc weights using the log semiring Plus() has a
457 // WeightConvert specialization to and from log64 weights. It is similar to the
458 // FastLogAccumator. However here, the accumulated weights are pre-computed and
459 // stored only for the states that are visited. The member function Init(fst)
460 // has to be called to setup this accumulator.
461 template <class Arc>
463  public:
464  using StateId = typename Arc::StateId;
465  using Weight = typename Arc::Weight;
466 
467  explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
468  size_t gc_limit = 10 * 1024 * 1024)
469  : arc_limit_(arc_limit),
470  data_(std::make_shared<CacheLogAccumulatorData<Arc>>(gc, gc_limit)),
471  s_(kNoStateId),
472  error_(false) {}
473 
474  CacheLogAccumulator(const CacheLogAccumulator<Arc> &acc, bool safe = false)
475  : arc_limit_(acc.arc_limit_),
476  fst_(acc.fst_ ? acc.fst_->Copy() : nullptr),
477  data_(safe ? std::make_shared<CacheLogAccumulatorData<Arc>>(*acc.data_)
478  : acc.data_),
479  s_(kNoStateId),
480  error_(acc.error_) {}
481 
482  // Argument arc_limit specifies the minimum number of arcs to pre-compute.
483  void Init(const Fst<Arc> &fst, bool copy = false) {
484  if (!copy && fst_) {
485  FSTERROR() << "CacheLogAccumulator: Initialization error";
486  error_ = true;
487  return;
488  }
489  fst_.reset(fst.Copy());
490  }
491 
492  void SetState(StateId s, int depth = 0) {
493  if (s == s_) return;
494  s_ = s;
495  if (data_->CacheDisabled() || error_) {
496  weights_ = nullptr;
497  return;
498  }
499  if (!fst_) {
500  FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized";
501  error_ = true;
502  weights_ = nullptr;
503  return;
504  }
505  weights_ = data_->GetWeights(s);
506  if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) {
507  weights_ = new std::vector<double>;
508  weights_->reserve(fst_->NumArcs(s) + 1);
509  weights_->push_back(FloatLimits<double>::PosInfinity());
510  data_->AddWeights(s, weights_);
511  }
512  }
513 
514  Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
515 
516  template <class ArcIter>
517  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
518  if (weights_ == nullptr) {
519  auto sum = w;
520  aiter->Seek(begin);
521  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
522  sum = LogPlus(sum, aiter->Value().weight);
523  }
524  return sum;
525  } else {
526  Extend(end, aiter);
527  const auto &f1 = (*weights_)[end];
528  const auto &f2 = (*weights_)[begin];
529  if (f1 < f2) {
530  return LogPlus(w, LogMinus(f1, f2));
531  } else {
532  // Commented out for efficiency; adds Zero().
533  /*
534  auto sum = w;
535  // Explicitly computes if cumulative sum lacks precision.
536  aiter->Seek(begin);
537  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
538  sum = LogPlus(sum, aiter->Value().weight);
539  }
540  return sum;
541  */
542  return w;
543  }
544  }
545  }
546 
547  // Returns first position from aiter->Position() whose accumulated
548  // value is greater or equal to w (w.r.t. Zero() < One()). The
549  // iterator may be repositioned.
550  template <class ArcIter>
551  size_t LowerBound(Weight w, ArcIter *aiter) {
552  const auto f = to_log_weight_(w).Value();
553  auto pos = aiter->Position();
554  if (weights_) {
555  Extend(fst_->NumArcs(s_), aiter);
556  return std::lower_bound(weights_->begin() + pos + 1, weights_->end(),
557  f, std::greater<double>()) -
558  weights_->begin() - 1;
559  } else {
560  size_t n = 0;
562  for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
563  x = LogPlus(x, aiter->Value().weight);
564  if (n >= pos && x <= f) break;
565  }
566  return n;
567  }
568  }
569 
570  bool Error() const { return error_; }
571 
572  private:
573  double LogPosExp(double x) {
574  return x == FloatLimits<double>::PosInfinity() ? 0.0
575  : log(1.0F + exp(-x));
576  }
577 
578  double LogMinusExp(double x) {
579  return x == FloatLimits<double>::PosInfinity() ? 0.0
580  : log(1.0F - exp(-x));
581  }
582 
583  Weight LogPlus(Weight w, Weight v) {
584  if (w == Weight::Zero()) {
585  return v;
586  }
587  const auto f1 = to_log_weight_(w).Value();
588  const auto f2 = to_log_weight_(v).Value();
589  if (f1 > f2) {
590  return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
591  } else {
592  return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
593  }
594  }
595 
596  double LogPlus(double f1, Weight v) {
597  const auto f2 = to_log_weight_(v).Value();
598  if (f1 == FloatLimits<double>::PosInfinity()) {
599  return f2;
600  } else if (f1 > f2) {
601  return f2 - LogPosExp(f1 - f2);
602  } else {
603  return f1 - LogPosExp(f2 - f1);
604  }
605  }
606 
607  // Assumes f1 < f2.
608  Weight LogMinus(double f1, double f2) {
609  if (f2 == FloatLimits<double>::PosInfinity()) {
610  return to_weight_(Log64Weight(f1));
611  } else {
612  return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
613  }
614  }
615 
616  // Extends weights up to index 'end'.
617  template <class ArcIter>
618  void Extend(ssize_t end, ArcIter *aiter) {
619  if (weights_->size() <= end) {
620  for (aiter->Seek(weights_->size() - 1); weights_->size() <= end;
621  aiter->Next()) {
622  weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight));
623  }
624  }
625  }
626 
627 
628  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
629  const WeightConvert<Log64Weight, Weight> to_weight_{};
630  ssize_t arc_limit_; // Minimum # of arcs to cache a state.
631  std::vector<double> *weights_; // Accumulated weights for cur. state.
632  std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
633  std::shared_ptr<CacheLogAccumulatorData<Arc>> data_; // Cache data.
634  StateId s_; // Current state.
635  bool error_;
636 };
637 
638 // Stores shareable data for replace accumulator copies.
639 template <class Accumulator, class T>
641  public:
642  using Arc = typename Accumulator::Arc;
643  using Label = typename Arc::Label;
644  using StateId = typename Arc::StateId;
645  using StateTable = T;
646  using StateTuple = typename StateTable::StateTuple;
647 
648  ReplaceAccumulatorData() : state_table_(nullptr) {}
649 
651  const std::vector<Accumulator *> &accumulators)
652  : state_table_(nullptr) {
653  accumulators_.reserve(accumulators.size());
654  for (const auto accumulator : accumulators) {
655  accumulators_.emplace_back(accumulator);
656  }
657  }
658 
659  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
660  const StateTable *state_table) {
661  state_table_ = state_table;
662  accumulators_.resize(fst_tuples.size());
663  for (Label i = 0; i < accumulators_.size(); ++i) {
664  if (!accumulators_[i]) {
665  accumulators_[i].reset(new Accumulator());
666  accumulators_[i]->Init(*(fst_tuples[i].second));
667  }
668  fst_array_.emplace_back(fst_tuples[i].second->Copy());
669  }
670  }
671 
672  const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); }
673 
674  Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); }
675 
676  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
677 
678  private:
679  const StateTable *state_table_;
680  std::vector<std::unique_ptr<Accumulator>> accumulators_;
681  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
682 };
683 
684 // This class accumulates weights in a ReplaceFst. The 'Init' method takes as
685 // input the argument used to build the ReplaceFst and the ReplaceFst state
686 // table. It uses accumulators of type 'Accumulator' in the underlying FSTs.
687 template <class Accumulator,
690  public:
691  using Arc = typename Accumulator::Arc;
692  using Label = typename Arc::Label;
693  using StateId = typename Arc::StateId;
694  using StateTable = T;
695  using StateTuple = typename StateTable::StateTuple;
696  using Weight = typename Arc::Weight;
697 
699  : init_(false),
700  data_(std::make_shared<
701  ReplaceAccumulatorData<Accumulator, StateTable>>()),
702  error_(false) {}
703 
704  explicit ReplaceAccumulator(const std::vector<Accumulator *> &accumulators)
705  : init_(false),
706  data_(std::make_shared<ReplaceAccumulatorData<Accumulator, StateTable>>(
707  accumulators)),
708  error_(false) {}
709 
711  bool safe = false)
712  : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
713  if (!init_) {
714  FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator";
715  }
716  if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported";
717  }
718 
719  // Does not take ownership of the state table, the state table is owned by
720  // the ReplaceFst.
721  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
722  const StateTable *state_table) {
723  init_ = true;
724  data_->Init(fst_tuples, state_table);
725  }
726 
727  // Method required by LookAheadMatcher. However, ReplaceAccumulator needs to
728  // be initialized by calling the Init method above before being passed to
729  // LookAheadMatcher.
730  //
731  // TODO(allauzen): Revisit this. Consider creating a method
732  // Init(const ReplaceFst<A, T, C>&, bool) and using friendship to get access
733  // to the innards of ReplaceFst.
734  void Init(const Fst<Arc> &fst, bool copy = false) {
735  if (!init_) {
736  FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be"
737  << " initialized before being passed to LookAheadMatcher";
738  error_ = true;
739  }
740  }
741 
742  void SetState(StateId s) {
743  if (!init_) {
744  FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized";
745  error_ = true;
746  return;
747  }
748  auto tuple = data_->GetTuple(s);
749  fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based.
750  data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
751  if ((tuple.prefix_id != 0) &&
752  (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
753  offset_ = 1;
754  offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
755  } else {
756  offset_ = 0;
757  offset_weight_ = Weight::Zero();
758  }
759  aiter_.reset(
760  new ArcIterator<Fst<Arc>>(*data_->GetFst(fst_id_), tuple.fst_state));
761  }
762 
764  if (error_) return Weight::NoWeight();
765  return data_->GetAccumulator(fst_id_)->Sum(w, v);
766  }
767 
768  template <class ArcIter>
769  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
770  if (error_) return Weight::NoWeight();
771  auto sum = begin == end ? Weight::Zero()
772  : data_->GetAccumulator(fst_id_)->Sum(
773  w, aiter_.get(), begin ? begin - offset_ : 0,
774  end - offset_);
775  if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum);
776  return sum;
777  }
778 
779  bool Error() const { return error_; }
780 
781  private:
782  bool init_;
783  std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_;
784  Label fst_id_;
785  size_t offset_;
786  Weight offset_weight_;
787  std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
788  bool error_;
789 };
790 
791 // SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it
792 // are always thread-safe copies.
793 template <class Accumulator, class T>
795  public:
796  using Arc = typename Accumulator::Arc;
797  using StateId = typename Arc::StateId;
798  using Label = typename Arc::Label;
799  using Weight = typename Arc::Weight;
800  using StateTable = T;
801  using StateTuple = typename StateTable::StateTuple;
802 
804 
806  : SafeReplaceAccumulator(copy) {}
807 
809  const std::vector<Accumulator> &accumulators) {
810  for (const auto &accumulator : accumulators) {
811  accumulators_.emplace_back(accumulator, true);
812  }
813  }
814 
815  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
816  const StateTable *state_table) {
817  state_table_ = state_table;
818  for (Label i = 0; i < fst_tuples.size(); ++i) {
819  if (i == accumulators_.size()) {
820  accumulators_.resize(accumulators_.size() + 1);
821  accumulators_[i].Init(*(fst_tuples[i].second));
822  }
823  fst_array_.emplace_back(fst_tuples[i].second->Copy(true));
824  }
825  init_ = true;
826  }
827 
828  void Init(const Fst<Arc> &fst, bool copy = false) {
829  if (!init_) {
830  FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be"
831  << " initialized before being passed to LookAheadMatcher";
832  error_ = true;
833  }
834  }
835 
836  void SetState(StateId s) {
837  auto tuple = state_table_->Tuple(s);
838  fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based
839  GetAccumulator(fst_id_)->SetState(tuple.fst_state);
840  offset_ = 0;
841  offset_weight_ = Weight::Zero();
842  const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state);
843  if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) {
844  offset_ = 1;
845  offset_weight_ = final_weight;
846  }
847  aiter_.Set(*GetFst(fst_id_), tuple.fst_state);
848  }
849 
851  if (error_) return Weight::NoWeight();
852  return GetAccumulator(fst_id_)->Sum(w, v);
853  }
854 
855  template <class ArcIter>
856  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
857  if (error_) return Weight::NoWeight();
858  if (begin == end) return Weight::Zero();
859  auto sum = GetAccumulator(fst_id_)->Sum(
860  w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_);
861  if (begin == 0 && end != 0 && offset_ > 0) {
862  sum = Sum(offset_weight_, sum);
863  }
864  return sum;
865  }
866 
867  bool Error() const { return error_; }
868 
869  private:
870  class ArcIteratorPtr {
871  public:
872  ArcIteratorPtr() {}
873 
874  ArcIteratorPtr(const ArcIteratorPtr &copy) {}
875 
876  void Set(const Fst<Arc> &fst, StateId state_id) {
877  ptr_.reset(new ArcIterator<Fst<Arc>>(fst, state_id));
878  }
879 
880  ArcIterator<Fst<Arc>> *get() { return ptr_.get(); }
881 
882  private:
883  std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_;
884  };
885 
886  Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; }
887 
888  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
889 
890  const StateTable *state_table_;
891  std::vector<Accumulator> accumulators_;
892  std::vector<std::shared_ptr<Fst<Arc>>> fst_array_;
893  ArcIteratorPtr aiter_;
894  bool init_ = false;
895  bool error_ = false;
896  Label fst_id_;
897  size_t offset_;
898  Weight offset_weight_;
899 };
900 
901 } // namespace fst
902 
903 #endif // FST_ACCUMULATOR_H_
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:77
FastLogAccumulator(const FastLogAccumulator< Arc > &acc, bool safe=false)
Definition: accumulator.h:223
SafeReplaceAccumulator(const std::vector< Accumulator > &accumulators)
Definition: accumulator.h:808
constexpr bool Error() const
Definition: accumulator.h:51
const Fst< Arc > * GetFst(size_t i) const
Definition: accumulator.h:676
typename Arc::Weight Weight
Definition: accumulator.h:799
void SetState(StateId s)
Definition: accumulator.h:232
CacheLogAccumulatorData(const CacheLogAccumulatorData< Arc > &data)
Definition: accumulator.h:395
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:40
ReplaceAccumulatorData(const std::vector< Accumulator * > &accumulators)
Definition: accumulator.h:650
std::shared_ptr< FastLogAccumulatorData > GetData() const
Definition: accumulator.h:327
void SetState(StateId s)
Definition: accumulator.h:836
typename Arc::Weight Weight
Definition: accumulator.h:696
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:517
MutableFastLogAccumulatorData(int arc_limit, int arc_period)
Definition: accumulator.h:171
typename Arc::Weight Weight
Definition: accumulator.h:30
void Init(int num_weights, const double *weights, int num_positions, const int *weight_positions)
Definition: accumulator.h:147
typename Arc::Label Label
Definition: accumulator.h:643
void SetData(std::vector< double > *weights, std::vector< int > *weight_positions) override
Definition: accumulator.h:176
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:43
typename Arc::Weight Weight
Definition: accumulator.h:465
FastLogAccumulatorData(int arc_limit, int arc_period)
Definition: accumulator.h:112
const int * WeightPositions() const
Definition: accumulator.h:130
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:850
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:695
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:815
constexpr int kNoStateId
Definition: fst.h:180
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:659
const Arc & Value() const
Definition: fst.h:503
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:769
static constexpr T PosInfinity()
Definition: float-weight.h:29
ReplaceAccumulator(const ReplaceAccumulator< Accumulator, StateTable > &acc, bool safe=false)
Definition: accumulator.h:710
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:483
#define FSTERROR()
Definition: util.h:35
void Init(const FST &fst, bool copy=false)
Definition: accumulator.h:294
CacheLogAccumulatorData(bool gc, size_t gc_limit)
Definition: accumulator.h:392
bool IsMutable() const override
Definition: accumulator.h:174
Weight Add(const Weight &w)
Definition: weight.h:216
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:449
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:763
typename Arc::StateId StateId
Definition: accumulator.h:201
Weight Sum()
Definition: weight.h:221
typename Arc::StateId StateId
Definition: accumulator.h:389
SafeReplaceAccumulator(const SafeReplaceAccumulator &copy, bool safe)
Definition: accumulator.h:805
const StateTuple & GetTuple(StateId s) const
Definition: accumulator.h:672
typename Arc::StateId StateId
Definition: accumulator.h:464
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:646
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:721
typename Arc::Weight Weight
Definition: accumulator.h:390
void SetFlags(uint32 flags, uint32 mask)
Definition: fst.h:541
FastLogAccumulator(std::shared_ptr< FastLogAccumulatorData > data)
Definition: accumulator.h:214
typename Arc::Label Label
Definition: accumulator.h:692
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
void AddWeights(StateId s, std::vector< double > *weights)
Definition: accumulator.h:412
bool Done() const
Definition: fst.h:499
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:734
void SetState(StateId s)
Definition: accumulator.h:72
typename Accumulator::Arc Arc
Definition: accumulator.h:796
typename Accumulator::Arc Arc
Definition: accumulator.h:642
LogAccumulator(const LogAccumulator &acc, bool safe=false)
Definition: accumulator.h:68
double LogPosExp(double x)
Definition: float-weight.h:454
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const
Definition: accumulator.h:245
void SetState(StateId s, int depth=0)
Definition: accumulator.h:492
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:801
std::vector< double > * GetWeights(StateId s)
Definition: accumulator.h:402
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:514
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:154
typename Arc::StateId StateId
Definition: accumulator.h:644
void SetState(StateId state)
Definition: accumulator.h:38
const double * Weights() const
Definition: accumulator.h:124
bool Done() const
Definition: fst.h:383
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:856
Accumulator * GetAccumulator(size_t i)
Definition: accumulator.h:674
typename Arc::StateId StateId
Definition: accumulator.h:29
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:828
typename Accumulator::Arc Arc
Definition: accumulator.h:691
Weight Sum(Weight w, Weight v) const
Definition: accumulator.h:242
typename Arc::StateId StateId
Definition: accumulator.h:693
FastLogAccumulator(ssize_t arc_limit=20, ssize_t arc_period=10)
Definition: accumulator.h:204
ReplaceAccumulator(const std::vector< Accumulator * > &accumulators)
Definition: accumulator.h:704
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:551
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:74
typename Arc::StateId StateId
Definition: accumulator.h:797
CacheLogAccumulator(ssize_t arc_limit=10, bool gc=false, size_t gc_limit=10 *1024 *1024)
Definition: accumulator.h:467
DefaultAccumulator(const DefaultAccumulator &acc, bool safe=false)
Definition: accumulator.h:34
virtual Fst< Arc > * Copy(bool safe=false) const =0
void SetState(StateId s)
Definition: accumulator.h:742
typename Arc::Weight Weight
Definition: accumulator.h:64
typename Arc::Weight Weight
Definition: accumulator.h:202
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:36
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:70
typename Arc::Label Label
Definition: accumulator.h:798
CacheLogAccumulator(const CacheLogAccumulator< Arc > &acc, bool safe=false)
Definition: accumulator.h:474
constexpr bool Error() const
Definition: accumulator.h:86
void Next()
Definition: fst.h:507
typename Arc::StateId StateId
Definition: accumulator.h:63