FST  openfst-1.8.2
OpenFst Library
accumulator.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to accumulate arc weights. Useful for weight lookahead.
19 
20 #ifndef FST_ACCUMULATOR_H_
21 #define FST_ACCUMULATOR_H_
22 
23 #include <algorithm>
24 #include <functional>
25 #include <memory>
26 #include <vector>
27 
28 #include <fst/log.h>
29 
30 #include <fst/arcfilter.h>
31 #include <fst/arcsort.h>
32 #include <fst/dfs-visit.h>
33 #include <fst/expanded-fst.h>
34 #include <fst/replace.h>
35 
36 #include <unordered_map>
37 
38 namespace fst {
39 
40 // This class accumulates arc weights using the semiring Plus().
41 // Sum(w, aiter, begin, end) has time complexity O(begin - end).
42 template <class A>
44  public:
45  using Arc = A;
46  using StateId = typename Arc::StateId;
47  using Weight = typename Arc::Weight;
48 
50 
51  DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {}
52 
53  void Init(const Fst<Arc> &fst, bool copy = false) {}
54 
55  void SetState(StateId state) {}
56 
57  Weight Sum(Weight w, Weight v) { return Plus(w, v); }
58 
59  template <class ArcIter>
60  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
61  Adder<Weight> adder(w); // maintains cumulative sum accurately
62  aiter->Seek(begin);
63  for (auto pos = begin; pos < end; aiter->Next(), ++pos)
64  adder.Add(aiter->Value().weight);
65  return adder.Sum();
66  }
67 
68  constexpr bool Error() const { return false; }
69 
70  private:
71  DefaultAccumulator &operator=(const DefaultAccumulator &) = delete;
72 };
73 
74 // This class accumulates arc weights using the log semiring Plus() assuming an
75 // arc weight has a WeightConvert specialization to and from log64 weights.
76 // Sum(w, aiter, begin, end) has time complexity O(begin - end).
77 template <class A>
79  public:
80  using Arc = A;
81  using StateId = typename Arc::StateId;
82  using Weight = typename Arc::Weight;
83 
85 
86  LogAccumulator(const LogAccumulator &acc, bool safe = false) {}
87 
88  void Init(const Fst<Arc> &fst, bool copy = false) {}
89 
90  void SetState(StateId s) {}
91 
92  Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
93 
94  template <class ArcIter>
95  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
96  auto sum = w;
97  aiter->Seek(begin);
98  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
99  sum = LogPlus(sum, aiter->Value().weight);
100  }
101  return sum;
102  }
103 
104  constexpr bool Error() const { return false; }
105 
106  private:
107  Weight LogPlus(Weight w, Weight v) {
108  if (w == Weight::Zero()) {
109  return v;
110  }
111  const auto f1 = to_log_weight_(w).Value();
112  const auto f2 = to_log_weight_(v).Value();
113  if (f1 > f2) {
114  return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2)));
115  } else {
116  return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1)));
117  }
118  }
119 
120  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
121  const WeightConvert<Log64Weight, Weight> to_weight_{};
122 
123  LogAccumulator &operator=(const LogAccumulator &) = delete;
124 };
125 
126 // Interface for shareable data for fast log accumulator copies. Holds pointers
127 // to data only, storage is provided by derived classes.
129  public:
130  FastLogAccumulatorData(int arc_limit, int arc_period)
131  : arc_limit_(arc_limit),
132  arc_period_(arc_period),
133  weights_ptr_(nullptr),
134  num_weights_(0),
135  weight_positions_ptr_(nullptr),
136  num_positions_(0) {}
137 
139 
140  // Cummulative weight per state for all states s.t. # of arcs > arc_limit_
141  // with arcs in order. The first element per state is Log64Weight::Zero().
142  const double *Weights() const { return weights_ptr_; }
143 
144  int NumWeights() const { return num_weights_; }
145 
146  // Maps from state to corresponding beginning weight position in weights_.
147  // osition -1 means no pre-computed weights for that state.
148  const int *WeightPositions() const { return weight_positions_ptr_; }
149 
150  int NumPositions() const { return num_positions_; }
151 
152  int ArcLimit() const { return arc_limit_; }
153 
154  int ArcPeriod() const { return arc_period_; }
155 
156  // Returns true if the data object is mutable and supports SetData().
157  virtual bool IsMutable() const = 0;
158 
159  // Does not take ownership but may invalidate the contents of weights and
160  // weight_positions.
161  virtual void SetData(std::vector<double> *weights,
162  std::vector<int> *weight_positions) = 0;
163 
164  protected:
165  void Init(int num_weights, const double *weights, int num_positions,
166  const int *weight_positions) {
167  weights_ptr_ = weights;
168  num_weights_ = num_weights;
169  weight_positions_ptr_ = weight_positions;
170  num_positions_ = num_positions;
171  }
172 
173  private:
174  const int arc_limit_;
175  const int arc_period_;
176  const double *weights_ptr_;
177  int num_weights_;
178  const int *weight_positions_ptr_;
179  int num_positions_;
180 
182  FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete;
183 };
184 
185 // FastLogAccumulatorData with mutable storage; filled by
186 // FastLogAccumulator::Init.
188  public:
189  MutableFastLogAccumulatorData(int arc_limit, int arc_period)
190  : FastLogAccumulatorData(arc_limit, arc_period) {}
191 
192  bool IsMutable() const override { return true; }
193 
194  void SetData(std::vector<double> *weights,
195  std::vector<int> *weight_positions) override {
196  weights_.swap(*weights);
197  weight_positions_.swap(*weight_positions);
198  Init(weights_.size(), weights_.data(), weight_positions_.size(),
199  weight_positions_.data());
200  }
201 
202  private:
203  std::vector<double> weights_;
204  std::vector<int> weight_positions_;
205 
208  const MutableFastLogAccumulatorData &) = delete;
209 };
210 
211 // This class accumulates arc weights using the log semiring Plus() assuming an
212 // arc weight has a WeightConvert specialization to and from log64 weights. The
213 // member function Init(fst) has to be called to setup pre-computed weight
214 // information.
215 // Sum(w, aiter, begin, end) has time complexity O(arc_limit_) or O(arc_period_)
216 // depending on whether the state has more than arc_limit_ arcs
217 // Space complexity is O(CountStates(fst) + CountArcs(fst) / arc_period_).
218 template <class A>
220  public:
221  using Arc = A;
222  using StateId = typename Arc::StateId;
223  using Weight = typename Arc::Weight;
224 
225  explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
226  : to_log_weight_(),
227  to_weight_(),
228  arc_limit_(arc_limit),
229  arc_period_(arc_period),
230  data_(std::make_shared<MutableFastLogAccumulatorData>(arc_limit,
231  arc_period)),
232  state_weights_(nullptr),
233  error_(false) {}
234 
235  explicit FastLogAccumulator(std::shared_ptr<FastLogAccumulatorData> data)
236  : to_log_weight_(),
237  to_weight_(),
238  arc_limit_(data->ArcLimit()),
239  arc_period_(data->ArcPeriod()),
240  data_(data),
241  state_weights_(nullptr),
242  error_(false) {}
243 
244  FastLogAccumulator(const FastLogAccumulator &acc, bool safe = false)
245  : to_log_weight_(),
246  to_weight_(),
247  arc_limit_(acc.arc_limit_),
248  arc_period_(acc.arc_period_),
249  data_(acc.data_),
250  state_weights_(nullptr),
251  error_(acc.error_) {}
252 
253  void SetState(StateId s) {
254  const auto *weights = data_->Weights();
255  const auto *weight_positions = data_->WeightPositions();
256  state_weights_ = nullptr;
257  if (s < data_->NumPositions()) {
258  const auto pos = weight_positions[s];
259  if (pos >= 0) state_weights_ = &(weights[pos]);
260  }
261  }
262 
263  Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); }
264 
265  template <class ArcIter>
266  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const {
267  if (error_) return Weight::NoWeight();
268  auto sum = w;
269  // Finds begin and end of pre-stored weights.
270  ssize_t index_begin = -1;
271  ssize_t index_end = -1;
272  ssize_t stored_begin = end;
273  ssize_t stored_end = end;
274  if (state_weights_) {
275  index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0;
276  index_end = end / arc_period_;
277  stored_begin = index_begin * arc_period_;
278  stored_end = index_end * arc_period_;
279  }
280  // Computes sum before pre-stored weights.
281  if (begin < stored_begin) {
282  const auto pos_end = std::min(stored_begin, end);
283  aiter->Seek(begin);
284  for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) {
285  sum = LogPlus(sum, aiter->Value().weight);
286  }
287  }
288  // Computes sum between pre-stored weights.
289  if (stored_begin < stored_end) {
290  const auto f1 = state_weights_[index_end];
291  const auto f2 = state_weights_[index_begin];
292  if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2));
293  // Commented out for efficiency; adds Zero().
294  /*
295  else {
296  // explicitly computes if cumulative sum lacks precision
297  aiter->Seek(stored_begin);
298  for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos)
299  sum = LogPlus(sum, aiter->Value().weight);
300  }
301  */
302  }
303  // Computes sum after pre-stored weights.
304  if (stored_end < end) {
305  const auto pos_start = std::max(stored_begin, stored_end);
306  aiter->Seek(pos_start);
307  for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) {
308  sum = LogPlus(sum, aiter->Value().weight);
309  }
310  }
311  return sum;
312  }
313 
314  template <class FST>
315  void Init(const FST &fst, bool copy = false) {
316  if (copy || !data_->IsMutable()) return;
317  if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) {
318  FSTERROR() << "FastLogAccumulator: Initialization error";
319  error_ = true;
320  return;
321  }
322  std::vector<double> weights;
323  std::vector<int> weight_positions;
324  weight_positions.reserve(CountStates(fst));
325  for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
326  const auto s = siter.Value();
327  if (fst.NumArcs(s) >= arc_limit_) {
329  if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1);
330  weight_positions[s] = weights.size();
331  weights.push_back(sum);
332  size_t narcs = 0;
333  ArcIterator<FST> aiter(fst, s);
335  for (; !aiter.Done(); aiter.Next()) {
336  const auto &arc = aiter.Value();
337  sum = LogPlus(sum, arc.weight);
338  // Stores cumulative weight distribution per arc_period_.
339  if (++narcs % arc_period_ == 0) weights.push_back(sum);
340  }
341  }
342  }
343  data_->SetData(&weights, &weight_positions);
344  }
345 
346  bool Error() const { return error_; }
347 
348  std::shared_ptr<FastLogAccumulatorData> GetData() const { return data_; }
349 
350  private:
351  static double LogPosExp(double x) {
352  return x == FloatLimits<double>::PosInfinity() ? 0.0
353  : log(1.0F + exp(-x));
354  }
355 
356  static double LogMinusExp(double x) {
357  return x == FloatLimits<double>::PosInfinity() ? 0.0
358  : log(1.0F - exp(-x));
359  }
360 
361  Weight LogPlus(Weight w, Weight v) const {
362  if (w == Weight::Zero()) {
363  return v;
364  }
365  const auto f1 = to_log_weight_(w).Value();
366  const auto f2 = to_log_weight_(v).Value();
367  if (f1 > f2) {
368  return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
369  } else {
370  return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
371  }
372  }
373 
374  double LogPlus(double f1, Weight v) const {
375  const auto f2 = to_log_weight_(v).Value();
376  if (f1 == FloatLimits<double>::PosInfinity()) {
377  return f2;
378  } else if (f1 > f2) {
379  return f2 - LogPosExp(f1 - f2);
380  } else {
381  return f1 - LogPosExp(f2 - f1);
382  }
383  }
384 
385  // Assumes f1 < f2.
386  Weight LogMinus(double f1, double f2) const {
387  if (f2 == FloatLimits<double>::PosInfinity()) {
388  return to_weight_(Log64Weight(f1));
389  } else {
390  return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
391  }
392  }
393 
394  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
395  const WeightConvert<Log64Weight, Weight> to_weight_{};
396  const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state.
397  const ssize_t arc_period_; // Saves cumulative weights per arc_period_.
398  std::shared_ptr<FastLogAccumulatorData> data_;
399  const double *state_weights_;
400  bool error_;
401 
402  FastLogAccumulator &operator=(const FastLogAccumulator &) = delete;
403 };
404 
405 // Stores shareable data for cache log accumulator copies. All copies share the
406 // same cache.
407 template <class Arc>
409  public:
410  using StateId = typename Arc::StateId;
411  using Weight = typename Arc::Weight;
412 
413  CacheLogAccumulatorData(bool gc, size_t gc_limit)
414  : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
415 
417  : cache_gc_(data.cache_gc_),
418  cache_limit_(data.cache_limit_),
419  cache_size_(0) {}
420 
421  bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
422 
423  std::vector<double> *GetWeights(StateId s) {
424  auto it = cache_.find(s);
425  if (it != cache_.end()) {
426  it->second.recent = true;
427  return it->second.weights.get();
428  } else {
429  return nullptr;
430  }
431  }
432 
433  void AddWeights(StateId s, std::vector<double> *weights) {
434  if (cache_gc_ && cache_size_ >= cache_limit_) GC(false);
435  cache_.emplace(s, CacheState(weights, true));
436  if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double);
437  }
438 
439  private:
440  // Cached information for a given state.
441  struct CacheState {
442  std::unique_ptr<std::vector<double>> weights; // Accumulated weights.
443  bool recent; // Has this state been accessed since last GC?
444 
445  CacheState(std::vector<double> *weights, bool recent)
446  : weights(weights), recent(recent) {}
447  };
448 
449  // Garbage collect: Deletes from cache states that have not been accessed
450  // since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of
451  // 'cache_limit_'. If it does not free enough memory, start deleting
452  // recently accessed states.
453  void GC(bool free_recent) {
454  auto cache_target = (2 * cache_limit_) / 3 + 1;
455  auto it = cache_.begin();
456  while (it != cache_.end() && cache_size_ > cache_target) {
457  auto &cs = it->second;
458  if (free_recent || !cs.recent) {
459  cache_size_ -= cs.weights->capacity() * sizeof(double);
460  cache_.erase(it++);
461  } else {
462  cs.recent = false;
463  ++it;
464  }
465  }
466  if (!free_recent && cache_size_ > cache_target) GC(true);
467  }
468 
469  std::unordered_map<StateId, CacheState> cache_; // Cache.
470  bool cache_gc_; // Enables garbage collection.
471  size_t cache_limit_; // # of bytes cached.
472  size_t cache_size_; // # of bytes allowed before GC.
473 
474  CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete;
475 };
476 
477 // This class accumulates arc weights using the log semiring Plus() has a
478 // WeightConvert specialization to and from log64 weights. It is similar to the
479 // FastLogAccumator. However here, the accumulated weights are pre-computed and
480 // stored only for the states that are visited. The member function Init(fst)
481 // has to be called to setup this accumulator. Space complexity is O(gc_limit).
482 template <class Arc>
484  public:
485  using StateId = typename Arc::StateId;
486  using Weight = typename Arc::Weight;
487 
488  explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
489  size_t gc_limit = 10 * 1024 * 1024)
490  : arc_limit_(arc_limit),
491  data_(std::make_shared<CacheLogAccumulatorData<Arc>>(gc, gc_limit)),
492  s_(kNoStateId),
493  error_(false) {}
494 
495  CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe = false)
496  : arc_limit_(acc.arc_limit_),
497  fst_(acc.fst_ ? acc.fst_->Copy() : nullptr),
498  data_(safe ? std::make_shared<CacheLogAccumulatorData<Arc>>(*acc.data_)
499  : acc.data_),
500  s_(kNoStateId),
501  error_(acc.error_) {}
502 
503  // Argument arc_limit specifies the minimum number of arcs to pre-compute.
504  void Init(const Fst<Arc> &fst, bool copy = false) {
505  if (!copy && fst_) {
506  FSTERROR() << "CacheLogAccumulator: Initialization error";
507  error_ = true;
508  return;
509  }
510  fst_.reset(fst.Copy());
511  }
512 
513  void SetState(StateId s, int depth = 0) {
514  if (s == s_) return;
515  s_ = s;
516  if (data_->CacheDisabled() || error_) {
517  weights_ = nullptr;
518  return;
519  }
520  if (!fst_) {
521  FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized";
522  error_ = true;
523  weights_ = nullptr;
524  return;
525  }
526  weights_ = data_->GetWeights(s);
527  if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) {
528  weights_ = new std::vector<double>;
529  weights_->reserve(fst_->NumArcs(s) + 1);
530  weights_->push_back(FloatLimits<double>::PosInfinity());
531  data_->AddWeights(s, weights_);
532  }
533  }
534 
535  Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
536 
537  template <class ArcIter>
538  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
539  if (weights_ == nullptr) {
540  auto sum = w;
541  aiter->Seek(begin);
542  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
543  sum = LogPlus(sum, aiter->Value().weight);
544  }
545  return sum;
546  } else {
547  Extend(end, aiter);
548  const auto &f1 = (*weights_)[end];
549  const auto &f2 = (*weights_)[begin];
550  if (f1 < f2) {
551  return LogPlus(w, LogMinus(f1, f2));
552  } else {
553  // Commented out for efficiency; adds Zero().
554  /*
555  auto sum = w;
556  // Explicitly computes if cumulative sum lacks precision.
557  aiter->Seek(begin);
558  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
559  sum = LogPlus(sum, aiter->Value().weight);
560  }
561  return sum;
562  */
563  return w;
564  }
565  }
566  }
567 
568  // Returns first position from aiter->Position() whose accumulated
569  // value is greater or equal to w (w.r.t. Zero() < One()). The
570  // iterator may be repositioned.
571  template <class ArcIter>
572  size_t LowerBound(Weight w, ArcIter *aiter) {
573  const auto f = to_log_weight_(w).Value();
574  auto pos = aiter->Position();
575  if (weights_) {
576  Extend(fst_->NumArcs(s_), aiter);
577  return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), f,
578  std::greater<double>()) -
579  weights_->begin() - 1;
580  } else {
581  size_t n = 0;
583  for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
584  x = LogPlus(x, aiter->Value().weight);
585  if (n >= pos && x <= f) break;
586  }
587  return n;
588  }
589  }
590 
591  bool Error() const { return error_; }
592 
593  private:
594  double LogPosExp(double x) {
595  return x == FloatLimits<double>::PosInfinity() ? 0.0
596  : log(1.0F + exp(-x));
597  }
598 
599  double LogMinusExp(double x) {
600  return x == FloatLimits<double>::PosInfinity() ? 0.0
601  : log(1.0F - exp(-x));
602  }
603 
604  Weight LogPlus(Weight w, Weight v) {
605  if (w == Weight::Zero()) {
606  return v;
607  }
608  const auto f1 = to_log_weight_(w).Value();
609  const auto f2 = to_log_weight_(v).Value();
610  if (f1 > f2) {
611  return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
612  } else {
613  return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
614  }
615  }
616 
617  double LogPlus(double f1, Weight v) {
618  const auto f2 = to_log_weight_(v).Value();
619  if (f1 == FloatLimits<double>::PosInfinity()) {
620  return f2;
621  } else if (f1 > f2) {
622  return f2 - LogPosExp(f1 - f2);
623  } else {
624  return f1 - LogPosExp(f2 - f1);
625  }
626  }
627 
628  // Assumes f1 < f2.
629  Weight LogMinus(double f1, double f2) {
630  if (f2 == FloatLimits<double>::PosInfinity()) {
631  return to_weight_(Log64Weight(f1));
632  } else {
633  return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
634  }
635  }
636 
637  // Extends weights up to index 'end'.
638  template <class ArcIter>
639  void Extend(ssize_t end, ArcIter *aiter) {
640  if (weights_->size() <= end) {
641  for (aiter->Seek(weights_->size() - 1); weights_->size() <= end;
642  aiter->Next()) {
643  weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight));
644  }
645  }
646  }
647 
648  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
649  const WeightConvert<Log64Weight, Weight> to_weight_{};
650  ssize_t arc_limit_; // Minimum # of arcs to cache a state.
651  std::vector<double> *weights_; // Accumulated weights for cur. state.
652  std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
653  std::shared_ptr<CacheLogAccumulatorData<Arc>> data_; // Cache data.
654  StateId s_; // Current state.
655  bool error_;
656 };
657 
658 // Stores shareable data for replace accumulator copies.
659 template <class Accumulator, class T>
661  public:
662  using Arc = typename Accumulator::Arc;
663  using Label = typename Arc::Label;
664  using StateId = typename Arc::StateId;
665  using StateTable = T;
666  using StateTuple = typename StateTable::StateTuple;
667 
668  ReplaceAccumulatorData() : state_table_(nullptr) {}
669 
671  std::vector<std::unique_ptr<Accumulator>> &&accumulators)
672  : state_table_(nullptr), accumulators_(std::move(accumulators)) {}
673 
674  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
675  const StateTable *state_table) {
676  state_table_ = state_table;
677  accumulators_.resize(fst_tuples.size());
678  for (Label i = 0; i < accumulators_.size(); ++i) {
679  if (!accumulators_[i]) {
680  accumulators_[i] = std::make_unique<Accumulator>();
681  accumulators_[i]->Init(*(fst_tuples[i].second));
682  }
683  fst_array_.emplace_back(fst_tuples[i].second->Copy());
684  }
685  }
686 
687  const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); }
688 
689  Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); }
690 
691  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
692 
693  private:
694  const StateTable *state_table_;
695  std::vector<std::unique_ptr<Accumulator>> accumulators_;
696  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
697 };
698 
699 // This class accumulates weights in a ReplaceFst. The 'Init' method takes as
700 // input the argument used to build the ReplaceFst and the ReplaceFst state
701 // table. It uses accumulators of type 'Accumulator' in the underlying FSTs.
702 template <class Accumulator,
705  public:
706  using Arc = typename Accumulator::Arc;
707  using Label = typename Arc::Label;
708  using StateId = typename Arc::StateId;
709  using StateTable = T;
710  using StateTuple = typename StateTable::StateTuple;
711  using Weight = typename Arc::Weight;
712 
714  : init_(false),
715  data_(std::make_shared<
716  ReplaceAccumulatorData<Accumulator, StateTable>>()),
717  error_(false) {}
718 
720  std::vector<std::unique_ptr<Accumulator>> &&accumulators)
721  : init_(false),
722  data_(std::make_shared<ReplaceAccumulatorData<Accumulator, StateTable>>(
723  std::move(accumulators))),
724  error_(false) {}
725 
727  bool safe = false)
728  : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
729  if (!init_) {
730  FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator";
731  }
732  if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported";
733  }
734 
735  // Does not take ownership of the state table, the state table is owned by
736  // the ReplaceFst.
737  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
738  const StateTable *state_table) {
739  init_ = true;
740  data_->Init(fst_tuples, state_table);
741  }
742 
743  // Method required by LookAheadMatcher. However, ReplaceAccumulator needs to
744  // be initialized by calling the Init method above before being passed to
745  // LookAheadMatcher.
746  //
747  // TODO(allauzen): Revisit this. Consider creating a method
748  // Init(const ReplaceFst<A, T, C>&, bool) and using friendship to get access
749  // to the innards of ReplaceFst.
750  void Init(const Fst<Arc> &fst, bool copy = false) {
751  if (!init_) {
752  FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be"
753  << " initialized before being passed to LookAheadMatcher";
754  error_ = true;
755  }
756  }
757 
758  void SetState(StateId s) {
759  if (!init_) {
760  FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized";
761  error_ = true;
762  return;
763  }
764  auto tuple = data_->GetTuple(s);
765  fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based.
766  data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
767  if ((tuple.prefix_id != 0) &&
768  (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
769  offset_ = 1;
770  offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
771  } else {
772  offset_ = 0;
773  offset_weight_ = Weight::Zero();
774  }
775  aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(*data_->GetFst(fst_id_),
776  tuple.fst_state);
777  }
778 
780  if (error_) return Weight::NoWeight();
781  return data_->GetAccumulator(fst_id_)->Sum(w, v);
782  }
783 
784  template <class ArcIter>
785  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
786  if (error_) return Weight::NoWeight();
787  auto sum = begin == end ? Weight::Zero()
788  : data_->GetAccumulator(fst_id_)->Sum(
789  w, aiter_.get(), begin ? begin - offset_ : 0,
790  end - offset_);
791  if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum);
792  return sum;
793  }
794 
795  bool Error() const { return error_; }
796 
797  private:
798  bool init_;
799  std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_;
800  Label fst_id_;
801  size_t offset_;
802  Weight offset_weight_;
803  std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
804  bool error_;
805 };
806 
807 // SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it
808 // are always thread-safe copies.
809 template <class Accumulator, class T>
811  public:
812  using Arc = typename Accumulator::Arc;
813  using StateId = typename Arc::StateId;
814  using Label = typename Arc::Label;
815  using Weight = typename Arc::Weight;
816  using StateTable = T;
817  using StateTuple = typename StateTable::StateTuple;
818 
820 
822  : SafeReplaceAccumulator(copy) {}
823 
825  const std::vector<Accumulator> &accumulators) {
826  for (const auto &accumulator : accumulators) {
827  accumulators_.emplace_back(accumulator, true);
828  }
829  }
830 
831  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
832  const StateTable *state_table) {
833  state_table_ = state_table;
834  for (Label i = 0; i < fst_tuples.size(); ++i) {
835  if (i == accumulators_.size()) {
836  accumulators_.resize(accumulators_.size() + 1);
837  accumulators_[i].Init(*(fst_tuples[i].second));
838  }
839  fst_array_.emplace_back(fst_tuples[i].second->Copy(true));
840  }
841  init_ = true;
842  }
843 
844  void Init(const Fst<Arc> &fst, bool copy = false) {
845  if (!init_) {
846  FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be"
847  << " initialized before being passed to LookAheadMatcher";
848  error_ = true;
849  }
850  }
851 
852  void SetState(StateId s) {
853  auto tuple = state_table_->Tuple(s);
854  fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based
855  GetAccumulator(fst_id_)->SetState(tuple.fst_state);
856  offset_ = 0;
857  offset_weight_ = Weight::Zero();
858  const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state);
859  if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) {
860  offset_ = 1;
861  offset_weight_ = final_weight;
862  }
863  aiter_.Set(*GetFst(fst_id_), tuple.fst_state);
864  }
865 
867  if (error_) return Weight::NoWeight();
868  return GetAccumulator(fst_id_)->Sum(w, v);
869  }
870 
871  template <class ArcIter>
872  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
873  if (error_) return Weight::NoWeight();
874  if (begin == end) return Weight::Zero();
875  auto sum = GetAccumulator(fst_id_)->Sum(
876  w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_);
877  if (begin == 0 && end != 0 && offset_ > 0) {
878  sum = Sum(offset_weight_, sum);
879  }
880  return sum;
881  }
882 
883  bool Error() const { return error_; }
884 
885  private:
886  class ArcIteratorPtr {
887  public:
888  ArcIteratorPtr() {}
889 
890  ArcIteratorPtr(const ArcIteratorPtr &copy) {}
891 
892  void Set(const Fst<Arc> &fst, StateId state_id) {
893  ptr_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst, state_id);
894  }
895 
896  ArcIterator<Fst<Arc>> *get() { return ptr_.get(); }
897 
898  private:
899  std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_;
900  };
901 
902  Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; }
903 
904  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
905 
906  const StateTable *state_table_;
907  std::vector<Accumulator> accumulators_;
908  std::vector<std::shared_ptr<Fst<Arc>>> fst_array_;
909  ArcIteratorPtr aiter_;
910  bool init_ = false;
911  bool error_ = false;
912  Label fst_id_;
913  size_t offset_;
914  Weight offset_weight_;
915 };
916 
917 } // namespace fst
918 
919 #endif // FST_ACCUMULATOR_H_
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:95
CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe=false)
Definition: accumulator.h:495
SafeReplaceAccumulator(const std::vector< Accumulator > &accumulators)
Definition: accumulator.h:824
constexpr bool Error() const
Definition: accumulator.h:68
const Fst< Arc > * GetFst(size_t i) const
Definition: accumulator.h:691
typename Arc::Weight Weight
Definition: accumulator.h:815
constexpr uint8_t kArcNoCache
Definition: fst.h:452
void SetState(StateId s)
Definition: accumulator.h:253
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
CacheLogAccumulatorData(const CacheLogAccumulatorData< Arc > &data)
Definition: accumulator.h:416
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:57
std::shared_ptr< FastLogAccumulatorData > GetData() const
Definition: accumulator.h:348
void SetState(StateId s)
Definition: accumulator.h:852
typename Arc::Weight Weight
Definition: accumulator.h:711
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:538
constexpr uint8_t kArcFlags
Definition: fst.h:455
MutableFastLogAccumulatorData(int arc_limit, int arc_period)
Definition: accumulator.h:189
typename Arc::Weight Weight
Definition: accumulator.h:47
Weight Sum() const
Definition: weight.h:223
void Init(int num_weights, const double *weights, int num_positions, const int *weight_positions)
Definition: accumulator.h:165
typename Arc::Label Label
Definition: accumulator.h:663
void SetData(std::vector< double > *weights, std::vector< int > *weight_positions) override
Definition: accumulator.h:194
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:60
typename Arc::Weight Weight
Definition: accumulator.h:486
FastLogAccumulatorData(int arc_limit, int arc_period)
Definition: accumulator.h:130
const int * WeightPositions() const
Definition: accumulator.h:148
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:866
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:710
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:831
ReplaceAccumulatorData(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
Definition: accumulator.h:670
constexpr int kNoStateId
Definition: fst.h:202
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:674
const Arc & Value() const
Definition: fst.h:537
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:785
static constexpr T PosInfinity()
Definition: float-weight.h:56
ReplaceAccumulator(const ReplaceAccumulator< Accumulator, StateTable > &acc, bool safe=false)
Definition: accumulator.h:726
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:504
#define FSTERROR()
Definition: util.h:53
void Init(const FST &fst, bool copy=false)
Definition: accumulator.h:315
CacheLogAccumulatorData(bool gc, size_t gc_limit)
Definition: accumulator.h:413
void SetFlags(uint8_t flags, uint8_t mask)
Definition: fst.h:571
bool IsMutable() const override
Definition: accumulator.h:192
Weight Add(const Weight &w)
Definition: weight.h:218
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:468
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:779
typename Arc::StateId StateId
Definition: accumulator.h:222
typename Arc::StateId StateId
Definition: accumulator.h:410
SafeReplaceAccumulator(const SafeReplaceAccumulator &copy, bool safe)
Definition: accumulator.h:821
const StateTuple & GetTuple(StateId s) const
Definition: accumulator.h:687
typename Arc::StateId StateId
Definition: accumulator.h:485
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:666
virtual Fst * Copy(bool safe=false) const =0
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:737
typename Arc::Weight Weight
Definition: accumulator.h:411
FastLogAccumulator(std::shared_ptr< FastLogAccumulatorData > data)
Definition: accumulator.h:235
typename Arc::Label Label
Definition: accumulator.h:707
void AddWeights(StateId s, std::vector< double > *weights)
Definition: accumulator.h:433
bool Done() const
Definition: fst.h:533
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:750
void SetState(StateId s)
Definition: accumulator.h:90
typename Accumulator::Arc Arc
Definition: accumulator.h:812
typename Accumulator::Arc Arc
Definition: accumulator.h:662
LogAccumulator(const LogAccumulator &acc, bool safe=false)
Definition: accumulator.h:86
double LogPosExp(double x)
Definition: float-weight.h:473
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const
Definition: accumulator.h:266
void SetState(StateId s, int depth=0)
Definition: accumulator.h:513
constexpr uint8_t kArcWeightValue
Definition: fst.h:448
FastLogAccumulator(const FastLogAccumulator &acc, bool safe=false)
Definition: accumulator.h:244
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:817
std::vector< double > * GetWeights(StateId s)
Definition: accumulator.h:423
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:535
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:169
typename Arc::StateId StateId
Definition: accumulator.h:664
void SetState(StateId state)
Definition: accumulator.h:55
const double * Weights() const
Definition: accumulator.h:142
bool Done() const
Definition: fst.h:415
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:872
Accumulator * GetAccumulator(size_t i)
Definition: accumulator.h:689
typename Arc::StateId StateId
Definition: accumulator.h:46
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:844
typename Accumulator::Arc Arc
Definition: accumulator.h:706
ReplaceAccumulator(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
Definition: accumulator.h:719
Weight Sum(Weight w, Weight v) const
Definition: accumulator.h:263
typename Arc::StateId StateId
Definition: accumulator.h:708
FastLogAccumulator(ssize_t arc_limit=20, ssize_t arc_period=10)
Definition: accumulator.h:225
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:572
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:92
typename Arc::StateId StateId
Definition: accumulator.h:813
CacheLogAccumulator(ssize_t arc_limit=10, bool gc=false, size_t gc_limit=10 *1024 *1024)
Definition: accumulator.h:488
DefaultAccumulator(const DefaultAccumulator &acc, bool safe=false)
Definition: accumulator.h:51
void SetState(StateId s)
Definition: accumulator.h:758
typename Arc::Weight Weight
Definition: accumulator.h:82
typename Arc::Weight Weight
Definition: accumulator.h:223
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:53
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:88
typename Arc::Label Label
Definition: accumulator.h:814
constexpr bool Error() const
Definition: accumulator.h:104
void Next()
Definition: fst.h:541
typename Arc::StateId StateId
Definition: accumulator.h:81