FST  openfst-1.8.3
OpenFst Library
accumulator.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to accumulate arc weights. Useful for weight lookahead.
19 
20 #ifndef FST_ACCUMULATOR_H_
21 #define FST_ACCUMULATOR_H_
22 
23 #include <sys/types.h>
24 
25 #include <algorithm>
26 #include <cstddef>
27 #include <functional>
28 #include <memory>
29 #include <utility>
30 #include <vector>
31 
32 #include <fst/log.h>
33 #include <fst/arcfilter.h>
34 #include <fst/arcsort.h>
35 #include <fst/dfs-visit.h>
36 #include <fst/expanded-fst.h>
37 #include <fst/float-weight.h>
38 #include <fst/fst.h>
39 #include <fst/replace.h>
40 #include <fst/util.h>
41 #include <fst/weight.h>
42 #include <unordered_map>
43 
44 namespace fst {
45 
46 // This class accumulates arc weights using the semiring Plus().
47 // Sum(w, aiter, begin, end) has time complexity O(begin - end).
48 template <class A>
50  public:
51  using Arc = A;
52  using StateId = typename Arc::StateId;
53  using Weight = typename Arc::Weight;
54 
55  DefaultAccumulator() = default;
56 
57  DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {}
58 
59  void Init(const Fst<Arc> &fst, bool copy = false) {}
60 
61  void SetState(StateId state) {}
62 
63  Weight Sum(Weight w, Weight v) { return Plus(w, v); }
64 
65  template <class ArcIter>
66  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
67  Adder<Weight> adder(w); // maintains cumulative sum accurately
68  aiter->Seek(begin);
69  for (auto pos = begin; pos < end; aiter->Next(), ++pos)
70  adder.Add(aiter->Value().weight);
71  return adder.Sum();
72  }
73 
74  constexpr bool Error() const { return false; }
75 
76  private:
77  DefaultAccumulator &operator=(const DefaultAccumulator &) = delete;
78 };
79 
80 // This class accumulates arc weights using the log semiring Plus() assuming an
81 // arc weight has a WeightConvert specialization to and from log64 weights.
82 // Sum(w, aiter, begin, end) has time complexity O(begin - end).
83 template <class A>
85  public:
86  using Arc = A;
87  using StateId = typename Arc::StateId;
88  using Weight = typename Arc::Weight;
89 
90  LogAccumulator() = default;
91 
92  LogAccumulator(const LogAccumulator &acc, bool safe = false) {}
93 
94  void Init(const Fst<Arc> &fst, bool copy = false) {}
95 
96  void SetState(StateId s) {}
97 
98  Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
99 
100  template <class ArcIter>
101  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
102  auto sum = w;
103  aiter->Seek(begin);
104  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
105  sum = LogPlus(sum, aiter->Value().weight);
106  }
107  return sum;
108  }
109 
110  constexpr bool Error() const { return false; }
111 
112  private:
113  Weight LogPlus(Weight w, Weight v) {
114  if (w == Weight::Zero()) {
115  return v;
116  }
117  const auto f1 = to_log_weight_(w).Value();
118  const auto f2 = to_log_weight_(v).Value();
119  if (f1 > f2) {
120  return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2)));
121  } else {
122  return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1)));
123  }
124  }
125 
126  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
127  const WeightConvert<Log64Weight, Weight> to_weight_{};
128 
129  LogAccumulator &operator=(const LogAccumulator &) = delete;
130 };
131 
132 // Interface for shareable data for fast log accumulator copies. Holds pointers
133 // to data only, storage is provided by derived classes.
135  public:
136  FastLogAccumulatorData(int arc_limit, int arc_period)
137  : arc_limit_(arc_limit),
138  arc_period_(arc_period),
139  weights_ptr_(nullptr),
140  num_weights_(0),
141  weight_positions_ptr_(nullptr),
142  num_positions_(0) {}
143 
144  virtual ~FastLogAccumulatorData() = default;
145 
146  // Cummulative weight per state for all states s.t. # of arcs > arc_limit_
147  // with arcs in order. The first element per state is Log64Weight::Zero().
148  const double *Weights() const { return weights_ptr_; }
149 
150  int NumWeights() const { return num_weights_; }
151 
152  // Maps from state to corresponding beginning weight position in weights_.
153  // osition -1 means no pre-computed weights for that state.
154  const int *WeightPositions() const { return weight_positions_ptr_; }
155 
156  int NumPositions() const { return num_positions_; }
157 
158  int ArcLimit() const { return arc_limit_; }
159 
160  int ArcPeriod() const { return arc_period_; }
161 
162  // Returns true if the data object is mutable and supports SetData().
163  virtual bool IsMutable() const = 0;
164 
165  // Does not take ownership but may invalidate the contents of weights and
166  // weight_positions.
167  virtual void SetData(std::vector<double> *weights,
168  std::vector<int> *weight_positions) = 0;
169 
170  protected:
171  void Init(int num_weights, const double *weights, int num_positions,
172  const int *weight_positions) {
173  weights_ptr_ = weights;
174  num_weights_ = num_weights;
175  weight_positions_ptr_ = weight_positions;
176  num_positions_ = num_positions;
177  }
178 
179  private:
180  const int arc_limit_;
181  const int arc_period_;
182  const double *weights_ptr_;
183  int num_weights_;
184  const int *weight_positions_ptr_;
185  int num_positions_;
186 
188  FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete;
189 };
190 
191 // FastLogAccumulatorData with mutable storage; filled by
192 // FastLogAccumulator::Init.
194  public:
195  MutableFastLogAccumulatorData(int arc_limit, int arc_period)
196  : FastLogAccumulatorData(arc_limit, arc_period) {}
197 
198  bool IsMutable() const override { return true; }
199 
200  void SetData(std::vector<double> *weights,
201  std::vector<int> *weight_positions) override {
202  weights_.swap(*weights);
203  weight_positions_.swap(*weight_positions);
204  Init(weights_.size(), weights_.data(), weight_positions_.size(),
205  weight_positions_.data());
206  }
207 
208  private:
209  std::vector<double> weights_;
210  std::vector<int> weight_positions_;
211 
214  const MutableFastLogAccumulatorData &) = delete;
215 };
216 
217 // This class accumulates arc weights using the log semiring Plus() assuming an
218 // arc weight has a WeightConvert specialization to and from log64 weights. The
219 // member function Init(fst) has to be called to setup pre-computed weight
220 // information.
221 // Sum(w, aiter, begin, end) has time complexity O(arc_limit_) or O(arc_period_)
222 // depending on whether the state has more than arc_limit_ arcs
223 // Space complexity is O(CountStates(fst) + CountArcs(fst) / arc_period_).
224 template <class A>
226  public:
227  using Arc = A;
228  using StateId = typename Arc::StateId;
229  using Weight = typename Arc::Weight;
230 
231  explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
232  : to_log_weight_(),
233  to_weight_(),
234  arc_limit_(arc_limit),
235  arc_period_(arc_period),
236  data_(std::make_shared<MutableFastLogAccumulatorData>(arc_limit,
237  arc_period)),
238  state_weights_(nullptr),
239  error_(false) {}
240 
241  explicit FastLogAccumulator(std::shared_ptr<FastLogAccumulatorData> data)
242  : to_log_weight_(),
243  to_weight_(),
244  arc_limit_(data->ArcLimit()),
245  arc_period_(data->ArcPeriod()),
246  data_(data),
247  state_weights_(nullptr),
248  error_(false) {}
249 
250  FastLogAccumulator(const FastLogAccumulator &acc, bool safe = false)
251  : to_log_weight_(),
252  to_weight_(),
253  arc_limit_(acc.arc_limit_),
254  arc_period_(acc.arc_period_),
255  data_(acc.data_),
256  state_weights_(nullptr),
257  error_(acc.error_) {}
258 
259  void SetState(StateId s) {
260  const auto *weights = data_->Weights();
261  const auto *weight_positions = data_->WeightPositions();
262  state_weights_ = nullptr;
263  if (s < data_->NumPositions()) {
264  const auto pos = weight_positions[s];
265  if (pos >= 0) state_weights_ = &(weights[pos]);
266  }
267  }
268 
269  Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); }
270 
271  template <class ArcIter>
272  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const {
273  if (error_) return Weight::NoWeight();
274  auto sum = w;
275  // Finds begin and end of pre-stored weights.
276  ssize_t index_begin = -1;
277  ssize_t index_end = -1;
278  ssize_t stored_begin = end;
279  ssize_t stored_end = end;
280  if (state_weights_) {
281  index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0;
282  index_end = end / arc_period_;
283  stored_begin = index_begin * arc_period_;
284  stored_end = index_end * arc_period_;
285  }
286  // Computes sum before pre-stored weights.
287  if (begin < stored_begin) {
288  const auto pos_end = std::min(stored_begin, end);
289  aiter->Seek(begin);
290  for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) {
291  sum = LogPlus(sum, aiter->Value().weight);
292  }
293  }
294  // Computes sum between pre-stored weights.
295  if (stored_begin < stored_end) {
296  const auto f1 = state_weights_[index_end];
297  const auto f2 = state_weights_[index_begin];
298  if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2));
299  // Commented out for efficiency; adds Zero().
300  /*
301  else {
302  // explicitly computes if cumulative sum lacks precision
303  aiter->Seek(stored_begin);
304  for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos)
305  sum = LogPlus(sum, aiter->Value().weight);
306  }
307  */
308  }
309  // Computes sum after pre-stored weights.
310  if (stored_end < end) {
311  const auto pos_start = std::max(stored_begin, stored_end);
312  aiter->Seek(pos_start);
313  for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) {
314  sum = LogPlus(sum, aiter->Value().weight);
315  }
316  }
317  return sum;
318  }
319 
320  template <class FST>
321  void Init(const FST &fst, bool copy = false) {
322  if (copy || !data_->IsMutable()) return;
323  if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) {
324  FSTERROR() << "FastLogAccumulator: Initialization error";
325  error_ = true;
326  return;
327  }
328  std::vector<double> weights;
329  std::vector<int> weight_positions;
330  weight_positions.reserve(CountStates(fst));
331  for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
332  const auto s = siter.Value();
333  if (fst.NumArcs(s) >= arc_limit_) {
335  if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1);
336  weight_positions[s] = weights.size();
337  weights.push_back(sum);
338  size_t narcs = 0;
339  ArcIterator<FST> aiter(fst, s);
341  for (; !aiter.Done(); aiter.Next()) {
342  const auto &arc = aiter.Value();
343  sum = LogPlus(sum, arc.weight);
344  // Stores cumulative weight distribution per arc_period_.
345  if (++narcs % arc_period_ == 0) weights.push_back(sum);
346  }
347  }
348  }
349  data_->SetData(&weights, &weight_positions);
350  }
351 
352  bool Error() const { return error_; }
353 
354  std::shared_ptr<FastLogAccumulatorData> GetData() const { return data_; }
355 
356  private:
357  static double LogPosExp(double x) {
358  return x == FloatLimits<double>::PosInfinity() ? 0.0
359  : log(1.0F + exp(-x));
360  }
361 
362  static double LogMinusExp(double x) {
363  return x == FloatLimits<double>::PosInfinity() ? 0.0
364  : log(1.0F - exp(-x));
365  }
366 
367  Weight LogPlus(Weight w, Weight v) const {
368  if (w == Weight::Zero()) {
369  return v;
370  }
371  const auto f1 = to_log_weight_(w).Value();
372  const auto f2 = to_log_weight_(v).Value();
373  if (f1 > f2) {
374  return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
375  } else {
376  return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
377  }
378  }
379 
380  double LogPlus(double f1, Weight v) const {
381  const auto f2 = to_log_weight_(v).Value();
382  if (f1 == FloatLimits<double>::PosInfinity()) {
383  return f2;
384  } else if (f1 > f2) {
385  return f2 - LogPosExp(f1 - f2);
386  } else {
387  return f1 - LogPosExp(f2 - f1);
388  }
389  }
390 
391  // Assumes f1 < f2.
392  Weight LogMinus(double f1, double f2) const {
393  if (f2 == FloatLimits<double>::PosInfinity()) {
394  return to_weight_(Log64Weight(f1));
395  } else {
396  return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
397  }
398  }
399 
400  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
401  const WeightConvert<Log64Weight, Weight> to_weight_{};
402  const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state.
403  const ssize_t arc_period_; // Saves cumulative weights per arc_period_.
404  std::shared_ptr<FastLogAccumulatorData> data_;
405  const double *state_weights_;
406  bool error_;
407 
408  FastLogAccumulator &operator=(const FastLogAccumulator &) = delete;
409 };
410 
411 // Stores shareable data for cache log accumulator copies. All copies share the
412 // same cache.
413 template <class Arc>
415  public:
416  using StateId = typename Arc::StateId;
417  using Weight = typename Arc::Weight;
418 
419  CacheLogAccumulatorData(bool gc, size_t gc_limit)
420  : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
421 
423  : cache_gc_(data.cache_gc_),
424  cache_limit_(data.cache_limit_),
425  cache_size_(0) {}
426 
427  bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
428 
429  std::vector<double> *GetWeights(StateId s) {
430  if (auto it = cache_.find(s); it != cache_.end()) {
431  it->second.recent = true;
432  return it->second.weights.get();
433  } else {
434  return nullptr;
435  }
436  }
437 
438  void AddWeights(StateId s, std::unique_ptr<std::vector<double>> weights) {
439  if (cache_gc_ && cache_size_ >= cache_limit_) GC(false);
440  if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double);
441  cache_.emplace(s, CacheState(std::move(weights), true));
442  }
443 
444  private:
445  // Cached information for a given state.
446  struct CacheState {
447  std::unique_ptr<std::vector<double>> weights; // Accumulated weights.
448  bool recent; // Has this state been accessed since last GC?
449 
450  CacheState(std::unique_ptr<std::vector<double>> weights, bool recent)
451  : weights(std::move(weights)), recent(recent) {}
452  };
453 
454  // Garbage collect: Deletes from cache states that have not been accessed
455  // since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of
456  // 'cache_limit_'. If it does not free enough memory, start deleting
457  // recently accessed states.
458  void GC(bool free_recent) {
459  auto cache_target = (2 * cache_limit_) / 3 + 1;
460  auto it = cache_.begin();
461  while (it != cache_.end() && cache_size_ > cache_target) {
462  auto &cs = it->second;
463  if (free_recent || !cs.recent) {
464  cache_size_ -= cs.weights->capacity() * sizeof(double);
465  cache_.erase(it++);
466  } else {
467  cs.recent = false;
468  ++it;
469  }
470  }
471  if (!free_recent && cache_size_ > cache_target) GC(true);
472  }
473 
474  std::unordered_map<StateId, CacheState> cache_; // Cache.
475  bool cache_gc_; // Enables garbage collection.
476  size_t cache_limit_; // # of bytes cached.
477  size_t cache_size_; // # of bytes allowed before GC.
478 
479  CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete;
480 };
481 
482 // This class accumulates arc weights using the log semiring Plus() has a
483 // WeightConvert specialization to and from log64 weights. It is similar to the
484 // FastLogAccumator. However here, the accumulated weights are pre-computed and
485 // stored only for the states that are visited. The member function Init(fst)
486 // has to be called to setup this accumulator. Space complexity is O(gc_limit).
487 template <class Arc>
489  public:
490  using StateId = typename Arc::StateId;
491  using Weight = typename Arc::Weight;
492 
493  explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
494  size_t gc_limit = 10 * 1024 * 1024)
495  : arc_limit_(arc_limit),
496  data_(std::make_shared<CacheLogAccumulatorData<Arc>>(gc, gc_limit)),
497  s_(kNoStateId),
498  error_(false) {}
499 
500  CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe = false)
501  : arc_limit_(acc.arc_limit_),
502  fst_(acc.fst_ ? acc.fst_->Copy() : nullptr),
503  data_(safe ? std::make_shared<CacheLogAccumulatorData<Arc>>(*acc.data_)
504  : acc.data_),
505  s_(kNoStateId),
506  error_(acc.error_) {}
507 
508  // Argument arc_limit specifies the minimum number of arcs to pre-compute.
509  void Init(const Fst<Arc> &fst, bool copy = false) {
510  if (!copy && fst_) {
511  FSTERROR() << "CacheLogAccumulator: Initialization error";
512  error_ = true;
513  return;
514  }
515  fst_.reset(fst.Copy());
516  }
517 
518  void SetState(StateId s, int depth = 0) {
519  if (s == s_) return;
520  s_ = s;
521  if (data_->CacheDisabled() || error_) {
522  weights_ = nullptr;
523  return;
524  }
525  if (!fst_) {
526  FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized";
527  error_ = true;
528  weights_ = nullptr;
529  return;
530  }
531  weights_ = data_->GetWeights(s);
532  if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) {
533  auto weights = std::make_unique<std::vector<double>>();
534  weights->reserve(fst_->NumArcs(s) + 1);
535  weights->push_back(FloatLimits<double>::PosInfinity());
536  // `weights` holds a reference to the weight vector, whose ownership is
537  // transferred to `data_`.
538  weights_ = weights.get();
539  data_->AddWeights(s, std::move(weights));
540  }
541  }
542 
543  Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
544 
545  template <class ArcIter>
546  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
547  if (weights_ == nullptr) {
548  auto sum = w;
549  aiter->Seek(begin);
550  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
551  sum = LogPlus(sum, aiter->Value().weight);
552  }
553  return sum;
554  } else {
555  Extend(end, aiter);
556  const auto &f1 = (*weights_)[end];
557  const auto &f2 = (*weights_)[begin];
558  if (f1 < f2) {
559  return LogPlus(w, LogMinus(f1, f2));
560  } else {
561  // Commented out for efficiency; adds Zero().
562  /*
563  auto sum = w;
564  // Explicitly computes if cumulative sum lacks precision.
565  aiter->Seek(begin);
566  for (auto pos = begin; pos < end; aiter->Next(), ++pos) {
567  sum = LogPlus(sum, aiter->Value().weight);
568  }
569  return sum;
570  */
571  return w;
572  }
573  }
574  }
575 
576  // Returns first position from aiter->Position() whose accumulated
577  // value is greater or equal to w (w.r.t. Zero() < One()). The
578  // iterator may be repositioned.
579  template <class ArcIter>
580  size_t LowerBound(Weight w, ArcIter *aiter) {
581  const auto f = to_log_weight_(w).Value();
582  auto pos = aiter->Position();
583  if (weights_) {
584  Extend(fst_->NumArcs(s_), aiter);
585  return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), f,
586  std::greater<double>()) -
587  weights_->begin() - 1;
588  } else {
589  size_t n = 0;
591  for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
592  x = LogPlus(x, aiter->Value().weight);
593  if (n >= pos && x <= f) break;
594  }
595  return n;
596  }
597  }
598 
599  bool Error() const { return error_; }
600 
601  private:
602  double LogPosExp(double x) {
603  return x == FloatLimits<double>::PosInfinity() ? 0.0
604  : log(1.0F + exp(-x));
605  }
606 
607  double LogMinusExp(double x) {
608  return x == FloatLimits<double>::PosInfinity() ? 0.0
609  : log(1.0F - exp(-x));
610  }
611 
612  Weight LogPlus(Weight w, Weight v) {
613  if (w == Weight::Zero()) {
614  return v;
615  }
616  const auto f1 = to_log_weight_(w).Value();
617  const auto f2 = to_log_weight_(v).Value();
618  if (f1 > f2) {
619  return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2)));
620  } else {
621  return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1)));
622  }
623  }
624 
625  double LogPlus(double f1, Weight v) {
626  const auto f2 = to_log_weight_(v).Value();
627  if (f1 == FloatLimits<double>::PosInfinity()) {
628  return f2;
629  } else if (f1 > f2) {
630  return f2 - LogPosExp(f1 - f2);
631  } else {
632  return f1 - LogPosExp(f2 - f1);
633  }
634  }
635 
636  // Assumes f1 < f2.
637  Weight LogMinus(double f1, double f2) {
638  if (f2 == FloatLimits<double>::PosInfinity()) {
639  return to_weight_(Log64Weight(f1));
640  } else {
641  return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1)));
642  }
643  }
644 
645  // Extends weights up to index 'end'.
646  template <class ArcIter>
647  void Extend(ssize_t end, ArcIter *aiter) {
648  if (weights_->size() <= end) {
649  for (aiter->Seek(weights_->size() - 1); weights_->size() <= end;
650  aiter->Next()) {
651  weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight));
652  }
653  }
654  }
655 
656  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
657  const WeightConvert<Log64Weight, Weight> to_weight_{};
658  ssize_t arc_limit_; // Minimum # of arcs to cache a state.
659  std::vector<double> *weights_; // Accumulated weights for cur. state.
660  // Pointee owned by `data_`.
661  std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
662  std::shared_ptr<CacheLogAccumulatorData<Arc>> data_; // Cache data.
663  StateId s_; // Current state.
664  bool error_;
665 };
666 
667 // Stores shareable data for replace accumulator copies.
668 template <class Accumulator, class T>
670  public:
671  using Arc = typename Accumulator::Arc;
672  using Label = typename Arc::Label;
673  using StateId = typename Arc::StateId;
674  using StateTable = T;
675  using StateTuple = typename StateTable::StateTuple;
676 
677  ReplaceAccumulatorData() : state_table_(nullptr) {}
678 
680  std::vector<std::unique_ptr<Accumulator>> &&accumulators)
681  : state_table_(nullptr), accumulators_(std::move(accumulators)) {}
682 
683  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
684  const StateTable *state_table) {
685  state_table_ = state_table;
686  accumulators_.resize(fst_tuples.size());
687  for (Label i = 0; i < accumulators_.size(); ++i) {
688  if (!accumulators_[i]) {
689  accumulators_[i] = std::make_unique<Accumulator>();
690  accumulators_[i]->Init(*(fst_tuples[i].second));
691  }
692  fst_array_.emplace_back(fst_tuples[i].second->Copy());
693  }
694  }
695 
696  const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); }
697 
698  Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); }
699 
700  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
701 
702  private:
703  const StateTable *state_table_;
704  std::vector<std::unique_ptr<Accumulator>> accumulators_;
705  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
706 };
707 
708 // This class accumulates weights in a ReplaceFst. The 'Init' method takes as
709 // input the argument used to build the ReplaceFst and the ReplaceFst state
710 // table. It uses accumulators of type 'Accumulator' in the underlying FSTs.
711 template <class Accumulator,
714  public:
715  using Arc = typename Accumulator::Arc;
716  using Label = typename Arc::Label;
717  using StateId = typename Arc::StateId;
718  using StateTable = T;
719  using StateTuple = typename StateTable::StateTuple;
720  using Weight = typename Arc::Weight;
721 
723  : init_(false),
724  data_(std::make_shared<
725  ReplaceAccumulatorData<Accumulator, StateTable>>()),
726  error_(false) {}
727 
729  std::vector<std::unique_ptr<Accumulator>> &&accumulators)
730  : init_(false),
731  data_(std::make_shared<ReplaceAccumulatorData<Accumulator, StateTable>>(
732  std::move(accumulators))),
733  error_(false) {}
734 
736  bool safe = false)
737  : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
738  if (!init_) {
739  FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator";
740  }
741  if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported";
742  }
743 
744  // Does not take ownership of the state table, the state table is owned by
745  // the ReplaceFst.
746  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
747  const StateTable *state_table) {
748  init_ = true;
749  data_->Init(fst_tuples, state_table);
750  }
751 
752  // Method required by LookAheadMatcher. However, ReplaceAccumulator needs to
753  // be initialized by calling the Init method above before being passed to
754  // LookAheadMatcher.
755  //
756  // TODO(allauzen): Revisit this. Consider creating a method
757  // Init(const ReplaceFst<A, T, C>&, bool) and using friendship to get access
758  // to the innards of ReplaceFst.
759  void Init(const Fst<Arc> &fst, bool copy = false) {
760  if (!init_) {
761  FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be"
762  << " initialized before being passed to LookAheadMatcher";
763  error_ = true;
764  }
765  }
766 
767  void SetState(StateId s) {
768  if (!init_) {
769  FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized";
770  error_ = true;
771  return;
772  }
773  auto tuple = data_->GetTuple(s);
774  fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based.
775  data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
776  if ((tuple.prefix_id != 0) &&
777  (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
778  offset_ = 1;
779  offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
780  } else {
781  offset_ = 0;
782  offset_weight_ = Weight::Zero();
783  }
784  aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(*data_->GetFst(fst_id_),
785  tuple.fst_state);
786  }
787 
789  if (error_) return Weight::NoWeight();
790  return data_->GetAccumulator(fst_id_)->Sum(w, v);
791  }
792 
793  template <class ArcIter>
794  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
795  if (error_) return Weight::NoWeight();
796  auto sum = begin == end ? Weight::Zero()
797  : data_->GetAccumulator(fst_id_)->Sum(
798  w, aiter_.get(), begin ? begin - offset_ : 0,
799  end - offset_);
800  if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum);
801  return sum;
802  }
803 
804  bool Error() const { return error_; }
805 
806  private:
807  bool init_;
808  std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_;
809  Label fst_id_;
810  size_t offset_;
811  Weight offset_weight_;
812  std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
813  bool error_;
814 };
815 
816 // SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it
817 // are always thread-safe copies.
818 template <class Accumulator, class T>
820  public:
821  using Arc = typename Accumulator::Arc;
822  using StateId = typename Arc::StateId;
823  using Label = typename Arc::Label;
824  using Weight = typename Arc::Weight;
825  using StateTable = T;
826  using StateTuple = typename StateTable::StateTuple;
827 
828  SafeReplaceAccumulator() = default;
829 
831  : SafeReplaceAccumulator(copy) {}
832 
834  const std::vector<Accumulator> &accumulators) {
835  for (const auto &accumulator : accumulators) {
836  accumulators_.emplace_back(accumulator, true);
837  }
838  }
839 
840  void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples,
841  const StateTable *state_table) {
842  state_table_ = state_table;
843  for (Label i = 0; i < fst_tuples.size(); ++i) {
844  if (i == accumulators_.size()) {
845  accumulators_.resize(accumulators_.size() + 1);
846  accumulators_[i].Init(*(fst_tuples[i].second));
847  }
848  fst_array_.emplace_back(fst_tuples[i].second->Copy(true));
849  }
850  init_ = true;
851  }
852 
853  void Init(const Fst<Arc> &fst, bool copy = false) {
854  if (!init_) {
855  FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be"
856  << " initialized before being passed to LookAheadMatcher";
857  error_ = true;
858  }
859  }
860 
861  void SetState(StateId s) {
862  auto tuple = state_table_->Tuple(s);
863  fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based
864  GetAccumulator(fst_id_)->SetState(tuple.fst_state);
865  offset_ = 0;
866  offset_weight_ = Weight::Zero();
867  const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state);
868  if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) {
869  offset_ = 1;
870  offset_weight_ = final_weight;
871  }
872  aiter_.Set(*GetFst(fst_id_), tuple.fst_state);
873  }
874 
876  if (error_) return Weight::NoWeight();
877  return GetAccumulator(fst_id_)->Sum(w, v);
878  }
879 
880  template <class ArcIter>
881  Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) {
882  if (error_) return Weight::NoWeight();
883  if (begin == end) return Weight::Zero();
884  auto sum = GetAccumulator(fst_id_)->Sum(
885  w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_);
886  if (begin == 0 && end != 0 && offset_ > 0) {
887  sum = Sum(offset_weight_, sum);
888  }
889  return sum;
890  }
891 
892  bool Error() const { return error_; }
893 
894  private:
895  class ArcIteratorPtr {
896  public:
897  ArcIteratorPtr() = default;
898 
899  ArcIteratorPtr(const ArcIteratorPtr &copy) {}
900 
901  void Set(const Fst<Arc> &fst, StateId state_id) {
902  ptr_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst, state_id);
903  }
904 
905  ArcIterator<Fst<Arc>> *get() { return ptr_.get(); }
906 
907  private:
908  std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_;
909  };
910 
911  Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; }
912 
913  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
914 
915  const StateTable *state_table_;
916  std::vector<Accumulator> accumulators_;
917  std::vector<std::shared_ptr<Fst<Arc>>> fst_array_;
918  ArcIteratorPtr aiter_;
919  bool init_ = false;
920  bool error_ = false;
921  Label fst_id_;
922  size_t offset_;
923  Weight offset_weight_;
924 };
925 
926 } // namespace fst
927 
928 #endif // FST_ACCUMULATOR_H_
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:101
CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe=false)
Definition: accumulator.h:500
SafeReplaceAccumulator(const std::vector< Accumulator > &accumulators)
Definition: accumulator.h:833
constexpr bool Error() const
Definition: accumulator.h:74
const Fst< Arc > * GetFst(size_t i) const
Definition: accumulator.h:700
typename Arc::Weight Weight
Definition: accumulator.h:824
constexpr uint8_t kArcNoCache
Definition: fst.h:452
void SetState(StateId s)
Definition: accumulator.h:259
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
CacheLogAccumulatorData(const CacheLogAccumulatorData< Arc > &data)
Definition: accumulator.h:422
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:63
std::shared_ptr< FastLogAccumulatorData > GetData() const
Definition: accumulator.h:354
void SetState(StateId s)
Definition: accumulator.h:861
typename Arc::Weight Weight
Definition: accumulator.h:720
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:546
constexpr uint8_t kArcFlags
Definition: fst.h:455
MutableFastLogAccumulatorData(int arc_limit, int arc_period)
Definition: accumulator.h:195
typename Arc::Weight Weight
Definition: accumulator.h:53
Weight Sum() const
Definition: weight.h:224
void Init(int num_weights, const double *weights, int num_positions, const int *weight_positions)
Definition: accumulator.h:171
void AddWeights(StateId s, std::unique_ptr< std::vector< double >> weights)
Definition: accumulator.h:438
typename Arc::Label Label
Definition: accumulator.h:672
void SetData(std::vector< double > *weights, std::vector< int > *weight_positions) override
Definition: accumulator.h:200
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:66
typename Arc::Weight Weight
Definition: accumulator.h:491
FastLogAccumulatorData(int arc_limit, int arc_period)
Definition: accumulator.h:136
const int * WeightPositions() const
Definition: accumulator.h:154
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:875
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:719
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:840
ReplaceAccumulatorData(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
Definition: accumulator.h:679
constexpr int kNoStateId
Definition: fst.h:196
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:683
const Arc & Value() const
Definition: fst.h:536
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:794
static constexpr T PosInfinity()
Definition: float-weight.h:61
ReplaceAccumulator(const ReplaceAccumulator< Accumulator, StateTable > &acc, bool safe=false)
Definition: accumulator.h:735
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:509
#define FSTERROR()
Definition: util.h:56
void Init(const FST &fst, bool copy=false)
Definition: accumulator.h:321
CacheLogAccumulatorData(bool gc, size_t gc_limit)
Definition: accumulator.h:419
void SetFlags(uint8_t flags, uint8_t mask)
Definition: fst.h:570
bool IsMutable() const override
Definition: accumulator.h:198
Weight Add(const Weight &w)
Definition: weight.h:219
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:472
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:788
typename Arc::StateId StateId
Definition: accumulator.h:228
typename Arc::StateId StateId
Definition: accumulator.h:416
SafeReplaceAccumulator(const SafeReplaceAccumulator &copy, bool safe)
Definition: accumulator.h:830
const StateTuple & GetTuple(StateId s) const
Definition: accumulator.h:696
typename Arc::StateId StateId
Definition: accumulator.h:490
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:675
virtual Fst * Copy(bool safe=false) const =0
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
Definition: accumulator.h:746
typename Arc::Weight Weight
Definition: accumulator.h:417
FastLogAccumulator(std::shared_ptr< FastLogAccumulatorData > data)
Definition: accumulator.h:241
typename Arc::Label Label
Definition: accumulator.h:716
bool Done() const
Definition: fst.h:532
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:759
void SetState(StateId s)
Definition: accumulator.h:96
typename Accumulator::Arc Arc
Definition: accumulator.h:821
typename Accumulator::Arc Arc
Definition: accumulator.h:671
LogAccumulator(const LogAccumulator &acc, bool safe=false)
Definition: accumulator.h:92
double LogPosExp(double x)
Definition: float-weight.h:477
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const
Definition: accumulator.h:272
void SetState(StateId s, int depth=0)
Definition: accumulator.h:518
constexpr uint8_t kArcWeightValue
Definition: fst.h:448
FastLogAccumulator(const FastLogAccumulator &acc, bool safe=false)
Definition: accumulator.h:250
typename StateTable::StateTuple StateTuple
Definition: accumulator.h:826
std::vector< double > * GetWeights(StateId s)
Definition: accumulator.h:429
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:543
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:179
typename Arc::StateId StateId
Definition: accumulator.h:673
void SetState(StateId state)
Definition: accumulator.h:61
const double * Weights() const
Definition: accumulator.h:148
bool Done() const
Definition: fst.h:415
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Definition: accumulator.h:881
Accumulator * GetAccumulator(size_t i)
Definition: accumulator.h:698
typename Arc::StateId StateId
Definition: accumulator.h:52
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:853
typename Accumulator::Arc Arc
Definition: accumulator.h:715
ReplaceAccumulator(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
Definition: accumulator.h:728
Weight Sum(Weight w, Weight v) const
Definition: accumulator.h:269
typename Arc::StateId StateId
Definition: accumulator.h:717
FastLogAccumulator(ssize_t arc_limit=20, ssize_t arc_period=10)
Definition: accumulator.h:231
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:580
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:98
typename Arc::StateId StateId
Definition: accumulator.h:822
CacheLogAccumulator(ssize_t arc_limit=10, bool gc=false, size_t gc_limit=10 *1024 *1024)
Definition: accumulator.h:493
DefaultAccumulator(const DefaultAccumulator &acc, bool safe=false)
Definition: accumulator.h:57
void SetState(StateId s)
Definition: accumulator.h:767
typename Arc::Weight Weight
Definition: accumulator.h:88
typename Arc::Weight Weight
Definition: accumulator.h:229
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:59
void Init(const Fst< Arc > &fst, bool copy=false)
Definition: accumulator.h:94
typename Arc::Label Label
Definition: accumulator.h:823
constexpr bool Error() const
Definition: accumulator.h:110
void Next()
Definition: fst.h:540
typename Arc::StateId StateId
Definition: accumulator.h:87