FST  openfst-1.7.1
OpenFst Library
randgen.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes and functions to generate random paths through an FST.
5 
6 #ifndef FST_RANDGEN_H_
7 #define FST_RANDGEN_H_
8 
9 #include <math.h>
10 #include <stddef.h>
11 #include <limits>
12 #include <map>
13 #include <memory>
14 #include <random>
15 #include <utility>
16 #include <vector>
17 
18 #include <fst/log.h>
19 
20 #include <fst/accumulator.h>
21 #include <fst/cache.h>
22 #include <fst/dfs-visit.h>
23 #include <fst/float-weight.h>
24 #include <fst/fst-decl.h>
25 #include <fst/fst.h>
26 #include <fst/mutable-fst.h>
27 #include <fst/properties.h>
28 #include <fst/util.h>
29 #include <fst/weight.h>
30 
31 namespace fst {
32 
33 // The RandGenFst class is roughly similar to ArcMapFst in that it takes two
34 // template parameters denoting the input and output arc types. However, it also
35 // takes an additional template parameter which specifies a sampler object which
36 // samples (with replacement) arcs from an FST state. The sampler in turn takes
37 // a template parameter for a selector object which actually chooses the arc.
38 //
39 // Arc selector functors are used to select a random transition given an FST
40 // state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
41 // NumArcs(s), then the final weight is selected; otherwise the N-th arc is
42 // selected. It is assumed these are not applied to any state which is neither
43 // final nor has any arcs leaving it.
44 
45 // Randomly selects a transition using the uniform distribution. This class is
46 // not thread-safe.
47 template <class Arc>
49  public:
50  using StateId = typename Arc::StateId;
51  using Weight = typename Arc::Weight;
52 
53  // Constructs a selector with a non-deterministic seed.
54  UniformArcSelector() : rand_(std::random_device()()) {}
55  // Constructs a selector with a given seed.
56  explicit UniformArcSelector(uint64 seed) : rand_(seed) {}
57 
58  size_t operator()(const Fst<Arc> &fst, StateId s) const {
59  const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
60  return static_cast<size_t>(
61  std::uniform_int_distribution<>(0, n - 1)(rand_));
62  }
63 
64  private:
65  mutable std::mt19937_64 rand_;
66 };
67 
68 // Randomly selects a transition w.r.t. the weights treated as negative log
69 // probabilities after normalizing for the total weight leaving the state. Zero
70 // transitions are disregarded. It assumed that Arc::Weight::Value() accesses
71 // the floating point representation of the weight. This class is not
72 // thread-safe.
73 template <class Arc>
75  public:
76  using StateId = typename Arc::StateId;
77  using Weight = typename Arc::Weight;
78 
79  // Constructs a selector with a non-deterministic seed.
80  LogProbArcSelector() : rand_(std::random_device()()) {}
81  // Constructs a selector with a given seed.
82  explicit LogProbArcSelector(uint64 seed) : rand_(seed) {}
83 
84  size_t operator()(const Fst<Arc> &fst, StateId s) const {
85  // Finds total weight leaving state.
86  auto sum = Log64Weight::Zero();
87  ArcIterator<Fst<Arc>> aiter(fst, s);
88  for (; !aiter.Done(); aiter.Next()) {
89  const auto &arc = aiter.Value();
90  sum = Plus(sum, to_log_weight_(arc.weight));
91  }
92  sum = Plus(sum, to_log_weight_(fst.Final(s)));
93  const double threshold =
94  std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
95  auto p = Log64Weight::Zero();
96  size_t n = 0;
97  for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
98  p = Plus(p, to_log_weight_(aiter.Value().weight));
99  if (exp(-p.Value()) > threshold) return n;
100  }
101  return n;
102  }
103 
104  private:
105  mutable std::mt19937_64 rand_;
106  WeightConvert<Weight, Log64Weight> to_log_weight_;
107 };
108 
109 // Useful alias when using StdArc.
111 
112 // Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
113 // accumulation computations. This class is not thread-safe.
114 template <class Arc>
116  public:
117  using StateId = typename Arc::StateId;
118  using Weight = typename Arc::Weight;
119 
121 
122  // Constructs a selector with a non-deterministic seed.
123  FastLogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
124  // Constructs a selector with a given seed.
125  explicit FastLogProbArcSelector(uint64 seed) : seed_(seed), rand_(seed_) {}
126 
127  size_t operator()(const Fst<Arc> &fst, StateId s,
128  CacheLogAccumulator<Arc> *accumulator) const {
129  accumulator->SetState(s);
130  ArcIterator<Fst<Arc>> aiter(fst, s);
131  // Finds total weight leaving state.
132  const double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
133  fst.NumArcs(s)))
134  .Value();
135  const double r = -log(std::uniform_real_distribution<>(0, 1)(rand_));
136  Weight w = from_log_weight_(r + sum);
137  aiter.Reset();
138  return accumulator->LowerBound(w, &aiter);
139  }
140 
141  uint64 Seed() const { return seed_; }
142 
143  private:
144  const uint64 seed_;
145  mutable std::mt19937_64 rand_;
146  WeightConvert<Weight, Log64Weight> to_log_weight_;
147  WeightConvert<Log64Weight, Weight> from_log_weight_;
148 };
149 
150 // Random path state info maintained by RandGenFst and passed to samplers.
151 template <typename Arc>
152 struct RandState {
153  using StateId = typename Arc::StateId;
154 
155  StateId state_id; // Current input FST state.
156  size_t nsamples; // Number of samples to be sampled at this state.
157  size_t length; // Length of path to this random state.
158  size_t select; // Previous sample arc selection.
159  const RandState<Arc> *parent; // Previous random state on this path.
160 
161  explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
162  size_t select = 0, const RandState<Arc> *parent = nullptr)
163  : state_id(state_id),
164  nsamples(nsamples),
165  length(length),
166  select(select),
167  parent(parent) {}
168 
170 };
171 
172 // This class, given an arc selector, samples, with replacement, multiple random
173 // transitions from an FST's state. This is a generic version with a
174 // straightforward use of the arc selector. Specializations may be defined for
175 // arc selectors for greater efficiency or special behavior.
176 template <class Arc, class Selector>
177 class ArcSampler {
178  public:
179  using StateId = typename Arc::StateId;
180  using Weight = typename Arc::Weight;
181 
182  // The max_length argument may be interpreted (or ignored) by a selector as
183  // it chooses. This generic version interprets this literally.
184  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
185  int32 max_length = std::numeric_limits<int32>::max())
186  : fst_(fst), selector_(selector), max_length_(max_length) {}
187 
188  // Allow updating FST argument; pass only if changed.
190  const Fst<Arc> *fst = nullptr)
191  : fst_(fst ? *fst : sampler.fst_),
192  selector_(sampler.selector_),
193  max_length_(sampler.max_length_) {
194  Reset();
195  }
196 
197  // Samples a fixed number of samples from the given state. The length argument
198  // specifies the length of the path to the state. Returns true if the samples
199  // were collected. No samples may be collected if either there are no
200  // transitions leaving the state and the state is non-final, or if the path
201  // length has been exceeded. Iterator members are provided to read the samples
202  // in the order in which they were collected.
203  bool Sample(const RandState<Arc> &rstate) {
204  sample_map_.clear();
205  if ((fst_.NumArcs(rstate.state_id) == 0 &&
206  fst_.Final(rstate.state_id) == Weight::Zero()) ||
207  rstate.length == max_length_) {
208  Reset();
209  return false;
210  }
211  for (size_t i = 0; i < rstate.nsamples; ++i) {
212  ++sample_map_[selector_(fst_, rstate.state_id)];
213  }
214  Reset();
215  return true;
216  }
217 
218  // More samples?
219  bool Done() const { return sample_iter_ == sample_map_.end(); }
220 
221  // Gets the next sample.
222  void Next() { ++sample_iter_; }
223 
224  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
225 
226  void Reset() { sample_iter_ = sample_map_.begin(); }
227 
228  bool Error() const { return false; }
229 
230  private:
231  const Fst<Arc> &fst_;
232  const Selector &selector_;
233  const int32 max_length_;
234 
235  // Stores (N, K) as described for Value().
236  std::map<size_t, size_t> sample_map_;
237  std::map<size_t, size_t>::const_iterator sample_iter_;
238 
239  ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
240 };
241 
242 // Samples one sample of num_to_sample dimensions from a multinomial
243 // distribution parameterized by a vector of probabilities. The result
244 // container should be pre-initialized (e.g., an empty map or a zeroed vector
245 // sized the same as the vector of probabilities.
246 // probs.size()).
247 template <class Result, class RNG>
248 void OneMultinomialSample(const std::vector<double> &probs,
249  size_t num_to_sample, Result *result, RNG *rng) {
250  // Left-over probability mass.
251  double norm = 0;
252  for (double p : probs) norm += p;
253  // Left-over number of samples needed.
254  for (size_t i = 0; i < probs.size(); ++i) {
255  size_t num_sampled = 0;
256  if (probs[i] > 0) {
257  std::binomial_distribution<> d(num_to_sample, probs[i] / norm);
258  num_sampled = d(*rng);
259  }
260  if (num_sampled != 0) (*result)[i] = num_sampled;
261  norm -= probs[i];
262  num_to_sample -= num_sampled;
263  }
264 }
265 
266 // Specialization for FastLogProbArcSelector.
267 template <class Arc>
269  public:
270  using StateId = typename Arc::StateId;
271  using Weight = typename Arc::Weight;
272 
275 
276  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
277  int32 max_length = std::numeric_limits<int32>::max())
278  : fst_(fst),
279  selector_(selector),
280  max_length_(max_length),
281  accumulator_(new Accumulator()) {
282  accumulator_->Init(fst);
283  rng_.seed(selector_.Seed());
284  }
285 
287  const Fst<Arc> *fst = nullptr)
288  : fst_(fst ? *fst : sampler.fst_),
289  selector_(sampler.selector_),
290  max_length_(sampler.max_length_) {
291  if (fst) {
292  accumulator_.reset(new Accumulator());
293  accumulator_->Init(*fst);
294  } else { // Shallow copy.
295  accumulator_.reset(new Accumulator(*sampler.accumulator_));
296  }
297  }
298 
299  bool Sample(const RandState<Arc> &rstate) {
300  sample_map_.clear();
301  if ((fst_.NumArcs(rstate.state_id) == 0 &&
302  fst_.Final(rstate.state_id) == Weight::Zero()) ||
303  rstate.length == max_length_) {
304  Reset();
305  return false;
306  }
307  if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
308  MultinomialSample(rstate);
309  Reset();
310  return true;
311  }
312  for (size_t i = 0; i < rstate.nsamples; ++i) {
313  ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
314  }
315  Reset();
316  return true;
317  }
318 
319  bool Done() const { return sample_iter_ == sample_map_.end(); }
320 
321  void Next() { ++sample_iter_; }
322 
323  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
324 
325  void Reset() { sample_iter_ = sample_map_.begin(); }
326 
327  bool Error() const { return accumulator_->Error(); }
328 
329  private:
330  using RNG = std::mt19937;
331 
332  // Sample according to the multinomial distribution of rstate.nsamples draws
333  // from p_.
334  void MultinomialSample(const RandState<Arc> &rstate) {
335  p_.clear();
336  for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
337  aiter.Next()) {
338  p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
339  }
340  if (fst_.Final(rstate.state_id) != Weight::Zero()) {
341  p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
342  }
343  if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
344  OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
345  } else {
346  for (size_t i = 0; i < p_.size(); ++i) {
347  sample_map_[i] = ceil(p_[i] * rstate.nsamples);
348  }
349  }
350  }
351 
352  const Fst<Arc> &fst_;
353  const Selector &selector_;
354  const int32 max_length_;
355 
356  // Stores (N, K) for Value().
357  std::map<size_t, size_t> sample_map_;
358  std::map<size_t, size_t>::const_iterator sample_iter_;
359 
360  std::unique_ptr<Accumulator> accumulator_;
361  RNG rng_; // Random number generator.
362  std::vector<double> p_; // Multinomial parameters.
363  WeightConvert<Weight, Log64Weight> to_log_weight_;
364 };
365 
366 // Options for random path generation with RandGenFst. The template argument is
367 // a sampler, typically the class ArcSampler. Ownership of the sampler is taken
368 // by RandGenFst.
369 template <class Sampler>
371  Sampler *sampler; // How to sample transitions at a state.
372  int32 npath; // Number of paths to generate.
373  bool weighted; // Is the output tree weighted by path count, or
374  // is it just an unweighted DAG?
375  bool remove_total_weight; // Remove total weight when output is weighted.
376 
377  RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath = 1,
378  bool weighted = true, bool remove_total_weight = false)
379  : CacheOptions(opts),
380  sampler(sampler),
381  npath(npath),
382  weighted(weighted),
383  remove_total_weight(remove_total_weight) {}
384 };
385 
386 namespace internal {
387 
388 // Implementation of RandGenFst.
389 template <class FromArc, class ToArc, class Sampler>
390 class RandGenFstImpl : public CacheImpl<ToArc> {
391  public:
396 
397  using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
398  using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
399  using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
400  using CacheBaseImpl<CacheState<ToArc>>::HasStart;
401  using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
402  using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
403  using CacheBaseImpl<CacheState<ToArc>>::SetStart;
404 
405  using Label = typename FromArc::Label;
406  using StateId = typename FromArc::StateId;
407  using FromWeight = typename FromArc::Weight;
408 
409  using ToWeight = typename ToArc::Weight;
410 
412  const RandGenFstOptions<Sampler> &opts)
413  : CacheImpl<ToArc>(opts),
414  fst_(fst.Copy()),
415  sampler_(opts.sampler),
416  npath_(opts.npath),
417  weighted_(opts.weighted),
418  remove_total_weight_(opts.remove_total_weight),
419  superfinal_(kNoLabel) {
420  SetType("randgen");
421  SetProperties(
422  RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
424  SetInputSymbols(fst.InputSymbols());
425  SetOutputSymbols(fst.OutputSymbols());
426  }
427 
429  : CacheImpl<ToArc>(impl),
430  fst_(impl.fst_->Copy(true)),
431  sampler_(new Sampler(*impl.sampler_, fst_.get())),
432  npath_(impl.npath_),
433  weighted_(impl.weighted_),
434  superfinal_(kNoLabel) {
435  SetType("randgen");
436  SetProperties(impl.Properties(), kCopyProperties);
437  SetInputSymbols(impl.InputSymbols());
438  SetOutputSymbols(impl.OutputSymbols());
439  }
440 
442  if (!HasStart()) {
443  const auto s = fst_->Start();
444  if (s == kNoStateId) return kNoStateId;
445  SetStart(state_table_.size());
446  state_table_.emplace_back(
447  new RandState<FromArc>(s, npath_, 0, 0, nullptr));
448  }
449  return CacheImpl<ToArc>::Start();
450  }
451 
453  if (!HasFinal(s)) Expand(s);
454  return CacheImpl<ToArc>::Final(s);
455  }
456 
457  size_t NumArcs(StateId s) {
458  if (!HasArcs(s)) Expand(s);
459  return CacheImpl<ToArc>::NumArcs(s);
460  }
461 
463  if (!HasArcs(s)) Expand(s);
465  }
466 
468  if (!HasArcs(s)) Expand(s);
470  }
471 
472  uint64 Properties() const override { return Properties(kFstProperties); }
473 
474  // Sets error if found, and returns other FST impl properties.
475  uint64 Properties(uint64 mask) const override {
476  if ((mask & kError) &&
477  (fst_->Properties(kError, false) || sampler_->Error())) {
478  SetProperties(kError, kError);
479  }
480  return FstImpl<ToArc>::Properties(mask);
481  }
482 
484  if (!HasArcs(s)) Expand(s);
486  }
487 
488  // Computes the outgoing transitions from a state, creating new destination
489  // states as needed.
490  void Expand(StateId s) {
491  if (s == superfinal_) {
492  SetFinal(s, ToWeight::One());
493  SetArcs(s);
494  return;
495  }
496  SetFinal(s, ToWeight::Zero());
497  const auto &rstate = *state_table_[s];
498  sampler_->Sample(rstate);
499  ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
500  const auto narcs = fst_->NumArcs(rstate.state_id);
501  for (; !sampler_->Done(); sampler_->Next()) {
502  const auto &sample_pair = sampler_->Value();
503  const auto pos = sample_pair.first;
504  const auto count = sample_pair.second;
505  double prob = static_cast<double>(count) / rstate.nsamples;
506  if (pos < narcs) { // Regular transition.
507  aiter.Seek(sample_pair.first);
508  const auto &aarc = aiter.Value();
509  auto weight =
510  weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
511  EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
512  state_table_.size());
513  auto *nrstate = new RandState<FromArc>(aarc.nextstate, count,
514  rstate.length + 1, pos, &rstate);
515  state_table_.emplace_back(nrstate);
516  } else { // Super-final transition.
517  if (weighted_) {
518  const auto weight =
519  remove_total_weight_
520  ? to_weight_(Log64Weight(-log(prob)))
521  : to_weight_(Log64Weight(-log(prob * npath_)));
522  SetFinal(s, weight);
523  } else {
524  if (superfinal_ == kNoLabel) {
525  superfinal_ = state_table_.size();
526  state_table_.emplace_back(
527  new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
528  }
529  for (size_t n = 0; n < count; ++n) {
530  EmplaceArc(s, 0, 0, ToWeight::One(), superfinal_);
531  }
532  }
533  }
534  }
535  SetArcs(s);
536  }
537 
538  private:
539  const std::unique_ptr<Fst<FromArc>> fst_;
540  std::unique_ptr<Sampler> sampler_;
541  const int32 npath_;
542  std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
543  const bool weighted_;
544  bool remove_total_weight_;
545  StateId superfinal_;
547 };
548 
549 } // namespace internal
550 
551 // FST class to randomly generate paths through an FST, with details controlled
552 // by RandGenOptionsFst. Output format is a tree weighted by the path count.
553 template <class FromArc, class ToArc, class Sampler>
554 class RandGenFst
555  : public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
556  public:
557  using Label = typename FromArc::Label;
558  using StateId = typename FromArc::StateId;
559  using Weight = typename FromArc::Weight;
560 
562  using State = typename Store::State;
563 
565 
566  friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
567  friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
568 
570  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
571 
572  // See Fst<>::Copy() for doc.
574  : ImplToFst<Impl>(fst, safe) {}
575 
576  // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
577  RandGenFst<FromArc, ToArc, Sampler> *Copy(bool safe = false) const override {
578  return new RandGenFst<FromArc, ToArc, Sampler>(*this, safe);
579  }
580 
581  inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
582 
583  void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
584  GetMutableImpl()->InitArcIterator(s, data);
585  }
586 
587  private:
590 
591  RandGenFst &operator=(const RandGenFst &) = delete;
592 };
593 
594 // Specialization for RandGenFst.
595 template <class FromArc, class ToArc, class Sampler>
596 class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
597  : public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
598  public:
600  : CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
601  fst, fst.GetMutableImpl()) {}
602 };
603 
604 // Specialization for RandGenFst.
605 template <class FromArc, class ToArc, class Sampler>
606 class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
607  : public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
608  public:
609  using StateId = typename FromArc::StateId;
610 
612  : CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
613  fst.GetMutableImpl(), s) {
614  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
615  }
616 };
617 
618 template <class FromArc, class ToArc, class Sampler>
620  StateIteratorData<ToArc> *data) const {
622 }
623 
624 // Options for random path generation.
625 template <class Selector>
627  const Selector &selector; // How an arc is selected at a state.
628  int32 max_length; // Maximum path length.
629  int32 npath; // Number of paths to generate.
630  bool weighted; // Is the output tree weighted by path count, or
631  // is it just an unweighted DAG?
632  bool remove_total_weight; // Remove total weight when output is weighted?
633 
634  explicit RandGenOptions(const Selector &selector,
635  int32 max_length = std::numeric_limits<int32>::max(),
636  int32 npath = 1, bool weighted = false,
637  bool remove_total_weight = false)
638  : selector(selector),
639  max_length(max_length),
640  npath(npath),
641  weighted(weighted),
642  remove_total_weight(remove_total_weight) {}
643 };
644 
645 namespace internal {
646 
647 template <class FromArc, class ToArc>
649  public:
650  using StateId = typename FromArc::StateId;
651  using Weight = typename FromArc::Weight;
652 
653  explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
654 
655  void InitVisit(const Fst<FromArc> &ifst) {
656  ifst_ = &ifst;
657  ofst_->DeleteStates();
658  ofst_->SetInputSymbols(ifst.InputSymbols());
659  ofst_->SetOutputSymbols(ifst.OutputSymbols());
660  if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
661  path_.clear();
662  }
663 
664  constexpr bool InitState(StateId, StateId) const { return true; }
665 
666  bool TreeArc(StateId, const ToArc &arc) {
667  if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
668  path_.push_back(arc);
669  } else {
670  OutputPath();
671  }
672  return true;
673  }
674 
675  bool BackArc(StateId, const FromArc &) {
676  FSTERROR() << "RandGenVisitor: cyclic input";
677  ofst_->SetProperties(kError, kError);
678  return false;
679  }
680 
681  bool ForwardOrCrossArc(StateId, const FromArc &) {
682  OutputPath();
683  return true;
684  }
685 
686  void FinishState(StateId s, StateId p, const FromArc *) {
687  if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
688  }
689 
690  void FinishVisit() {}
691 
692  private:
693  void OutputPath() {
694  if (ofst_->Start() == kNoStateId) {
695  const auto start = ofst_->AddState();
696  ofst_->SetStart(start);
697  }
698  auto src = ofst_->Start();
699  for (size_t i = 0; i < path_.size(); ++i) {
700  const auto dest = ofst_->AddState();
701  const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
702  ofst_->AddArc(src, arc);
703  src = dest;
704  }
705  ofst_->SetFinal(src, Weight::One());
706  }
707 
708  const Fst<FromArc> *ifst_;
709  MutableFst<ToArc> *ofst_;
710  std::vector<ToArc> path_;
711 
712  RandGenVisitor(const RandGenVisitor &) = delete;
713  RandGenVisitor &operator=(const RandGenVisitor &) = delete;
714 };
715 
716 } // namespace internal
717 
718 // Randomly generate paths through an FST; details controlled by
719 // RandGenOptions.
720 template <class FromArc, class ToArc, class Selector>
721 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
722  const RandGenOptions<Selector> &opts) {
723  using State = typename ToArc::StateId;
724  using Weight = typename ToArc::Weight;
725  using Sampler = ArcSampler<FromArc, Selector>;
726  auto *sampler = new Sampler(ifst, opts.selector, opts.max_length);
727  RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler, opts.npath,
728  opts.weighted, opts.remove_total_weight);
729  RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
730  if (opts.weighted) {
731  *ofst = rfst;
732  } else {
733  internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
734  DfsVisit(rfst, &rand_visitor);
735  }
736 }
737 
738 // Randomly generate a path through an FST with the uniform distribution
739 // over the transitions.
740 template <class FromArc, class ToArc>
741 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst) {
742  const UniformArcSelector<FromArc> uniform_selector;
743  RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
744  RandGen(ifst, ofst, opts);
745 }
746 
747 } // namespace fst
748 
749 #endif // FST_RANDGEN_H_
void OneMultinomialSample(const std::vector< double > &probs, size_t num_to_sample, Result *result, RNG *rng)
Definition: randgen.h:248
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:299
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:99
typename FromArc::Weight Weight
Definition: randgen.h:559
StateIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst)
Definition: randgen.h:599
uint64 Properties(uint64 mask) const override
Definition: randgen.h:475
typename FromArc::StateId StateId
Definition: randgen.h:406
constexpr bool InitState(StateId, StateId) const
Definition: randgen.h:664
constexpr int kNoLabel
Definition: fst.h:179
void Next()
Definition: randgen.h:222
typename ToArc::Weight ToWeight
Definition: randgen.h:409
uint64_t uint64
Definition: types.h:32
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:84
void Reset()
Definition: fst.h:515
typename Arc::StateId StateId
Definition: randgen.h:50
virtual size_t NumArcs(StateId) const =0
std::pair< size_t, size_t > Value() const
Definition: randgen.h:323
StateId state_id
Definition: randgen.h:155
void FinishState(StateId s, StateId p, const FromArc *)
Definition: randgen.h:686
RandState(StateId state_id, size_t nsamples=0, size_t length=0, size_t select=0, const RandState< Arc > *parent=nullptr)
Definition: randgen.h:161
typename FromArc::StateId StateId
Definition: randgen.h:558
void Expand(StateId s)
Definition: randgen.h:490
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:189
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:82
const SymbolTable * OutputSymbols() const
Definition: fst.h:690
size_t NumInputEpsilons(StateId s)
Definition: randgen.h:462
bool Error() const
Definition: randgen.h:228
RandGenFstImpl(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:411
RandGenOptions(const Selector &selector, int32 max_length=std::numeric_limits< int32 >::max(), int32 npath=1, bool weighted=false, bool remove_total_weight=false)
Definition: randgen.h:634
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:94
typename Arc::Weight Weight
Definition: randgen.h:77
size_t NumOutputEpsilons(StateId s)
Definition: randgen.h:467
typename Arc::StateId StateId
Definition: randgen.h:76
bool BackArc(StateId, const FromArc &)
Definition: randgen.h:675
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:37
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:413
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:88
RandGenFst(const RandGenFst< FromArc, ToArc, Sampler > &fst, bool safe=false)
Definition: randgen.h:573
RandGenVisitor(MutableFst< ToArc > *ofst)
Definition: randgen.h:653
typename FromArc::Label Label
Definition: randgen.h:557
constexpr uint64 kFstProperties
Definition: properties.h:301
void InitVisit(const Fst< FromArc > &ifst)
Definition: randgen.h:655
constexpr uint64 kCopyProperties
Definition: properties.h:138
constexpr int kNoStateId
Definition: fst.h:180
typename Arc::Weight Weight
Definition: randgen.h:180
const Arc & Value() const
Definition: fst.h:503
size_t nsamples
Definition: randgen.h:156
virtual uint64 Properties(uint64 mask, bool test) const =0
uint64 RandGenProperties(uint64 inprops, bool weighted)
Definition: properties.cc:216
#define FSTERROR()
Definition: util.h:35
size_t length
Definition: randgen.h:157
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:449
UniformArcSelector(uint64 seed)
Definition: randgen.h:56
StateIteratorBase< Arc > * base
Definition: fst.h:351
typename Arc::Weight Weight
Definition: randgen.h:51
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:93
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:286
RandGenFstImpl(const RandGenFstImpl &impl)
Definition: randgen.h:428
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:721
typename Arc::StateId StateId
Definition: randgen.h:153
typename FromArc::Weight Weight
Definition: randgen.h:651
bool Done() const
Definition: randgen.h:219
void Seek(size_t a)
Definition: fst.h:523
RandGenFst(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:569
bool ForwardOrCrossArc(StateId, const FromArc &)
Definition: randgen.h:681
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:644
void Reset()
Definition: randgen.h:226
bool TreeArc(StateId, const ToArc &arc)
Definition: randgen.h:666
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
bool Done() const
Definition: fst.h:499
LogProbArcSelector(uint64 seed)
Definition: randgen.h:82
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath=1, bool weighted=true, bool remove_total_weight=false)
Definition: randgen.h:377
std::pair< size_t, size_t > Value() const
Definition: randgen.h:224
size_t operator()(const Fst< Arc > &fst, StateId s, CacheLogAccumulator< Arc > *accumulator) const
Definition: randgen.h:127
ArcIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst, StateId s)
Definition: randgen.h:611
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data)
Definition: randgen.h:483
void InitStateIterator(StateIteratorData< ToArc > *data) const override
Definition: randgen.h:619
void SetState(StateId s, int depth=0)
Definition: accumulator.h:492
RandGenFst< FromArc, ToArc, Sampler > * Copy(bool safe=false) const override
Definition: randgen.h:577
typename FromArc::Label Label
Definition: randgen.h:405
FastLogProbArcSelector(uint64 seed)
Definition: randgen.h:125
virtual const SymbolTable * InputSymbols() const =0
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:514
const SymbolTable * InputSymbols() const
Definition: fst.h:688
const RandState< Arc > * parent
Definition: randgen.h:159
typename FromArc::StateId StateId
Definition: randgen.h:650
ToWeight Final(StateId s)
Definition: randgen.h:452
uint64 Seed() const
Definition: randgen.h:141
int32_t int32
Definition: types.h:26
constexpr uint64 kError
Definition: properties.h:33
uint64 Properties() const override
Definition: randgen.h:472
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data) const override
Definition: randgen.h:583
const Selector & selector
Definition: randgen.h:627
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:551
typename Arc::StateId StateId
Definition: randgen.h:179
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:302
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:203
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32 max_length=std::numeric_limits< int32 >::max())
Definition: randgen.h:276
internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetMutableImpl() const
Definition: fst.h:947
typename FromArc::Weight FromWeight
Definition: randgen.h:407
size_t NumArcs(StateId s)
Definition: randgen.h:457
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:58
size_t select
Definition: randgen.h:158
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32 max_length=std::numeric_limits< int32 >::max())
Definition: randgen.h:184
const internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetImpl() const
Definition: fst.h:945
bool remove_total_weight
Definition: randgen.h:632
typename Store::State State
Definition: randgen.h:562
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:507