20 #ifndef FST_RANDGEN_H_ 21 #define FST_RANDGEN_H_ 76 const auto n = fst.
NumArcs(s) + (fst.
Final(s) != Weight::Zero());
77 return static_cast<size_t>(
78 std::uniform_int_distribution<>(0, n - 1)(rand_));
82 mutable std::mt19937_64 rand_;
106 for (; !aiter.
Done(); aiter.
Next()) {
107 const auto &arc = aiter.
Value();
108 sum =
Plus(sum, to_log_weight_(arc.weight));
110 sum =
Plus(sum, to_log_weight_(fst.
Final(s)));
111 const double threshold =
112 std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
116 p =
Plus(p, to_log_weight_(aiter.
Value().weight));
117 if (exp(-p.Value()) > threshold)
return n;
122 uint64_t
Seed()
const {
return seed_; }
126 return to_log_weight_(weight);
132 const uint64_t seed_;
133 mutable std::mt19937_64 rand_;
167 -log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
168 Weight w = from_log_weight_(r + sum);
178 template <
typename Arc>
190 : state_id(state_id),
203 template <
class Arc,
class Selector>
212 int32_t max_length = std::numeric_limits<int32_t>::max())
213 : fst_(fst), selector_(selector), max_length_(max_length) {}
218 : fst_(fst ? *fst : sampler.fst_),
219 selector_(sampler.selector_),
220 max_length_(sampler.max_length_) {
232 if ((fst_.NumArcs(rstate.
state_id) == 0 &&
233 fst_.Final(rstate.
state_id) == Weight::Zero()) ||
234 rstate.
length == max_length_) {
238 for (
size_t i = 0; i < rstate.
nsamples; ++i) {
239 ++sample_map_[selector_(fst_, rstate.
state_id)];
246 bool Done()
const {
return sample_iter_ == sample_map_.end(); }
249 void Next() { ++sample_iter_; }
251 std::pair<size_t, size_t>
Value()
const {
return *sample_iter_; }
253 void Reset() { sample_iter_ = sample_map_.begin(); }
255 bool Error()
const {
return false; }
259 const Selector &selector_;
260 const int32_t max_length_;
263 std::map<size_t, size_t> sample_map_;
264 std::map<size_t, size_t>::const_iterator sample_iter_;
274 template <
class Result,
class RNG>
276 size_t num_to_sample, Result *result, RNG *rng) {
277 using distribution = std::binomial_distribution<size_t>;
281 std::vector<double> norm(probs.size());
282 std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
284 for (
size_t i = 0; i < probs.size(); ++i) {
285 distribution::result_type num_sampled = 0;
287 distribution d(num_to_sample, probs[i] / norm[i]);
288 num_sampled = d(*rng);
290 if (num_sampled != 0) (*result)[i] = num_sampled;
291 num_to_sample -= std::min(num_sampled, num_to_sample);
306 int32_t max_length = std::numeric_limits<int32_t>::max())
309 max_length_(max_length),
311 accumulator_->Init(fst);
312 rng_.seed(selector_.Seed());
317 : fst_(fst ? *fst : sampler.fst_),
318 selector_(sampler.selector_),
319 max_length_(sampler.max_length_) {
321 accumulator_ = std::make_unique<Accumulator>();
322 accumulator_->Init(*fst);
324 accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
330 if ((fst_.NumArcs(rstate.
state_id) == 0 &&
331 fst_.Final(rstate.
state_id) == Weight::Zero()) ||
332 rstate.
length == max_length_) {
337 MultinomialSample(rstate);
341 for (
size_t i = 0; i < rstate.
nsamples; ++i) {
342 ++sample_map_[selector_(fst_, rstate.
state_id, accumulator_.get())];
348 bool Done()
const {
return sample_iter_ == sample_map_.end(); }
350 void Next() { ++sample_iter_; }
352 std::pair<size_t, size_t>
Value()
const {
return *sample_iter_; }
354 void Reset() { sample_iter_ = sample_map_.begin(); }
356 bool Error()
const {
return accumulator_->Error(); }
359 using RNG = std::mt19937;
367 p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
369 if (fst_.Final(rstate.
state_id) != Weight::Zero()) {
370 p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.
state_id)).Value()));
372 if (rstate.
nsamples < std::numeric_limits<RNG::result_type>::max()) {
375 for (
size_t i = 0; i < p_.size(); ++i) {
376 sample_map_[i] = ceil(p_[i] * rstate.
nsamples);
383 const int32_t max_length_;
386 std::map<size_t, size_t> sample_map_;
387 std::map<size_t, size_t>::const_iterator sample_iter_;
389 std::unique_ptr<Accumulator> accumulator_;
391 std::vector<double> p_;
398 template <
class Sampler>
407 int32_t npath = 1,
bool weighted =
true,
408 bool remove_total_weight =
false)
413 remove_total_weight(remove_total_weight) {}
419 template <
class FromArc,
class ToArc,
class Sampler>
435 using Label =
typename FromArc::Label;
445 sampler_(opts.sampler),
447 weighted_(opts.weighted),
448 remove_total_weight_(opts.remove_total_weight),
460 fst_(impl.fst_->Copy(true)),
461 sampler_(new Sampler(*impl.sampler_, fst_.get())),
463 weighted_(impl.weighted_),
473 const auto s = fst_->Start();
475 SetStart(state_table_.size());
476 state_table_.emplace_back(
483 if (!HasFinal(s))
Expand(s);
488 if (!HasArcs(s))
Expand(s);
493 if (!HasArcs(s))
Expand(s);
498 if (!HasArcs(s))
Expand(s);
507 (fst_->Properties(kError,
false) || sampler_->Error())) {
508 SetProperties(kError, kError);
514 if (!HasArcs(s))
Expand(s);
521 if (s == superfinal_) {
526 SetFinal(s, ToWeight::Zero());
527 const auto &rstate = *state_table_[s];
528 sampler_->Sample(rstate);
530 const auto narcs = fst_->NumArcs(rstate.state_id);
531 for (; !sampler_->Done(); sampler_->Next()) {
532 const auto &sample_pair = sampler_->Value();
533 const auto pos = sample_pair.first;
534 const auto count = sample_pair.second;
535 double prob =
static_cast<double>(count) / rstate.nsamples;
537 aiter.
Seek(sample_pair.first);
538 const auto &aarc = aiter.
Value();
540 weighted_ ? to_weight_(
Log64Weight(-log(prob))) : ToWeight::One();
541 EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
542 state_table_.size());
544 rstate.length + 1, pos, &rstate);
545 state_table_.emplace_back(nrstate);
555 superfinal_ = state_table_.size();
556 state_table_.emplace_back(
559 for (
size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
567 const std::unique_ptr<Fst<FromArc>> fst_;
568 std::unique_ptr<Sampler> sampler_;
569 const int32_t npath_;
570 std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
571 const bool weighted_;
572 bool remove_total_weight_;
581 template <
class FromArc,
class ToArc,
class Sampler>
583 :
public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
585 using Label =
typename FromArc::Label;
612 GetMutableImpl()->InitArcIterator(s, data);
623 template <
class FromArc,
class ToArc,
class Sampler>
629 fst, fst.GetMutableImpl()) {}
633 template <
class FromArc,
class ToArc,
class Sampler>
641 fst.GetMutableImpl(), s) {
646 template <
class FromArc,
class ToArc,
class Sampler>
650 std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
655 template <
class Selector>
665 const Selector &selector,
666 int32_t max_length = std::numeric_limits<int32_t>::max(),
667 int32_t npath = 1,
bool weighted =
false,
668 bool remove_total_weight =
false)
669 : selector(selector),
670 max_length(max_length),
673 remove_total_weight(remove_total_weight) {}
678 template <
class FromArc,
class ToArc>
688 ofst_->DeleteStates();
698 if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
699 path_.push_back(arc);
707 FSTERROR() <<
"RandGenVisitor: cyclic input";
718 if (p !=
kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
726 const auto start = ofst_->AddState();
727 ofst_->SetStart(start);
729 auto src = ofst_->Start();
730 for (
size_t i = 0; i < path_.size(); ++i) {
731 const auto dest = ofst_->AddState();
732 const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
733 ofst_->AddArc(src, arc);
736 ofst_->SetFinal(src);
741 std::vector<ToArc> path_;
751 template <
class FromArc,
class ToArc,
class Selector>
769 template <
class FromArc,
class ToArc>
771 uint64_t seed = std::random_device()()) {
779 #endif // FST_RANDGEN_H_ void OneMultinomialSample(const std::vector< double > &probs, size_t num_to_sample, Result *result, RNG *rng)
bool Sample(const RandState< Arc > &rstate)
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename FromArc::Weight Weight
StateIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst)
uint64_t RandGenProperties(uint64_t inprops, bool weighted)
typename FromArc::StateId StateId
constexpr bool InitState(StateId, StateId) const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename ToArc::Weight ToWeight
size_t operator()(const Fst< Arc > &fst, StateId s) const
FastLogProbArcSelector(uint64_t seed)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
virtual size_t NumArcs(StateId) const =0
std::pair< size_t, size_t > Value() const
std::mt19937_64 & MutableRand() const
void FinishState(StateId s, StateId p, const FromArc *)
RandState(StateId state_id, size_t nsamples=0, size_t length=0, size_t select=0, const RandState< Arc > *parent=nullptr)
typename FromArc::StateId StateId
typename FromArc::StateId StateId
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
size_t NumInputEpsilons(StateId s)
RandGenFstImpl(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
constexpr uint64_t kError
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
typename Arc::Weight Weight
size_t NumOutputEpsilons(StateId s)
typename Arc::StateId StateId
bool BackArc(StateId, const FromArc &)
virtual Weight Final(StateId) const =0
static constexpr LogWeightTpl Zero()
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
RandGenVisitor(MutableFst< ToArc > *ofst)
LogProbArcSelector(uint64_t seed)
typename FromArc::Label Label
void InitVisit(const Fst< FromArc > &ifst)
typename Arc::Weight Weight
const Arc & Value() const
RandGenFst * Copy(bool safe=false) const override
LogWeightTpl< double > Log64Weight
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
RandGenFstImpl(const RandGenFstImpl &impl)
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
typename Arc::StateId StateId
typename FromArc::Weight Weight
RandGenFst(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
typename Arc::Weight Weight
constexpr uint64_t kCopyProperties
bool ForwardOrCrossArc(StateId, const FromArc &)
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
bool TreeArc(StateId, const ToArc &arc)
std::unique_ptr< StateIteratorBase< Arc > > base
std::pair< size_t, size_t > Value() const
typename Arc::Weight Weight
size_t operator()(const Fst< Arc > &fst, StateId s, CacheLogAccumulator< Arc > *accumulator) const
ArcIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst, StateId s)
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data)
void InitStateIterator(StateIteratorData< ToArc > *data) const override
RandGenOptions(const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max(), int32_t npath=1, bool weighted=false, bool remove_total_weight=false)
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32_t npath=1, bool weighted=true, bool remove_total_weight=false)
void SetState(StateId s, int depth=0)
RandGenFst(const RandGenFst &fst, bool safe=false)
constexpr uint64_t kFstProperties
typename FromArc::Label Label
virtual const SymbolTable * InputSymbols() const =0
Weight Sum(Weight w, Weight v)
const SymbolTable * InputSymbols() const
const RandState< Arc > * parent
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
typename FromArc::StateId StateId
ToWeight Final(StateId s)
typename Arc::StateId StateId
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data) const override
Log64Weight ToLogWeight(const Weight &weight) const
const Selector & selector
uint64_t Properties() const override
uint64_t Properties(uint64_t mask) const override
size_t LowerBound(Weight w, ArcIter *aiter)
typename Arc::StateId StateId
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
bool Sample(const RandState< Arc > &rstate)
internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetMutableImpl() const
typename FromArc::Weight FromWeight
size_t NumArcs(StateId s)
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
const internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetImpl() const
typename Store::State State
virtual const SymbolTable * OutputSymbols() const =0