FST  openfst-1.7.2
OpenFst Library
randgen.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes and functions to generate random paths through an FST.
5 
6 #ifndef FST_RANDGEN_H_
7 #define FST_RANDGEN_H_
8 
9 #include <math.h>
10 #include <stddef.h>
11 #include <limits>
12 #include <map>
13 #include <memory>
14 #include <random>
15 #include <utility>
16 #include <vector>
17 
18 #include <fst/log.h>
19 
20 #include <fst/accumulator.h>
21 #include <fst/cache.h>
22 #include <fst/dfs-visit.h>
23 #include <fst/float-weight.h>
24 #include <fst/fst-decl.h>
25 #include <fst/fst.h>
26 #include <fst/mutable-fst.h>
27 #include <fst/properties.h>
28 #include <fst/util.h>
29 #include <fst/weight.h>
30 
31 namespace fst {
32 
33 // The RandGenFst class is roughly similar to ArcMapFst in that it takes two
34 // template parameters denoting the input and output arc types. However, it also
35 // takes an additional template parameter which specifies a sampler object which
36 // samples (with replacement) arcs from an FST state. The sampler in turn takes
37 // a template parameter for a selector object which actually chooses the arc.
38 //
39 // Arc selector functors are used to select a random transition given an FST
40 // state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
41 // NumArcs(s), then the final weight is selected; otherwise the N-th arc is
42 // selected. It is assumed these are not applied to any state which is neither
43 // final nor has any arcs leaving it.
44 
45 // Randomly selects a transition using the uniform distribution. This class is
46 // not thread-safe.
47 template <class Arc>
49  public:
50  using StateId = typename Arc::StateId;
51  using Weight = typename Arc::Weight;
52 
53  // Constructs a selector with a non-deterministic seed.
54  UniformArcSelector() : rand_(std::random_device()()) {}
55  // Constructs a selector with a given seed.
56  explicit UniformArcSelector(uint64 seed) : rand_(seed) {}
57 
58  size_t operator()(const Fst<Arc> &fst, StateId s) const {
59  const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
60  return static_cast<size_t>(
61  std::uniform_int_distribution<>(0, n - 1)(rand_));
62  }
63 
64  private:
65  mutable std::mt19937_64 rand_;
66 };
67 
68 // Randomly selects a transition w.r.t. the weights treated as negative log
69 // probabilities after normalizing for the total weight leaving the state. Zero
70 // transitions are disregarded. It assumed that Arc::Weight::Value() accesses
71 // the floating point representation of the weight. This class is not
72 // thread-safe.
73 template <class Arc>
75  public:
76  using StateId = typename Arc::StateId;
77  using Weight = typename Arc::Weight;
78 
79  // Constructs a selector with a non-deterministic seed.
80  LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
81  // Constructs a selector with a given seed.
82  explicit LogProbArcSelector(uint64 seed) : seed_(seed), rand_(seed) {}
83 
84  size_t operator()(const Fst<Arc> &fst, StateId s) const {
85  // Finds total weight leaving state.
86  auto sum = Log64Weight::Zero();
87  ArcIterator<Fst<Arc>> aiter(fst, s);
88  for (; !aiter.Done(); aiter.Next()) {
89  const auto &arc = aiter.Value();
90  sum = Plus(sum, to_log_weight_(arc.weight));
91  }
92  sum = Plus(sum, to_log_weight_(fst.Final(s)));
93  const double threshold =
94  std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
95  auto p = Log64Weight::Zero();
96  size_t n = 0;
97  for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
98  p = Plus(p, to_log_weight_(aiter.Value().weight));
99  if (exp(-p.Value()) > threshold) return n;
100  }
101  return n;
102  }
103 
104  uint64 Seed() const { return seed_; }
105 
106  protected:
107  Log64Weight ToLogWeight(const Weight &weight) const {
108  return to_log_weight_(weight);
109  }
110 
111  std::mt19937_64 &MutableRand() const { return rand_; }
112 
113  private:
114  const uint64 seed_;
115  mutable std::mt19937_64 rand_;
116  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
117 };
118 
119 // Useful alias when using StdArc.
121 
122 // Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
123 // accumulation computations. This class is not thread-safe.
124 template <class Arc>
126  public:
127  using StateId = typename Arc::StateId;
128  using Weight = typename Arc::Weight;
129 
133 
134  // Constructs a selector with a non-deterministic seed.
136  // Constructs a selector with a given seed.
138  seed) {}
139 
140  size_t operator()(const Fst<Arc> &fst, StateId s,
141  CacheLogAccumulator<Arc> *accumulator) const {
142  accumulator->SetState(s);
143  ArcIterator<Fst<Arc>> aiter(fst, s);
144  // Finds total weight leaving state.
145  const double sum =
146  ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s)))
147  .Value();
148  const double r = -log(std::uniform_real_distribution<>(0, 1)(
149  MutableRand()));
150  Weight w = from_log_weight_(r + sum);
151  aiter.Reset();
152  return accumulator->LowerBound(w, &aiter);
153  }
154 
155  private:
156  const WeightConvert<Log64Weight, Weight> from_log_weight_{};
157 };
158 
159 // Random path state info maintained by RandGenFst and passed to samplers.
160 template <typename Arc>
161 struct RandState {
162  using StateId = typename Arc::StateId;
163 
164  StateId state_id; // Current input FST state.
165  size_t nsamples; // Number of samples to be sampled at this state.
166  size_t length; // Length of path to this random state.
167  size_t select; // Previous sample arc selection.
168  const RandState<Arc> *parent; // Previous random state on this path.
169 
170  explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
171  size_t select = 0, const RandState<Arc> *parent = nullptr)
172  : state_id(state_id),
173  nsamples(nsamples),
174  length(length),
175  select(select),
176  parent(parent) {}
177 
179 };
180 
181 // This class, given an arc selector, samples, with replacement, multiple random
182 // transitions from an FST's state. This is a generic version with a
183 // straightforward use of the arc selector. Specializations may be defined for
184 // arc selectors for greater efficiency or special behavior.
185 template <class Arc, class Selector>
186 class ArcSampler {
187  public:
188  using StateId = typename Arc::StateId;
189  using Weight = typename Arc::Weight;
190 
191  // The max_length argument may be interpreted (or ignored) by a selector as
192  // it chooses. This generic version interprets this literally.
193  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
194  int32 max_length = std::numeric_limits<int32>::max())
195  : fst_(fst), selector_(selector), max_length_(max_length) {}
196 
197  // Allow updating FST argument; pass only if changed.
199  const Fst<Arc> *fst = nullptr)
200  : fst_(fst ? *fst : sampler.fst_),
201  selector_(sampler.selector_),
202  max_length_(sampler.max_length_) {
203  Reset();
204  }
205 
206  // Samples a fixed number of samples from the given state. The length argument
207  // specifies the length of the path to the state. Returns true if the samples
208  // were collected. No samples may be collected if either there are no
209  // transitions leaving the state and the state is non-final, or if the path
210  // length has been exceeded. Iterator members are provided to read the samples
211  // in the order in which they were collected.
212  bool Sample(const RandState<Arc> &rstate) {
213  sample_map_.clear();
214  if ((fst_.NumArcs(rstate.state_id) == 0 &&
215  fst_.Final(rstate.state_id) == Weight::Zero()) ||
216  rstate.length == max_length_) {
217  Reset();
218  return false;
219  }
220  for (size_t i = 0; i < rstate.nsamples; ++i) {
221  ++sample_map_[selector_(fst_, rstate.state_id)];
222  }
223  Reset();
224  return true;
225  }
226 
227  // More samples?
228  bool Done() const { return sample_iter_ == sample_map_.end(); }
229 
230  // Gets the next sample.
231  void Next() { ++sample_iter_; }
232 
233  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
234 
235  void Reset() { sample_iter_ = sample_map_.begin(); }
236 
237  bool Error() const { return false; }
238 
239  private:
240  const Fst<Arc> &fst_;
241  const Selector &selector_;
242  const int32 max_length_;
243 
244  // Stores (N, K) as described for Value().
245  std::map<size_t, size_t> sample_map_;
246  std::map<size_t, size_t>::const_iterator sample_iter_;
247 
248  ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
249 };
250 
251 // Samples one sample of num_to_sample dimensions from a multinomial
252 // distribution parameterized by a vector of probabilities. The result
253 // container should be pre-initialized (e.g., an empty map or a zeroed vector
254 // sized the same as the vector of probabilities.
255 // probs.size()).
256 template <class Result, class RNG>
257 void OneMultinomialSample(const std::vector<double> &probs,
258  size_t num_to_sample, Result *result, RNG *rng) {
259  // Left-over probability mass.
260  double norm = 0;
261  for (double p : probs) norm += p;
262  // Left-over number of samples needed.
263  for (size_t i = 0; i < probs.size(); ++i) {
264  size_t num_sampled = 0;
265  if (probs[i] > 0) {
266  std::binomial_distribution<> d(num_to_sample, probs[i] / norm);
267  num_sampled = d(*rng);
268  }
269  if (num_sampled != 0) (*result)[i] = num_sampled;
270  norm -= probs[i];
271  num_to_sample -= num_sampled;
272  }
273 }
274 
275 // Specialization for FastLogProbArcSelector.
276 template <class Arc>
278  public:
279  using StateId = typename Arc::StateId;
280  using Weight = typename Arc::Weight;
281 
284 
285  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
286  int32 max_length = std::numeric_limits<int32>::max())
287  : fst_(fst),
288  selector_(selector),
289  max_length_(max_length),
290  accumulator_(new Accumulator()) {
291  accumulator_->Init(fst);
292  rng_.seed(selector_.Seed());
293  }
294 
296  const Fst<Arc> *fst = nullptr)
297  : fst_(fst ? *fst : sampler.fst_),
298  selector_(sampler.selector_),
299  max_length_(sampler.max_length_) {
300  if (fst) {
301  accumulator_.reset(new Accumulator());
302  accumulator_->Init(*fst);
303  } else { // Shallow copy.
304  accumulator_.reset(new Accumulator(*sampler.accumulator_));
305  }
306  }
307 
308  bool Sample(const RandState<Arc> &rstate) {
309  sample_map_.clear();
310  if ((fst_.NumArcs(rstate.state_id) == 0 &&
311  fst_.Final(rstate.state_id) == Weight::Zero()) ||
312  rstate.length == max_length_) {
313  Reset();
314  return false;
315  }
316  if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
317  MultinomialSample(rstate);
318  Reset();
319  return true;
320  }
321  for (size_t i = 0; i < rstate.nsamples; ++i) {
322  ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
323  }
324  Reset();
325  return true;
326  }
327 
328  bool Done() const { return sample_iter_ == sample_map_.end(); }
329 
330  void Next() { ++sample_iter_; }
331 
332  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
333 
334  void Reset() { sample_iter_ = sample_map_.begin(); }
335 
336  bool Error() const { return accumulator_->Error(); }
337 
338  private:
339  using RNG = std::mt19937;
340 
341  // Sample according to the multinomial distribution of rstate.nsamples draws
342  // from p_.
343  void MultinomialSample(const RandState<Arc> &rstate) {
344  p_.clear();
345  for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
346  aiter.Next()) {
347  p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
348  }
349  if (fst_.Final(rstate.state_id) != Weight::Zero()) {
350  p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
351  }
352  if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
353  OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
354  } else {
355  for (size_t i = 0; i < p_.size(); ++i) {
356  sample_map_[i] = ceil(p_[i] * rstate.nsamples);
357  }
358  }
359  }
360 
361  const Fst<Arc> &fst_;
362  const Selector &selector_;
363  const int32 max_length_;
364 
365  // Stores (N, K) for Value().
366  std::map<size_t, size_t> sample_map_;
367  std::map<size_t, size_t>::const_iterator sample_iter_;
368 
369  std::unique_ptr<Accumulator> accumulator_;
370  RNG rng_; // Random number generator.
371  std::vector<double> p_; // Multinomial parameters.
372  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
373 };
374 
375 // Options for random path generation with RandGenFst. The template argument is
376 // a sampler, typically the class ArcSampler. Ownership of the sampler is taken
377 // by RandGenFst.
378 template <class Sampler>
380  Sampler *sampler; // How to sample transitions at a state.
381  int32 npath; // Number of paths to generate.
382  bool weighted; // Is the output tree weighted by path count, or
383  // is it just an unweighted DAG?
384  bool remove_total_weight; // Remove total weight when output is weighted.
385 
386  RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath = 1,
387  bool weighted = true, bool remove_total_weight = false)
388  : CacheOptions(opts),
389  sampler(sampler),
390  npath(npath),
391  weighted(weighted),
392  remove_total_weight(remove_total_weight) {}
393 };
394 
395 namespace internal {
396 
397 // Implementation of RandGenFst.
398 template <class FromArc, class ToArc, class Sampler>
399 class RandGenFstImpl : public CacheImpl<ToArc> {
400  public:
405 
406  using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
407  using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
408  using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
409  using CacheBaseImpl<CacheState<ToArc>>::HasStart;
410  using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
411  using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
412  using CacheBaseImpl<CacheState<ToArc>>::SetStart;
413 
414  using Label = typename FromArc::Label;
415  using StateId = typename FromArc::StateId;
416  using FromWeight = typename FromArc::Weight;
417 
418  using ToWeight = typename ToArc::Weight;
419 
421  const RandGenFstOptions<Sampler> &opts)
422  : CacheImpl<ToArc>(opts),
423  fst_(fst.Copy()),
424  sampler_(opts.sampler),
425  npath_(opts.npath),
426  weighted_(opts.weighted),
427  remove_total_weight_(opts.remove_total_weight),
428  superfinal_(kNoLabel) {
429  SetType("randgen");
430  SetProperties(
431  RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
433  SetInputSymbols(fst.InputSymbols());
434  SetOutputSymbols(fst.OutputSymbols());
435  }
436 
438  : CacheImpl<ToArc>(impl),
439  fst_(impl.fst_->Copy(true)),
440  sampler_(new Sampler(*impl.sampler_, fst_.get())),
441  npath_(impl.npath_),
442  weighted_(impl.weighted_),
443  superfinal_(kNoLabel) {
444  SetType("randgen");
445  SetProperties(impl.Properties(), kCopyProperties);
446  SetInputSymbols(impl.InputSymbols());
447  SetOutputSymbols(impl.OutputSymbols());
448  }
449 
451  if (!HasStart()) {
452  const auto s = fst_->Start();
453  if (s == kNoStateId) return kNoStateId;
454  SetStart(state_table_.size());
455  state_table_.emplace_back(
456  new RandState<FromArc>(s, npath_, 0, 0, nullptr));
457  }
458  return CacheImpl<ToArc>::Start();
459  }
460 
462  if (!HasFinal(s)) Expand(s);
463  return CacheImpl<ToArc>::Final(s);
464  }
465 
466  size_t NumArcs(StateId s) {
467  if (!HasArcs(s)) Expand(s);
468  return CacheImpl<ToArc>::NumArcs(s);
469  }
470 
472  if (!HasArcs(s)) Expand(s);
474  }
475 
477  if (!HasArcs(s)) Expand(s);
479  }
480 
481  uint64 Properties() const override { return Properties(kFstProperties); }
482 
483  // Sets error if found, and returns other FST impl properties.
484  uint64 Properties(uint64 mask) const override {
485  if ((mask & kError) &&
486  (fst_->Properties(kError, false) || sampler_->Error())) {
487  SetProperties(kError, kError);
488  }
489  return FstImpl<ToArc>::Properties(mask);
490  }
491 
493  if (!HasArcs(s)) Expand(s);
495  }
496 
497  // Computes the outgoing transitions from a state, creating new destination
498  // states as needed.
499  void Expand(StateId s) {
500  if (s == superfinal_) {
501  SetFinal(s, ToWeight::One());
502  SetArcs(s);
503  return;
504  }
505  SetFinal(s, ToWeight::Zero());
506  const auto &rstate = *state_table_[s];
507  sampler_->Sample(rstate);
508  ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
509  const auto narcs = fst_->NumArcs(rstate.state_id);
510  for (; !sampler_->Done(); sampler_->Next()) {
511  const auto &sample_pair = sampler_->Value();
512  const auto pos = sample_pair.first;
513  const auto count = sample_pair.second;
514  double prob = static_cast<double>(count) / rstate.nsamples;
515  if (pos < narcs) { // Regular transition.
516  aiter.Seek(sample_pair.first);
517  const auto &aarc = aiter.Value();
518  auto weight =
519  weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
520  EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
521  state_table_.size());
522  auto *nrstate = new RandState<FromArc>(aarc.nextstate, count,
523  rstate.length + 1, pos, &rstate);
524  state_table_.emplace_back(nrstate);
525  } else { // Super-final transition.
526  if (weighted_) {
527  const auto weight =
528  remove_total_weight_
529  ? to_weight_(Log64Weight(-log(prob)))
530  : to_weight_(Log64Weight(-log(prob * npath_)));
531  SetFinal(s, weight);
532  } else {
533  if (superfinal_ == kNoLabel) {
534  superfinal_ = state_table_.size();
535  state_table_.emplace_back(
536  new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
537  }
538  for (size_t n = 0; n < count; ++n) {
539  EmplaceArc(s, 0, 0, ToWeight::One(), superfinal_);
540  }
541  }
542  }
543  }
544  SetArcs(s);
545  }
546 
547  private:
548  const std::unique_ptr<Fst<FromArc>> fst_;
549  std::unique_ptr<Sampler> sampler_;
550  const int32 npath_;
551  std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
552  const bool weighted_;
553  bool remove_total_weight_;
554  StateId superfinal_;
555  const WeightConvert<Log64Weight, ToWeight> to_weight_{};
556 };
557 
558 } // namespace internal
559 
560 // FST class to randomly generate paths through an FST, with details controlled
561 // by RandGenOptionsFst. Output format is a tree weighted by the path count.
562 template <class FromArc, class ToArc, class Sampler>
563 class RandGenFst
564  : public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
565  public:
566  using Label = typename FromArc::Label;
567  using StateId = typename FromArc::StateId;
568  using Weight = typename FromArc::Weight;
569 
571  using State = typename Store::State;
572 
574 
575  friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
576  friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
577 
579  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
580 
581  // See Fst<>::Copy() for doc.
583  : ImplToFst<Impl>(fst, safe) {}
584 
585  // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
586  RandGenFst<FromArc, ToArc, Sampler> *Copy(bool safe = false) const override {
587  return new RandGenFst<FromArc, ToArc, Sampler>(*this, safe);
588  }
589 
590  inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
591 
592  void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
593  GetMutableImpl()->InitArcIterator(s, data);
594  }
595 
596  private:
599 
600  RandGenFst &operator=(const RandGenFst &) = delete;
601 };
602 
603 // Specialization for RandGenFst.
604 template <class FromArc, class ToArc, class Sampler>
605 class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
606  : public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
607  public:
609  : CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
610  fst, fst.GetMutableImpl()) {}
611 };
612 
613 // Specialization for RandGenFst.
614 template <class FromArc, class ToArc, class Sampler>
615 class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
616  : public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
617  public:
618  using StateId = typename FromArc::StateId;
619 
621  : CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
622  fst.GetMutableImpl(), s) {
623  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
624  }
625 };
626 
627 template <class FromArc, class ToArc, class Sampler>
629  StateIteratorData<ToArc> *data) const {
631 }
632 
633 // Options for random path generation.
634 template <class Selector>
636  const Selector &selector; // How an arc is selected at a state.
637  int32 max_length; // Maximum path length.
638  int32 npath; // Number of paths to generate.
639  bool weighted; // Is the output tree weighted by path count, or
640  // is it just an unweighted DAG?
641  bool remove_total_weight; // Remove total weight when output is weighted?
642 
643  explicit RandGenOptions(const Selector &selector,
644  int32 max_length = std::numeric_limits<int32>::max(),
645  int32 npath = 1, bool weighted = false,
646  bool remove_total_weight = false)
647  : selector(selector),
648  max_length(max_length),
649  npath(npath),
650  weighted(weighted),
651  remove_total_weight(remove_total_weight) {}
652 };
653 
654 namespace internal {
655 
656 template <class FromArc, class ToArc>
658  public:
659  using StateId = typename FromArc::StateId;
660  using Weight = typename FromArc::Weight;
661 
662  explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
663 
664  void InitVisit(const Fst<FromArc> &ifst) {
665  ifst_ = &ifst;
666  ofst_->DeleteStates();
667  ofst_->SetInputSymbols(ifst.InputSymbols());
668  ofst_->SetOutputSymbols(ifst.OutputSymbols());
669  if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
670  path_.clear();
671  }
672 
673  constexpr bool InitState(StateId, StateId) const { return true; }
674 
675  bool TreeArc(StateId, const ToArc &arc) {
676  if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
677  path_.push_back(arc);
678  } else {
679  OutputPath();
680  }
681  return true;
682  }
683 
684  bool BackArc(StateId, const FromArc &) {
685  FSTERROR() << "RandGenVisitor: cyclic input";
686  ofst_->SetProperties(kError, kError);
687  return false;
688  }
689 
690  bool ForwardOrCrossArc(StateId, const FromArc &) {
691  OutputPath();
692  return true;
693  }
694 
695  void FinishState(StateId s, StateId p, const FromArc *) {
696  if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
697  }
698 
699  void FinishVisit() {}
700 
701  private:
702  void OutputPath() {
703  if (ofst_->Start() == kNoStateId) {
704  const auto start = ofst_->AddState();
705  ofst_->SetStart(start);
706  }
707  auto src = ofst_->Start();
708  for (size_t i = 0; i < path_.size(); ++i) {
709  const auto dest = ofst_->AddState();
710  const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
711  ofst_->AddArc(src, arc);
712  src = dest;
713  }
714  ofst_->SetFinal(src, Weight::One());
715  }
716 
717  const Fst<FromArc> *ifst_;
718  MutableFst<ToArc> *ofst_;
719  std::vector<ToArc> path_;
720 
721  RandGenVisitor(const RandGenVisitor &) = delete;
722  RandGenVisitor &operator=(const RandGenVisitor &) = delete;
723 };
724 
725 } // namespace internal
726 
727 // Randomly generate paths through an FST; details controlled by
728 // RandGenOptions.
729 template <class FromArc, class ToArc, class Selector>
730 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
731  const RandGenOptions<Selector> &opts) {
732  using Sampler = ArcSampler<FromArc, Selector>;
733  auto *sampler = new Sampler(ifst, opts.selector, opts.max_length);
734  RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler, opts.npath,
735  opts.weighted, opts.remove_total_weight);
736  RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
737  if (opts.weighted) {
738  *ofst = rfst;
739  } else {
740  internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
741  DfsVisit(rfst, &rand_visitor);
742  }
743 }
744 
745 // Randomly generate a path through an FST with the uniform distribution
746 // over the transitions.
747 template <class FromArc, class ToArc>
748 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst) {
749  const UniformArcSelector<FromArc> uniform_selector;
750  RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
751  RandGen(ifst, ofst, opts);
752 }
753 
754 } // namespace fst
755 
756 #endif // FST_RANDGEN_H_
void OneMultinomialSample(const std::vector< double > &probs, size_t num_to_sample, Result *result, RNG *rng)
Definition: randgen.h:257
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:308
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:99
typename FromArc::Weight Weight
Definition: randgen.h:568
StateIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst)
Definition: randgen.h:608
uint64 Properties(uint64 mask) const override
Definition: randgen.h:484
uint64 Seed() const
Definition: randgen.h:104
typename FromArc::StateId StateId
Definition: randgen.h:415
constexpr bool InitState(StateId, StateId) const
Definition: randgen.h:673
constexpr int kNoLabel
Definition: fst.h:179
void Next()
Definition: randgen.h:231
typename ToArc::Weight ToWeight
Definition: randgen.h:418
uint64_t uint64
Definition: types.h:32
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:84
void Reset()
Definition: fst.h:515
typename Arc::StateId StateId
Definition: randgen.h:50
virtual size_t NumArcs(StateId) const =0
std::pair< size_t, size_t > Value() const
Definition: randgen.h:332
std::mt19937_64 & MutableRand() const
Definition: randgen.h:111
StateId state_id
Definition: randgen.h:164
void FinishState(StateId s, StateId p, const FromArc *)
Definition: randgen.h:695
RandState(StateId state_id, size_t nsamples=0, size_t length=0, size_t select=0, const RandState< Arc > *parent=nullptr)
Definition: randgen.h:170
typename FromArc::StateId StateId
Definition: randgen.h:567
void Expand(StateId s)
Definition: randgen.h:499
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:198
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:82
const SymbolTable * OutputSymbols() const
Definition: fst.h:690
size_t NumInputEpsilons(StateId s)
Definition: randgen.h:471
bool Error() const
Definition: randgen.h:237
RandGenFstImpl(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:420
RandGenOptions(const Selector &selector, int32 max_length=std::numeric_limits< int32 >::max(), int32 npath=1, bool weighted=false, bool remove_total_weight=false)
Definition: randgen.h:643
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:94
typename Arc::Weight Weight
Definition: randgen.h:77
size_t NumOutputEpsilons(StateId s)
Definition: randgen.h:476
typename Arc::StateId StateId
Definition: randgen.h:76
bool BackArc(StateId, const FromArc &)
Definition: randgen.h:684
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:37
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:413
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:88
RandGenFst(const RandGenFst< FromArc, ToArc, Sampler > &fst, bool safe=false)
Definition: randgen.h:582
RandGenVisitor(MutableFst< ToArc > *ofst)
Definition: randgen.h:662
typename FromArc::Label Label
Definition: randgen.h:566
constexpr uint64 kFstProperties
Definition: properties.h:301
void InitVisit(const Fst< FromArc > &ifst)
Definition: randgen.h:664
constexpr uint64 kCopyProperties
Definition: properties.h:138
constexpr int kNoStateId
Definition: fst.h:180
typename Arc::Weight Weight
Definition: randgen.h:189
const Arc & Value() const
Definition: fst.h:503
size_t nsamples
Definition: randgen.h:165
virtual uint64 Properties(uint64 mask, bool test) const =0
uint64 RandGenProperties(uint64 inprops, bool weighted)
Definition: properties.cc:216
#define FSTERROR()
Definition: util.h:35
size_t length
Definition: randgen.h:166
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:449
UniformArcSelector(uint64 seed)
Definition: randgen.h:56
StateIteratorBase< Arc > * base
Definition: fst.h:351
typename Arc::Weight Weight
Definition: randgen.h:51
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:93
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:295
RandGenFstImpl(const RandGenFstImpl &impl)
Definition: randgen.h:437
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:730
typename Arc::StateId StateId
Definition: randgen.h:162
typename FromArc::Weight Weight
Definition: randgen.h:660
bool Done() const
Definition: randgen.h:228
void Seek(size_t a)
Definition: fst.h:523
RandGenFst(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:578
bool ForwardOrCrossArc(StateId, const FromArc &)
Definition: randgen.h:690
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:648
void Reset()
Definition: randgen.h:235
bool TreeArc(StateId, const ToArc &arc)
Definition: randgen.h:675
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
bool Done() const
Definition: fst.h:499
LogProbArcSelector(uint64 seed)
Definition: randgen.h:82
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath=1, bool weighted=true, bool remove_total_weight=false)
Definition: randgen.h:386
std::pair< size_t, size_t > Value() const
Definition: randgen.h:233
size_t operator()(const Fst< Arc > &fst, StateId s, CacheLogAccumulator< Arc > *accumulator) const
Definition: randgen.h:140
ArcIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst, StateId s)
Definition: randgen.h:620
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data)
Definition: randgen.h:492
void InitStateIterator(StateIteratorData< ToArc > *data) const override
Definition: randgen.h:628
void SetState(StateId s, int depth=0)
Definition: accumulator.h:492
RandGenFst< FromArc, ToArc, Sampler > * Copy(bool safe=false) const override
Definition: randgen.h:586
typename FromArc::Label Label
Definition: randgen.h:414
FastLogProbArcSelector(uint64 seed)
Definition: randgen.h:137
virtual const SymbolTable * InputSymbols() const =0
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:514
const SymbolTable * InputSymbols() const
Definition: fst.h:688
const RandState< Arc > * parent
Definition: randgen.h:168
typename FromArc::StateId StateId
Definition: randgen.h:659
ToWeight Final(StateId s)
Definition: randgen.h:461
int32_t int32
Definition: types.h:26
constexpr uint64 kError
Definition: properties.h:33
uint64 Properties() const override
Definition: randgen.h:481
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data) const override
Definition: randgen.h:592
Log64Weight ToLogWeight(const Weight &weight) const
Definition: randgen.h:107
const Selector & selector
Definition: randgen.h:636
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:551
typename Arc::StateId StateId
Definition: randgen.h:188
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:302
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:212
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32 max_length=std::numeric_limits< int32 >::max())
Definition: randgen.h:285
internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetMutableImpl() const
Definition: fst.h:947
typename FromArc::Weight FromWeight
Definition: randgen.h:416
size_t NumArcs(StateId s)
Definition: randgen.h:466
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:58
size_t select
Definition: randgen.h:167
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32 max_length=std::numeric_limits< int32 >::max())
Definition: randgen.h:193
const internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetImpl() const
Definition: fst.h:945
bool remove_total_weight
Definition: randgen.h:641
typename Store::State State
Definition: randgen.h:571
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:507