20 #ifndef FST_ACCUMULATOR_H_ 21 #define FST_ACCUMULATOR_H_ 36 #include <unordered_map> 59 template <
class ArcIter>
63 for (
auto pos = begin; pos < end; aiter->Next(), ++pos)
64 adder.
Add(aiter->Value().weight);
68 constexpr
bool Error()
const {
return false; }
94 template <
class ArcIter>
98 for (
auto pos = begin; pos < end; aiter->Next(), ++pos) {
99 sum = LogPlus(sum, aiter->Value().weight);
104 constexpr
bool Error()
const {
return false; }
108 if (w == Weight::Zero()) {
111 const auto f1 = to_log_weight_(w).Value();
112 const auto f2 = to_log_weight_(v).Value();
131 : arc_limit_(arc_limit),
132 arc_period_(arc_period),
133 weights_ptr_(nullptr),
135 weight_positions_ptr_(nullptr),
142 const double *
Weights()
const {
return weights_ptr_; }
157 virtual bool IsMutable()
const = 0;
161 virtual void SetData(std::vector<double> *weights,
162 std::vector<int> *weight_positions) = 0;
165 void Init(
int num_weights,
const double *weights,
int num_positions,
166 const int *weight_positions) {
167 weights_ptr_ = weights;
168 num_weights_ = num_weights;
169 weight_positions_ptr_ = weight_positions;
170 num_positions_ = num_positions;
174 const int arc_limit_;
175 const int arc_period_;
176 const double *weights_ptr_;
178 const int *weight_positions_ptr_;
195 std::vector<int> *weight_positions)
override {
196 weights_.swap(*weights);
197 weight_positions_.swap(*weight_positions);
198 Init(weights_.size(), weights_.data(), weight_positions_.size(),
199 weight_positions_.data());
203 std::vector<double> weights_;
204 std::vector<int> weight_positions_;
228 arc_limit_(arc_limit),
229 arc_period_(arc_period),
232 state_weights_(nullptr),
238 arc_limit_(data->ArcLimit()),
239 arc_period_(data->ArcPeriod()),
241 state_weights_(nullptr),
247 arc_limit_(acc.arc_limit_),
248 arc_period_(acc.arc_period_),
250 state_weights_(nullptr),
251 error_(acc.error_) {}
254 const auto *weights = data_->Weights();
255 const auto *weight_positions = data_->WeightPositions();
256 state_weights_ =
nullptr;
257 if (s < data_->NumPositions()) {
258 const auto pos = weight_positions[s];
259 if (pos >= 0) state_weights_ = &(weights[pos]);
265 template <
class ArcIter>
267 if (error_)
return Weight::NoWeight();
270 ssize_t index_begin = -1;
271 ssize_t index_end = -1;
272 ssize_t stored_begin = end;
273 ssize_t stored_end = end;
274 if (state_weights_) {
275 index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0;
276 index_end = end / arc_period_;
277 stored_begin = index_begin * arc_period_;
278 stored_end = index_end * arc_period_;
281 if (begin < stored_begin) {
282 const auto pos_end = std::min(stored_begin, end);
284 for (
auto pos = begin; pos < pos_end; aiter->Next(), ++pos) {
285 sum = LogPlus(sum, aiter->Value().weight);
289 if (stored_begin < stored_end) {
290 const auto f1 = state_weights_[index_end];
291 const auto f2 = state_weights_[index_begin];
292 if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2));
304 if (stored_end < end) {
305 const auto pos_start = std::max(stored_begin, stored_end);
306 aiter->Seek(pos_start);
307 for (
auto pos = pos_start; pos < end; aiter->Next(), ++pos) {
308 sum = LogPlus(sum, aiter->Value().weight);
316 if (copy || !data_->IsMutable())
return;
317 if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) {
318 FSTERROR() <<
"FastLogAccumulator: Initialization error";
322 std::vector<double> weights;
323 std::vector<int> weight_positions;
326 const auto s = siter.Value();
327 if (fst.NumArcs(s) >= arc_limit_) {
329 if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1);
330 weight_positions[s] = weights.size();
331 weights.push_back(sum);
335 for (; !aiter.
Done(); aiter.
Next()) {
336 const auto &arc = aiter.
Value();
337 sum = LogPlus(sum, arc.weight);
339 if (++narcs % arc_period_ == 0) weights.push_back(sum);
343 data_->SetData(&weights, &weight_positions);
346 bool Error()
const {
return error_; }
348 std::shared_ptr<FastLogAccumulatorData>
GetData()
const {
return data_; }
353 : log(1.0F + exp(-x));
356 static double LogMinusExp(
double x) {
358 : log(1.0F - exp(-x));
362 if (w == Weight::Zero()) {
365 const auto f1 = to_log_weight_(w).Value();
366 const auto f2 = to_log_weight_(v).Value();
374 double LogPlus(
double f1,
Weight v)
const {
375 const auto f2 = to_log_weight_(v).Value();
378 }
else if (f1 > f2) {
386 Weight LogMinus(
double f1,
double f2)
const {
390 return to_weight_(
Log64Weight(f1 - LogMinusExp(f2 - f1)));
396 const ssize_t arc_limit_;
397 const ssize_t arc_period_;
398 std::shared_ptr<FastLogAccumulatorData> data_;
399 const double *state_weights_;
414 : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
417 : cache_gc_(data.cache_gc_),
418 cache_limit_(data.cache_limit_),
424 auto it = cache_.find(s);
425 if (it != cache_.end()) {
426 it->second.recent =
true;
427 return it->second.weights.get();
434 if (cache_gc_ && cache_size_ >= cache_limit_) GC(
false);
435 cache_.emplace(s, CacheState(weights,
true));
436 if (cache_gc_) cache_size_ += weights->capacity() *
sizeof(double);
442 std::unique_ptr<std::vector<double>> weights;
445 CacheState(std::vector<double> *weights,
bool recent)
446 : weights(weights), recent(recent) {}
453 void GC(
bool free_recent) {
454 auto cache_target = (2 * cache_limit_) / 3 + 1;
455 auto it = cache_.begin();
456 while (it != cache_.end() && cache_size_ > cache_target) {
457 auto &cs = it->second;
458 if (free_recent || !cs.recent) {
459 cache_size_ -= cs.weights->capacity() *
sizeof(double);
466 if (!free_recent && cache_size_ > cache_target) GC(
true);
469 std::unordered_map<StateId, CacheState> cache_;
489 size_t gc_limit = 10 * 1024 * 1024)
490 : arc_limit_(arc_limit),
496 : arc_limit_(acc.arc_limit_),
497 fst_(acc.fst_ ? acc.fst_->Copy() : nullptr),
501 error_(acc.error_) {}
506 FSTERROR() <<
"CacheLogAccumulator: Initialization error";
510 fst_.reset(fst.
Copy());
516 if (data_->CacheDisabled() || error_) {
521 FSTERROR() <<
"CacheLogAccumulator::SetState: Incorrectly initialized";
526 weights_ = data_->GetWeights(s);
527 if ((weights_ ==
nullptr) && (fst_->NumArcs(s) >= arc_limit_)) {
528 weights_ =
new std::vector<double>;
529 weights_->reserve(fst_->NumArcs(s) + 1);
531 data_->AddWeights(s, weights_);
537 template <
class ArcIter>
539 if (weights_ ==
nullptr) {
542 for (
auto pos = begin; pos < end; aiter->Next(), ++pos) {
543 sum = LogPlus(sum, aiter->Value().weight);
548 const auto &f1 = (*weights_)[end];
549 const auto &f2 = (*weights_)[begin];
551 return LogPlus(w, LogMinus(f1, f2));
571 template <
class ArcIter>
573 const auto f = to_log_weight_(w).Value();
574 auto pos = aiter->Position();
576 Extend(fst_->NumArcs(s_), aiter);
577 return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), f,
578 std::greater<double>()) -
579 weights_->begin() - 1;
583 for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
584 x = LogPlus(x, aiter->Value().weight);
585 if (n >= pos && x <= f)
break;
591 bool Error()
const {
return error_; }
596 : log(1.0F + exp(-x));
599 double LogMinusExp(
double x) {
601 : log(1.0F - exp(-x));
605 if (w == Weight::Zero()) {
608 const auto f1 = to_log_weight_(w).Value();
609 const auto f2 = to_log_weight_(v).Value();
617 double LogPlus(
double f1,
Weight v) {
618 const auto f2 = to_log_weight_(v).Value();
621 }
else if (f1 > f2) {
629 Weight LogMinus(
double f1,
double f2) {
633 return to_weight_(
Log64Weight(f1 - LogMinusExp(f2 - f1)));
638 template <
class ArcIter>
639 void Extend(ssize_t end, ArcIter *aiter) {
640 if (weights_->size() <= end) {
641 for (aiter->Seek(weights_->size() - 1); weights_->size() <= end;
643 weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight));
651 std::vector<double> *weights_;
652 std::unique_ptr<const Fst<Arc>> fst_;
653 std::shared_ptr<CacheLogAccumulatorData<Arc>> data_;
659 template <
class Accumulator,
class T>
662 using Arc =
typename Accumulator::Arc;
671 std::vector<std::unique_ptr<Accumulator>> &&accumulators)
672 : state_table_(nullptr), accumulators_(std::move(accumulators)) {}
676 state_table_ = state_table;
677 accumulators_.resize(fst_tuples.size());
678 for (
Label i = 0; i < accumulators_.size(); ++i) {
679 if (!accumulators_[i]) {
680 accumulators_[i] = std::make_unique<Accumulator>();
681 accumulators_[i]->Init(*(fst_tuples[i].second));
683 fst_array_.emplace_back(fst_tuples[i].second->Copy());
695 std::vector<std::unique_ptr<Accumulator>> accumulators_;
696 std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
702 template <
class Accumulator,
706 using Arc =
typename Accumulator::Arc;
715 data_(std::make_shared<
720 std::vector<std::unique_ptr<Accumulator>> &&accumulators)
723 std::move(accumulators))),
728 : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
730 FSTERROR() <<
"ReplaceAccumulator: Can't copy unintialized accumulator";
732 if (safe)
FSTERROR() <<
"ReplaceAccumulator: Safe copy not supported";
740 data_->Init(fst_tuples, state_table);
752 FSTERROR() <<
"ReplaceAccumulator::Init: Accumulator needs to be" 753 <<
" initialized before being passed to LookAheadMatcher";
760 FSTERROR() <<
"ReplaceAccumulator::SetState: Incorrectly initialized";
764 auto tuple = data_->GetTuple(s);
765 fst_id_ = tuple.fst_id - 1;
766 data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
767 if ((tuple.prefix_id != 0) &&
768 (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
770 offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
773 offset_weight_ = Weight::Zero();
775 aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(*data_->GetFst(fst_id_),
780 if (error_)
return Weight::NoWeight();
781 return data_->GetAccumulator(fst_id_)->Sum(w, v);
784 template <
class ArcIter>
786 if (error_)
return Weight::NoWeight();
787 auto sum = begin == end ? Weight::Zero()
788 : data_->GetAccumulator(fst_id_)->Sum(
789 w, aiter_.get(), begin ? begin - offset_ : 0,
791 if (begin == 0 && end != 0 && offset_ > 0) sum =
Sum(offset_weight_, sum);
795 bool Error()
const {
return error_; }
799 std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_;
803 std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
809 template <
class Accumulator,
class T>
812 using Arc =
typename Accumulator::Arc;
825 const std::vector<Accumulator> &accumulators) {
826 for (
const auto &accumulator : accumulators) {
827 accumulators_.emplace_back(accumulator,
true);
833 state_table_ = state_table;
834 for (
Label i = 0; i < fst_tuples.size(); ++i) {
835 if (i == accumulators_.size()) {
836 accumulators_.resize(accumulators_.size() + 1);
837 accumulators_[i].Init(*(fst_tuples[i].second));
839 fst_array_.emplace_back(fst_tuples[i].second->Copy(
true));
846 FSTERROR() <<
"SafeReplaceAccumulator::Init: Accumulator needs to be" 847 <<
" initialized before being passed to LookAheadMatcher";
853 auto tuple = state_table_->Tuple(s);
854 fst_id_ = tuple.fst_id - 1;
855 GetAccumulator(fst_id_)->SetState(tuple.fst_state);
857 offset_weight_ = Weight::Zero();
858 const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state);
859 if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) {
861 offset_weight_ = final_weight;
863 aiter_.Set(*GetFst(fst_id_), tuple.fst_state);
867 if (error_)
return Weight::NoWeight();
868 return GetAccumulator(fst_id_)->Sum(w, v);
871 template <
class ArcIter>
873 if (error_)
return Weight::NoWeight();
874 if (begin == end)
return Weight::Zero();
875 auto sum = GetAccumulator(fst_id_)->Sum(
876 w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_);
877 if (begin == 0 && end != 0 && offset_ > 0) {
878 sum =
Sum(offset_weight_, sum);
883 bool Error()
const {
return error_; }
886 class ArcIteratorPtr {
890 ArcIteratorPtr(
const ArcIteratorPtr ©) {}
893 ptr_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst, state_id);
899 std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_;
902 Accumulator *GetAccumulator(
size_t i) {
return &accumulators_[i]; }
904 const Fst<Arc> *GetFst(
size_t i)
const {
return fst_array_[i].get(); }
907 std::vector<Accumulator> accumulators_;
908 std::vector<std::shared_ptr<Fst<Arc>>> fst_array_;
909 ArcIteratorPtr aiter_;
919 #endif // FST_ACCUMULATOR_H_
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe=false)
SafeReplaceAccumulator(const std::vector< Accumulator > &accumulators)
constexpr bool Error() const
const Fst< Arc > * GetFst(size_t i) const
typename Arc::Weight Weight
constexpr uint8_t kArcNoCache
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
CacheLogAccumulatorData(const CacheLogAccumulatorData< Arc > &data)
Weight Sum(Weight w, Weight v)
std::shared_ptr< FastLogAccumulatorData > GetData() const
typename Arc::Weight Weight
bool CacheDisabled() const
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
constexpr uint8_t kArcFlags
MutableFastLogAccumulatorData(int arc_limit, int arc_period)
typename Arc::Weight Weight
virtual ~FastLogAccumulatorData()
void Init(int num_weights, const double *weights, int num_positions, const int *weight_positions)
typename Arc::Label Label
void SetData(std::vector< double > *weights, std::vector< int > *weight_positions) override
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
typename Arc::Weight Weight
FastLogAccumulatorData(int arc_limit, int arc_period)
const int * WeightPositions() const
Weight Sum(Weight w, Weight v)
typename StateTable::StateTuple StateTuple
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
ReplaceAccumulatorData(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
const Arc & Value() const
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
static constexpr T PosInfinity()
ReplaceAccumulator(const ReplaceAccumulator< Accumulator, StateTable > &acc, bool safe=false)
void Init(const Fst< Arc > &fst, bool copy=false)
void Init(const FST &fst, bool copy=false)
CacheLogAccumulatorData(bool gc, size_t gc_limit)
void SetFlags(uint8_t flags, uint8_t mask)
bool IsMutable() const override
Weight Add(const Weight &w)
LogWeightTpl< double > Log64Weight
Weight Sum(Weight w, Weight v)
typename Arc::StateId StateId
typename Arc::StateId StateId
SafeReplaceAccumulator(const SafeReplaceAccumulator ©, bool safe)
const StateTuple & GetTuple(StateId s) const
typename Arc::StateId StateId
typename StateTable::StateTuple StateTuple
virtual Fst * Copy(bool safe=false) const =0
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
typename Arc::Weight Weight
FastLogAccumulator(std::shared_ptr< FastLogAccumulatorData > data)
typename Arc::Label Label
void AddWeights(StateId s, std::vector< double > *weights)
void Init(const Fst< Arc > &fst, bool copy=false)
typename Accumulator::Arc Arc
typename Accumulator::Arc Arc
LogAccumulator(const LogAccumulator &acc, bool safe=false)
double LogPosExp(double x)
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const
void SetState(StateId s, int depth=0)
constexpr uint8_t kArcWeightValue
FastLogAccumulator(const FastLogAccumulator &acc, bool safe=false)
typename StateTable::StateTuple StateTuple
std::vector< double > * GetWeights(StateId s)
Weight Sum(Weight w, Weight v)
Arc::StateId CountStates(const Fst< Arc > &fst)
typename Arc::StateId StateId
void SetState(StateId state)
const double * Weights() const
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Accumulator * GetAccumulator(size_t i)
typename Arc::StateId StateId
void Init(const Fst< Arc > &fst, bool copy=false)
typename Accumulator::Arc Arc
ReplaceAccumulator(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
Weight Sum(Weight w, Weight v) const
typename Arc::StateId StateId
FastLogAccumulator(ssize_t arc_limit=20, ssize_t arc_period=10)
size_t LowerBound(Weight w, ArcIter *aiter)
Weight Sum(Weight w, Weight v)
typename Arc::StateId StateId
CacheLogAccumulator(ssize_t arc_limit=10, bool gc=false, size_t gc_limit=10 *1024 *1024)
DefaultAccumulator(const DefaultAccumulator &acc, bool safe=false)
typename Arc::Weight Weight
typename Arc::Weight Weight
void Init(const Fst< Arc > &fst, bool copy=false)
void Init(const Fst< Arc > &fst, bool copy=false)
typename Arc::Label Label
constexpr bool Error() const
typename Arc::StateId StateId