20 #ifndef FST_ACCUMULATOR_H_ 21 #define FST_ACCUMULATOR_H_ 23 #include <sys/types.h> 42 #include <unordered_map> 65 template <
class ArcIter>
69 for (
auto pos = begin; pos < end; aiter->Next(), ++pos)
70 adder.
Add(aiter->Value().weight);
74 constexpr
bool Error()
const {
return false; }
100 template <
class ArcIter>
104 for (
auto pos = begin; pos < end; aiter->Next(), ++pos) {
105 sum = LogPlus(sum, aiter->Value().weight);
110 constexpr
bool Error()
const {
return false; }
114 if (w == Weight::Zero()) {
117 const auto f1 = to_log_weight_(w).Value();
118 const auto f2 = to_log_weight_(v).Value();
137 : arc_limit_(arc_limit),
138 arc_period_(arc_period),
139 weights_ptr_(nullptr),
141 weight_positions_ptr_(nullptr),
148 const double *
Weights()
const {
return weights_ptr_; }
163 virtual bool IsMutable()
const = 0;
167 virtual void SetData(std::vector<double> *weights,
168 std::vector<int> *weight_positions) = 0;
171 void Init(
int num_weights,
const double *weights,
int num_positions,
172 const int *weight_positions) {
173 weights_ptr_ = weights;
174 num_weights_ = num_weights;
175 weight_positions_ptr_ = weight_positions;
176 num_positions_ = num_positions;
180 const int arc_limit_;
181 const int arc_period_;
182 const double *weights_ptr_;
184 const int *weight_positions_ptr_;
201 std::vector<int> *weight_positions)
override {
202 weights_.swap(*weights);
203 weight_positions_.swap(*weight_positions);
204 Init(weights_.size(), weights_.data(), weight_positions_.size(),
205 weight_positions_.data());
209 std::vector<double> weights_;
210 std::vector<int> weight_positions_;
234 arc_limit_(arc_limit),
235 arc_period_(arc_period),
238 state_weights_(nullptr),
244 arc_limit_(data->ArcLimit()),
245 arc_period_(data->ArcPeriod()),
247 state_weights_(nullptr),
253 arc_limit_(acc.arc_limit_),
254 arc_period_(acc.arc_period_),
256 state_weights_(nullptr),
257 error_(acc.error_) {}
260 const auto *weights = data_->Weights();
261 const auto *weight_positions = data_->WeightPositions();
262 state_weights_ =
nullptr;
263 if (s < data_->NumPositions()) {
264 const auto pos = weight_positions[s];
265 if (pos >= 0) state_weights_ = &(weights[pos]);
271 template <
class ArcIter>
273 if (error_)
return Weight::NoWeight();
276 ssize_t index_begin = -1;
277 ssize_t index_end = -1;
278 ssize_t stored_begin = end;
279 ssize_t stored_end = end;
280 if (state_weights_) {
281 index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0;
282 index_end = end / arc_period_;
283 stored_begin = index_begin * arc_period_;
284 stored_end = index_end * arc_period_;
287 if (begin < stored_begin) {
288 const auto pos_end = std::min(stored_begin, end);
290 for (
auto pos = begin; pos < pos_end; aiter->Next(), ++pos) {
291 sum = LogPlus(sum, aiter->Value().weight);
295 if (stored_begin < stored_end) {
296 const auto f1 = state_weights_[index_end];
297 const auto f2 = state_weights_[index_begin];
298 if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2));
310 if (stored_end < end) {
311 const auto pos_start = std::max(stored_begin, stored_end);
312 aiter->Seek(pos_start);
313 for (
auto pos = pos_start; pos < end; aiter->Next(), ++pos) {
314 sum = LogPlus(sum, aiter->Value().weight);
322 if (copy || !data_->IsMutable())
return;
323 if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) {
324 FSTERROR() <<
"FastLogAccumulator: Initialization error";
328 std::vector<double> weights;
329 std::vector<int> weight_positions;
332 const auto s = siter.Value();
333 if (fst.NumArcs(s) >= arc_limit_) {
335 if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1);
336 weight_positions[s] = weights.size();
337 weights.push_back(sum);
341 for (; !aiter.
Done(); aiter.
Next()) {
342 const auto &arc = aiter.
Value();
343 sum = LogPlus(sum, arc.weight);
345 if (++narcs % arc_period_ == 0) weights.push_back(sum);
349 data_->SetData(&weights, &weight_positions);
352 bool Error()
const {
return error_; }
354 std::shared_ptr<FastLogAccumulatorData>
GetData()
const {
return data_; }
359 : log(1.0F + exp(-x));
362 static double LogMinusExp(
double x) {
364 : log(1.0F - exp(-x));
368 if (w == Weight::Zero()) {
371 const auto f1 = to_log_weight_(w).Value();
372 const auto f2 = to_log_weight_(v).Value();
380 double LogPlus(
double f1,
Weight v)
const {
381 const auto f2 = to_log_weight_(v).Value();
384 }
else if (f1 > f2) {
392 Weight LogMinus(
double f1,
double f2)
const {
396 return to_weight_(
Log64Weight(f1 - LogMinusExp(f2 - f1)));
402 const ssize_t arc_limit_;
403 const ssize_t arc_period_;
404 std::shared_ptr<FastLogAccumulatorData> data_;
405 const double *state_weights_;
420 : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
423 : cache_gc_(data.cache_gc_),
424 cache_limit_(data.cache_limit_),
430 if (
auto it = cache_.find(s); it != cache_.end()) {
431 it->second.recent =
true;
432 return it->second.weights.get();
439 if (cache_gc_ && cache_size_ >= cache_limit_) GC(
false);
440 if (cache_gc_) cache_size_ += weights->capacity() *
sizeof(double);
441 cache_.emplace(s, CacheState(std::move(weights),
true));
447 std::unique_ptr<std::vector<double>> weights;
450 CacheState(std::unique_ptr<std::vector<double>> weights,
bool recent)
451 : weights(std::move(weights)), recent(recent) {}
458 void GC(
bool free_recent) {
459 auto cache_target = (2 * cache_limit_) / 3 + 1;
460 auto it = cache_.begin();
461 while (it != cache_.end() && cache_size_ > cache_target) {
462 auto &cs = it->second;
463 if (free_recent || !cs.recent) {
464 cache_size_ -= cs.weights->capacity() *
sizeof(double);
471 if (!free_recent && cache_size_ > cache_target) GC(
true);
474 std::unordered_map<StateId, CacheState> cache_;
494 size_t gc_limit = 10 * 1024 * 1024)
495 : arc_limit_(arc_limit),
501 : arc_limit_(acc.arc_limit_),
502 fst_(acc.fst_ ? acc.fst_->Copy() : nullptr),
506 error_(acc.error_) {}
511 FSTERROR() <<
"CacheLogAccumulator: Initialization error";
515 fst_.reset(fst.
Copy());
521 if (data_->CacheDisabled() || error_) {
526 FSTERROR() <<
"CacheLogAccumulator::SetState: Incorrectly initialized";
531 weights_ = data_->GetWeights(s);
532 if ((weights_ ==
nullptr) && (fst_->NumArcs(s) >= arc_limit_)) {
533 auto weights = std::make_unique<std::vector<double>>();
534 weights->reserve(fst_->NumArcs(s) + 1);
538 weights_ = weights.get();
539 data_->AddWeights(s, std::move(weights));
545 template <
class ArcIter>
547 if (weights_ ==
nullptr) {
550 for (
auto pos = begin; pos < end; aiter->Next(), ++pos) {
551 sum = LogPlus(sum, aiter->Value().weight);
556 const auto &f1 = (*weights_)[end];
557 const auto &f2 = (*weights_)[begin];
559 return LogPlus(w, LogMinus(f1, f2));
579 template <
class ArcIter>
581 const auto f = to_log_weight_(w).Value();
582 auto pos = aiter->Position();
584 Extend(fst_->NumArcs(s_), aiter);
585 return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), f,
586 std::greater<double>()) -
587 weights_->begin() - 1;
591 for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
592 x = LogPlus(x, aiter->Value().weight);
593 if (n >= pos && x <= f)
break;
599 bool Error()
const {
return error_; }
604 : log(1.0F + exp(-x));
607 double LogMinusExp(
double x) {
609 : log(1.0F - exp(-x));
613 if (w == Weight::Zero()) {
616 const auto f1 = to_log_weight_(w).Value();
617 const auto f2 = to_log_weight_(v).Value();
625 double LogPlus(
double f1,
Weight v) {
626 const auto f2 = to_log_weight_(v).Value();
629 }
else if (f1 > f2) {
637 Weight LogMinus(
double f1,
double f2) {
641 return to_weight_(
Log64Weight(f1 - LogMinusExp(f2 - f1)));
646 template <
class ArcIter>
647 void Extend(ssize_t end, ArcIter *aiter) {
648 if (weights_->size() <= end) {
649 for (aiter->Seek(weights_->size() - 1); weights_->size() <= end;
651 weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight));
659 std::vector<double> *weights_;
661 std::unique_ptr<const Fst<Arc>> fst_;
662 std::shared_ptr<CacheLogAccumulatorData<Arc>> data_;
668 template <
class Accumulator,
class T>
671 using Arc =
typename Accumulator::Arc;
680 std::vector<std::unique_ptr<Accumulator>> &&accumulators)
681 : state_table_(nullptr), accumulators_(std::move(accumulators)) {}
685 state_table_ = state_table;
686 accumulators_.resize(fst_tuples.size());
687 for (
Label i = 0; i < accumulators_.size(); ++i) {
688 if (!accumulators_[i]) {
689 accumulators_[i] = std::make_unique<Accumulator>();
690 accumulators_[i]->Init(*(fst_tuples[i].second));
692 fst_array_.emplace_back(fst_tuples[i].second->Copy());
704 std::vector<std::unique_ptr<Accumulator>> accumulators_;
705 std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
711 template <
class Accumulator,
715 using Arc =
typename Accumulator::Arc;
724 data_(std::make_shared<
729 std::vector<std::unique_ptr<Accumulator>> &&accumulators)
732 std::move(accumulators))),
737 : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
739 FSTERROR() <<
"ReplaceAccumulator: Can't copy unintialized accumulator";
741 if (safe)
FSTERROR() <<
"ReplaceAccumulator: Safe copy not supported";
749 data_->Init(fst_tuples, state_table);
761 FSTERROR() <<
"ReplaceAccumulator::Init: Accumulator needs to be" 762 <<
" initialized before being passed to LookAheadMatcher";
769 FSTERROR() <<
"ReplaceAccumulator::SetState: Incorrectly initialized";
773 auto tuple = data_->GetTuple(s);
774 fst_id_ = tuple.fst_id - 1;
775 data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
776 if ((tuple.prefix_id != 0) &&
777 (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
779 offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
782 offset_weight_ = Weight::Zero();
784 aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(*data_->GetFst(fst_id_),
789 if (error_)
return Weight::NoWeight();
790 return data_->GetAccumulator(fst_id_)->Sum(w, v);
793 template <
class ArcIter>
795 if (error_)
return Weight::NoWeight();
796 auto sum = begin == end ? Weight::Zero()
797 : data_->GetAccumulator(fst_id_)->Sum(
798 w, aiter_.get(), begin ? begin - offset_ : 0,
800 if (begin == 0 && end != 0 && offset_ > 0) sum =
Sum(offset_weight_, sum);
804 bool Error()
const {
return error_; }
808 std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_;
812 std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
818 template <
class Accumulator,
class T>
821 using Arc =
typename Accumulator::Arc;
834 const std::vector<Accumulator> &accumulators) {
835 for (
const auto &accumulator : accumulators) {
836 accumulators_.emplace_back(accumulator,
true);
842 state_table_ = state_table;
843 for (
Label i = 0; i < fst_tuples.size(); ++i) {
844 if (i == accumulators_.size()) {
845 accumulators_.resize(accumulators_.size() + 1);
846 accumulators_[i].Init(*(fst_tuples[i].second));
848 fst_array_.emplace_back(fst_tuples[i].second->Copy(
true));
855 FSTERROR() <<
"SafeReplaceAccumulator::Init: Accumulator needs to be" 856 <<
" initialized before being passed to LookAheadMatcher";
862 auto tuple = state_table_->Tuple(s);
863 fst_id_ = tuple.fst_id - 1;
864 GetAccumulator(fst_id_)->SetState(tuple.fst_state);
866 offset_weight_ = Weight::Zero();
867 const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state);
868 if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) {
870 offset_weight_ = final_weight;
872 aiter_.Set(*GetFst(fst_id_), tuple.fst_state);
876 if (error_)
return Weight::NoWeight();
877 return GetAccumulator(fst_id_)->Sum(w, v);
880 template <
class ArcIter>
882 if (error_)
return Weight::NoWeight();
883 if (begin == end)
return Weight::Zero();
884 auto sum = GetAccumulator(fst_id_)->Sum(
885 w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_);
886 if (begin == 0 && end != 0 && offset_ > 0) {
887 sum =
Sum(offset_weight_, sum);
892 bool Error()
const {
return error_; }
895 class ArcIteratorPtr {
897 ArcIteratorPtr() =
default;
899 ArcIteratorPtr(
const ArcIteratorPtr ©) {}
902 ptr_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst, state_id);
908 std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_;
911 Accumulator *GetAccumulator(
size_t i) {
return &accumulators_[i]; }
913 const Fst<Arc> *GetFst(
size_t i)
const {
return fst_array_[i].get(); }
916 std::vector<Accumulator> accumulators_;
917 std::vector<std::shared_ptr<Fst<Arc>>> fst_array_;
918 ArcIteratorPtr aiter_;
928 #endif // FST_ACCUMULATOR_H_
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe=false)
SafeReplaceAccumulator(const std::vector< Accumulator > &accumulators)
constexpr bool Error() const
const Fst< Arc > * GetFst(size_t i) const
typename Arc::Weight Weight
constexpr uint8_t kArcNoCache
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
CacheLogAccumulatorData(const CacheLogAccumulatorData< Arc > &data)
Weight Sum(Weight w, Weight v)
std::shared_ptr< FastLogAccumulatorData > GetData() const
typename Arc::Weight Weight
bool CacheDisabled() const
DefaultAccumulator()=default
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
constexpr uint8_t kArcFlags
MutableFastLogAccumulatorData(int arc_limit, int arc_period)
typename Arc::Weight Weight
void Init(int num_weights, const double *weights, int num_positions, const int *weight_positions)
void AddWeights(StateId s, std::unique_ptr< std::vector< double >> weights)
typename Arc::Label Label
void SetData(std::vector< double > *weights, std::vector< int > *weight_positions) override
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
typename Arc::Weight Weight
FastLogAccumulatorData(int arc_limit, int arc_period)
const int * WeightPositions() const
Weight Sum(Weight w, Weight v)
typename StateTable::StateTuple StateTuple
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
ReplaceAccumulatorData(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
const Arc & Value() const
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
static constexpr T PosInfinity()
ReplaceAccumulator(const ReplaceAccumulator< Accumulator, StateTable > &acc, bool safe=false)
void Init(const Fst< Arc > &fst, bool copy=false)
void Init(const FST &fst, bool copy=false)
CacheLogAccumulatorData(bool gc, size_t gc_limit)
void SetFlags(uint8_t flags, uint8_t mask)
bool IsMutable() const override
Weight Add(const Weight &w)
LogWeightTpl< double > Log64Weight
Weight Sum(Weight w, Weight v)
typename Arc::StateId StateId
typename Arc::StateId StateId
SafeReplaceAccumulator(const SafeReplaceAccumulator ©, bool safe)
const StateTuple & GetTuple(StateId s) const
typename Arc::StateId StateId
typename StateTable::StateTuple StateTuple
virtual Fst * Copy(bool safe=false) const =0
void Init(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_tuples, const StateTable *state_table)
typename Arc::Weight Weight
FastLogAccumulator(std::shared_ptr< FastLogAccumulatorData > data)
typename Arc::Label Label
void Init(const Fst< Arc > &fst, bool copy=false)
typename Accumulator::Arc Arc
typename Accumulator::Arc Arc
LogAccumulator(const LogAccumulator &acc, bool safe=false)
double LogPosExp(double x)
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const
void SetState(StateId s, int depth=0)
constexpr uint8_t kArcWeightValue
FastLogAccumulator(const FastLogAccumulator &acc, bool safe=false)
typename StateTable::StateTuple StateTuple
std::vector< double > * GetWeights(StateId s)
Weight Sum(Weight w, Weight v)
Arc::StateId CountStates(const Fst< Arc > &fst)
typename Arc::StateId StateId
void SetState(StateId state)
const double * Weights() const
Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end)
Accumulator * GetAccumulator(size_t i)
typename Arc::StateId StateId
void Init(const Fst< Arc > &fst, bool copy=false)
typename Accumulator::Arc Arc
ReplaceAccumulator(std::vector< std::unique_ptr< Accumulator >> &&accumulators)
Weight Sum(Weight w, Weight v) const
typename Arc::StateId StateId
FastLogAccumulator(ssize_t arc_limit=20, ssize_t arc_period=10)
size_t LowerBound(Weight w, ArcIter *aiter)
Weight Sum(Weight w, Weight v)
typename Arc::StateId StateId
CacheLogAccumulator(ssize_t arc_limit=10, bool gc=false, size_t gc_limit=10 *1024 *1024)
DefaultAccumulator(const DefaultAccumulator &acc, bool safe=false)
typename Arc::Weight Weight
typename Arc::Weight Weight
void Init(const Fst< Arc > &fst, bool copy=false)
void Init(const Fst< Arc > &fst, bool copy=false)
typename Arc::Label Label
constexpr bool Error() const
typename Arc::StateId StateId