20 #ifndef FST_RANDGEN_H_ 21 #define FST_RANDGEN_H_ 78 const auto n = fst.
NumArcs(s) + (fst.
Final(s) != Weight::Zero());
79 return static_cast<size_t>(
80 std::uniform_int_distribution<>(0, n - 1)(rand_));
84 mutable std::mt19937_64 rand_;
108 for (; !aiter.
Done(); aiter.
Next()) {
109 const auto &arc = aiter.
Value();
110 sum =
Plus(sum, to_log_weight_(arc.weight));
112 sum =
Plus(sum, to_log_weight_(fst.
Final(s)));
113 const double threshold =
114 std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
118 p =
Plus(p, to_log_weight_(aiter.
Value().weight));
119 if (exp(-p.Value()) > threshold)
return n;
124 uint64_t
Seed()
const {
return seed_; }
128 return to_log_weight_(weight);
134 const uint64_t seed_;
135 mutable std::mt19937_64 rand_;
166 -log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
167 Weight w = from_log_weight_(r + sum);
177 template <
typename Arc>
189 : state_id(state_id),
202 template <
class Arc,
class Selector>
211 int32_t max_length = std::numeric_limits<int32_t>::max())
212 : fst_(fst), selector_(selector), max_length_(max_length) {}
217 : fst_(fst ? *fst : sampler.fst_),
218 selector_(sampler.selector_),
219 max_length_(sampler.max_length_) {
231 if ((fst_.NumArcs(rstate.
state_id) == 0 &&
232 fst_.Final(rstate.
state_id) == Weight::Zero()) ||
233 rstate.
length == max_length_) {
237 for (
size_t i = 0; i < rstate.
nsamples; ++i) {
238 ++sample_map_[selector_(fst_, rstate.
state_id)];
245 bool Done()
const {
return sample_iter_ == sample_map_.end(); }
248 void Next() { ++sample_iter_; }
250 std::pair<size_t, size_t>
Value()
const {
return *sample_iter_; }
252 void Reset() { sample_iter_ = sample_map_.begin(); }
254 bool Error()
const {
return false; }
258 const Selector &selector_;
259 const int32_t max_length_;
262 std::map<size_t, size_t> sample_map_;
263 std::map<size_t, size_t>::const_iterator sample_iter_;
273 template <
class Result,
class RNG>
275 size_t num_to_sample, Result *result, RNG *rng) {
276 using distribution = std::binomial_distribution<size_t>;
280 std::vector<double> norm(probs.size());
281 std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
283 for (
size_t i = 0; i < probs.size(); ++i) {
284 distribution::result_type num_sampled = 0;
286 distribution d(num_to_sample, probs[i] / norm[i]);
287 num_sampled = d(*rng);
289 if (num_sampled != 0) (*result)[i] = num_sampled;
290 num_to_sample -= std::min(num_sampled, num_to_sample);
305 int32_t max_length = std::numeric_limits<int32_t>::max())
308 max_length_(max_length),
310 accumulator_->Init(fst);
311 rng_.seed(selector_.Seed());
316 : fst_(fst ? *fst : sampler.fst_),
317 selector_(sampler.selector_),
318 max_length_(sampler.max_length_) {
320 accumulator_ = std::make_unique<Accumulator>();
321 accumulator_->Init(*fst);
323 accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
329 if ((fst_.NumArcs(rstate.
state_id) == 0 &&
330 fst_.Final(rstate.
state_id) == Weight::Zero()) ||
331 rstate.
length == max_length_) {
336 MultinomialSample(rstate);
340 for (
size_t i = 0; i < rstate.
nsamples; ++i) {
341 ++sample_map_[selector_(fst_, rstate.
state_id, accumulator_.get())];
347 bool Done()
const {
return sample_iter_ == sample_map_.end(); }
349 void Next() { ++sample_iter_; }
351 std::pair<size_t, size_t>
Value()
const {
return *sample_iter_; }
353 void Reset() { sample_iter_ = sample_map_.begin(); }
355 bool Error()
const {
return accumulator_->Error(); }
358 using RNG = std::mt19937;
366 p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
368 if (fst_.Final(rstate.
state_id) != Weight::Zero()) {
369 p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.
state_id)).Value()));
371 if (rstate.
nsamples < std::numeric_limits<RNG::result_type>::max()) {
374 for (
size_t i = 0; i < p_.size(); ++i) {
375 sample_map_[i] = ceil(p_[i] * rstate.
nsamples);
382 const int32_t max_length_;
385 std::map<size_t, size_t> sample_map_;
386 std::map<size_t, size_t>::const_iterator sample_iter_;
388 std::unique_ptr<Accumulator> accumulator_;
390 std::vector<double> p_;
397 template <
class Sampler>
406 int32_t npath = 1,
bool weighted =
true,
407 bool remove_total_weight =
false)
412 remove_total_weight(remove_total_weight) {}
418 template <
class FromArc,
class ToArc,
class Sampler>
434 using Label =
typename FromArc::Label;
444 sampler_(opts.sampler),
446 weighted_(opts.weighted),
447 remove_total_weight_(opts.remove_total_weight),
459 fst_(impl.fst_->Copy(true)),
460 sampler_(new Sampler(*impl.sampler_, fst_.get())),
462 weighted_(impl.weighted_),
472 const auto s = fst_->Start();
474 SetStart(state_table_.size());
475 state_table_.emplace_back(
482 if (!HasFinal(s))
Expand(s);
487 if (!HasArcs(s))
Expand(s);
492 if (!HasArcs(s))
Expand(s);
497 if (!HasArcs(s))
Expand(s);
506 (fst_->Properties(kError,
false) || sampler_->Error())) {
507 SetProperties(kError, kError);
513 if (!HasArcs(s))
Expand(s);
520 if (s == superfinal_) {
525 SetFinal(s, ToWeight::Zero());
526 const auto &rstate = *state_table_[s];
527 sampler_->Sample(rstate);
529 const auto narcs = fst_->NumArcs(rstate.state_id);
530 for (; !sampler_->Done(); sampler_->Next()) {
531 const auto &sample_pair = sampler_->Value();
532 const auto pos = sample_pair.first;
533 const auto count = sample_pair.second;
534 double prob =
static_cast<double>(count) / rstate.nsamples;
536 aiter.
Seek(sample_pair.first);
537 const auto &aarc = aiter.
Value();
539 weighted_ ? to_weight_(
Log64Weight(-log(prob))) : ToWeight::One();
540 EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
541 state_table_.size());
542 auto nrstate = std::make_unique<RandState<FromArc>>(
543 aarc.nextstate, count, rstate.length + 1, pos, &rstate);
544 state_table_.push_back(std::move(nrstate));
554 superfinal_ = state_table_.size();
555 state_table_.emplace_back(
558 for (
size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
566 const std::unique_ptr<Fst<FromArc>> fst_;
567 std::unique_ptr<Sampler> sampler_;
568 const int32_t npath_;
569 std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
570 const bool weighted_;
571 bool remove_total_weight_;
580 template <
class FromArc,
class ToArc,
class Sampler>
582 :
public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
584 using Label =
typename FromArc::Label;
611 GetMutableImpl()->InitArcIterator(s, data);
622 template <
class FromArc,
class ToArc,
class Sampler>
628 fst, fst.GetMutableImpl()) {}
632 template <
class FromArc,
class ToArc,
class Sampler>
640 fst.GetMutableImpl(), s) {
645 template <
class FromArc,
class ToArc,
class Sampler>
649 std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
654 template <
class Selector>
664 const Selector &selector,
665 int32_t max_length = std::numeric_limits<int32_t>::max(),
666 int32_t npath = 1,
bool weighted =
false,
667 bool remove_total_weight =
false)
668 : selector(selector),
669 max_length(max_length),
672 remove_total_weight(remove_total_weight) {}
677 template <
class FromArc,
class ToArc>
687 ofst_->DeleteStates();
697 if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
698 path_.push_back(arc);
706 FSTERROR() <<
"RandGenVisitor: cyclic input";
717 if (p !=
kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
725 const auto start = ofst_->AddState();
726 ofst_->SetStart(start);
728 auto src = ofst_->Start();
729 for (
size_t i = 0; i < path_.size(); ++i) {
730 const auto dest = ofst_->AddState();
731 const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
732 ofst_->AddArc(src, arc);
735 ofst_->SetFinal(src);
740 std::vector<ToArc> path_;
750 template <
class FromArc,
class ToArc,
class Selector>
770 template <
class FromArc,
class ToArc>
772 uint64_t seed = std::random_device()()) {
780 #endif // FST_RANDGEN_H_ void OneMultinomialSample(const std::vector< double > &probs, size_t num_to_sample, Result *result, RNG *rng)
bool Sample(const RandState< Arc > &rstate)
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename FromArc::Weight Weight
StateIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst)
uint64_t RandGenProperties(uint64_t inprops, bool weighted)
typename FromArc::StateId StateId
constexpr bool InitState(StateId, StateId) const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename ToArc::Weight ToWeight
size_t operator()(const Fst< Arc > &fst, StateId s) const
FastLogProbArcSelector(uint64_t seed)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
virtual size_t NumArcs(StateId) const =0
std::pair< size_t, size_t > Value() const
std::mt19937_64 & MutableRand() const
void FinishState(StateId s, StateId p, const FromArc *)
RandState(StateId state_id, size_t nsamples=0, size_t length=0, size_t select=0, const RandState< Arc > *parent=nullptr)
typename FromArc::StateId StateId
typename FromArc::StateId StateId
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
size_t NumInputEpsilons(StateId s)
RandGenFstImpl(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
constexpr uint64_t kError
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
typename Arc::Weight Weight
size_t NumOutputEpsilons(StateId s)
typename Arc::StateId StateId
bool BackArc(StateId, const FromArc &)
virtual Weight Final(StateId) const =0
static constexpr LogWeightTpl Zero()
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
RandGenVisitor(MutableFst< ToArc > *ofst)
LogProbArcSelector(uint64_t seed)
typename FromArc::Label Label
void InitVisit(const Fst< FromArc > &ifst)
typename Arc::Weight Weight
const Arc & Value() const
RandGenFst * Copy(bool safe=false) const override
LogWeightTpl< double > Log64Weight
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
RandGenFstImpl(const RandGenFstImpl &impl)
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
typename Arc::StateId StateId
typename FromArc::Weight Weight
RandGenFst(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
typename Arc::Weight Weight
constexpr uint64_t kCopyProperties
bool ForwardOrCrossArc(StateId, const FromArc &)
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
bool TreeArc(StateId, const ToArc &arc)
std::unique_ptr< StateIteratorBase< Arc > > base
std::pair< size_t, size_t > Value() const
typename Arc::Weight Weight
size_t operator()(const Fst< Arc > &fst, StateId s, CacheLogAccumulator< Arc > *accumulator) const
ArcIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst, StateId s)
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data)
void InitStateIterator(StateIteratorData< ToArc > *data) const override
RandGenOptions(const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max(), int32_t npath=1, bool weighted=false, bool remove_total_weight=false)
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32_t npath=1, bool weighted=true, bool remove_total_weight=false)
void SetState(StateId s, int depth=0)
RandGenFst(const RandGenFst &fst, bool safe=false)
constexpr uint64_t kFstProperties
typename FromArc::Label Label
virtual const SymbolTable * InputSymbols() const =0
Weight Sum(Weight w, Weight v)
const SymbolTable * InputSymbols() const
const RandState< Arc > * parent
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
typename FromArc::StateId StateId
ToWeight Final(StateId s)
typename Arc::StateId StateId
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data) const override
Log64Weight ToLogWeight(const Weight &weight) const
const Selector & selector
uint64_t Properties() const override
uint64_t Properties(uint64_t mask) const override
size_t LowerBound(Weight w, ArcIter *aiter)
typename Arc::StateId StateId
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
bool Sample(const RandState< Arc > &rstate)
internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetMutableImpl() const
typename FromArc::Weight FromWeight
size_t NumArcs(StateId s)
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
const internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetImpl() const
typename Store::State State
virtual const SymbolTable * OutputSymbols() const =0