FST  openfst-1.8.3
OpenFst Library
randgen.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes and functions to generate random paths through an FST.
19 
20 #ifndef FST_RANDGEN_H_
21 #define FST_RANDGEN_H_
22 
23 #include <algorithm>
24 #include <cmath>
25 #include <cstddef>
26 #include <cstdint>
27 #include <cstring>
28 #include <functional>
29 #include <limits>
30 #include <map>
31 #include <memory>
32 #include <numeric>
33 #include <random>
34 #include <utility>
35 #include <vector>
36 
37 #include <fst/log.h>
38 #include <fst/accumulator.h>
39 #include <fst/arc.h>
40 #include <fst/cache.h>
41 #include <fst/dfs-visit.h>
42 #include <fst/float-weight.h>
43 #include <fst/fst-decl.h>
44 #include <fst/fst.h>
45 #include <fst/impl-to-fst.h>
46 #include <fst/mutable-fst.h>
47 #include <fst/properties.h>
48 #include <fst/util.h>
49 #include <fst/weight.h>
50 #include <vector>
51 
52 namespace fst {
53 
54 // The RandGenFst class is roughly similar to ArcMapFst in that it takes two
55 // template parameters denoting the input and output arc types. However, it also
56 // takes an additional template parameter which specifies a sampler object which
57 // samples (with replacement) arcs from an FST state. The sampler in turn takes
58 // a template parameter for a selector object which actually chooses the arc.
59 //
60 // Arc selector functors are used to select a random transition given an FST
61 // state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
62 // NumArcs(s), then the final weight is selected; otherwise the N-th arc is
63 // selected. It is assumed these are not applied to any state which is neither
64 // final nor has any arcs leaving it.
65 
66 // Randomly selects a transition using the uniform distribution. This class is
67 // not thread-safe.
68 template <class Arc>
70  public:
71  using StateId = typename Arc::StateId;
72  using Weight = typename Arc::Weight;
73 
74  explicit UniformArcSelector(uint64_t seed = std::random_device()())
75  : rand_(seed) {}
76 
77  size_t operator()(const Fst<Arc> &fst, StateId s) const {
78  const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
79  return static_cast<size_t>(
80  std::uniform_int_distribution<>(0, n - 1)(rand_));
81  }
82 
83  private:
84  mutable std::mt19937_64 rand_;
85 };
86 
87 // Randomly selects a transition w.r.t. the weights treated as negative log
88 // probabilities after normalizing for the total weight leaving the state. Zero
89 // transitions are disregarded. It assumed that Arc::Weight::Value() accesses
90 // the floating point representation of the weight. This class is not
91 // thread-safe.
92 template <class Arc>
94  public:
95  using StateId = typename Arc::StateId;
96  using Weight = typename Arc::Weight;
97 
98  // Constructs a selector with a non-deterministic seed.
99  LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
100 
101  // Constructs a selector with a given seed.
102  explicit LogProbArcSelector(uint64_t seed) : seed_(seed), rand_(seed) {}
103 
104  size_t operator()(const Fst<Arc> &fst, StateId s) const {
105  // Finds total weight leaving state.
106  auto sum = Log64Weight::Zero();
107  ArcIterator<Fst<Arc>> aiter(fst, s);
108  for (; !aiter.Done(); aiter.Next()) {
109  const auto &arc = aiter.Value();
110  sum = Plus(sum, to_log_weight_(arc.weight));
111  }
112  sum = Plus(sum, to_log_weight_(fst.Final(s)));
113  const double threshold =
114  std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
115  auto p = Log64Weight::Zero();
116  size_t n = 0;
117  for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
118  p = Plus(p, to_log_weight_(aiter.Value().weight));
119  if (exp(-p.Value()) > threshold) return n;
120  }
121  return n;
122  }
123 
124  uint64_t Seed() const { return seed_; }
125 
126  protected:
127  Log64Weight ToLogWeight(const Weight &weight) const {
128  return to_log_weight_(weight);
129  }
130 
131  std::mt19937_64 &MutableRand() const { return rand_; }
132 
133  private:
134  const uint64_t seed_;
135  mutable std::mt19937_64 rand_;
136  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
137 };
138 
139 // Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
140 // accumulation computations. This class is not thread-safe.
141 template <class Arc>
143  public:
144  using StateId = typename Arc::StateId;
145  using Weight = typename Arc::Weight;
146 
150 
151  // Constructs a selector with a non-deterministic seed.
153  // Constructs a selector with a given seed.
154  explicit FastLogProbArcSelector(uint64_t seed)
155  : LogProbArcSelector<Arc>(seed) {}
156 
157  size_t operator()(const Fst<Arc> &fst, StateId s,
158  CacheLogAccumulator<Arc> *accumulator) const {
159  accumulator->SetState(s);
160  ArcIterator<Fst<Arc>> aiter(fst, s);
161  // Finds total weight leaving state.
162  const double sum =
163  ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s)))
164  .Value();
165  const double r =
166  -log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
167  Weight w = from_log_weight_(r + sum);
168  aiter.Reset();
169  return accumulator->LowerBound(w, &aiter);
170  }
171 
172  private:
173  const WeightConvert<Log64Weight, Weight> from_log_weight_{};
174 };
175 
176 // Random path state info maintained by RandGenFst and passed to samplers.
177 template <typename Arc>
178 struct RandState {
179  using StateId = typename Arc::StateId;
180 
181  StateId state_id; // Current input FST state.
182  size_t nsamples; // Number of samples to be sampled at this state.
183  size_t length; // Length of path to this random state.
184  size_t select; // Previous sample arc selection.
185  const RandState<Arc> *parent; // Previous random state on this path.
186 
187  explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
188  size_t select = 0, const RandState<Arc> *parent = nullptr)
189  : state_id(state_id),
190  nsamples(nsamples),
191  length(length),
192  select(select),
193  parent(parent) {}
194 
196 };
197 
198 // This class, given an arc selector, samples, with replacement, multiple random
199 // transitions from an FST's state. This is a generic version with a
200 // straightforward use of the arc selector. Specializations may be defined for
201 // arc selectors for greater efficiency or special behavior.
202 template <class Arc, class Selector>
203 class ArcSampler {
204  public:
205  using StateId = typename Arc::StateId;
206  using Weight = typename Arc::Weight;
207 
208  // The max_length argument may be interpreted (or ignored) by a selector as
209  // it chooses. This generic version interprets this literally.
210  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
211  int32_t max_length = std::numeric_limits<int32_t>::max())
212  : fst_(fst), selector_(selector), max_length_(max_length) {}
213 
214  // Allow updating FST argument; pass only if changed.
216  const Fst<Arc> *fst = nullptr)
217  : fst_(fst ? *fst : sampler.fst_),
218  selector_(sampler.selector_),
219  max_length_(sampler.max_length_) {
220  Reset();
221  }
222 
223  // Samples a fixed number of samples from the given state. The length argument
224  // specifies the length of the path to the state. Returns true if the samples
225  // were collected. No samples may be collected if either there are no
226  // transitions leaving the state and the state is non-final, or if the path
227  // length has been exceeded. Iterator members are provided to read the samples
228  // in the order in which they were collected.
229  bool Sample(const RandState<Arc> &rstate) {
230  sample_map_.clear();
231  if ((fst_.NumArcs(rstate.state_id) == 0 &&
232  fst_.Final(rstate.state_id) == Weight::Zero()) ||
233  rstate.length == max_length_) {
234  Reset();
235  return false;
236  }
237  for (size_t i = 0; i < rstate.nsamples; ++i) {
238  ++sample_map_[selector_(fst_, rstate.state_id)];
239  }
240  Reset();
241  return true;
242  }
243 
244  // More samples?
245  bool Done() const { return sample_iter_ == sample_map_.end(); }
246 
247  // Gets the next sample.
248  void Next() { ++sample_iter_; }
249 
250  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
251 
252  void Reset() { sample_iter_ = sample_map_.begin(); }
253 
254  bool Error() const { return false; }
255 
256  private:
257  const Fst<Arc> &fst_;
258  const Selector &selector_;
259  const int32_t max_length_;
260 
261  // Stores (N, K) as described for Value().
262  std::map<size_t, size_t> sample_map_;
263  std::map<size_t, size_t>::const_iterator sample_iter_;
264 
265  ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
266 };
267 
268 // Samples one sample of num_to_sample dimensions from a multinomial
269 // distribution parameterized by a vector of probabilities. The result
270 // container should be pre-initialized (e.g., an empty map or a zeroed vector
271 // sized the same as the vector of probabilities.
272 // probs.size()).
273 template <class Result, class RNG>
274 void OneMultinomialSample(const std::vector<double> &probs,
275  size_t num_to_sample, Result *result, RNG *rng) {
276  using distribution = std::binomial_distribution<size_t>;
277  // Left-over probability mass. Keep an array of the partial sums because
278  // keeping a scalar and modifying norm -= probs[i] in the loop will result
279  // in round-off error and can have probs[i] > norm.
280  std::vector<double> norm(probs.size());
281  std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
282  // Left-over number of samples needed.
283  for (size_t i = 0; i < probs.size(); ++i) {
284  distribution::result_type num_sampled = 0;
285  if (probs[i] > 0) {
286  distribution d(num_to_sample, probs[i] / norm[i]);
287  num_sampled = d(*rng);
288  }
289  if (num_sampled != 0) (*result)[i] = num_sampled;
290  num_to_sample -= std::min(num_sampled, num_to_sample);
291  }
292 }
293 
294 // Specialization for FastLogProbArcSelector.
295 template <class Arc>
297  public:
298  using StateId = typename Arc::StateId;
299  using Weight = typename Arc::Weight;
300 
303 
304  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
305  int32_t max_length = std::numeric_limits<int32_t>::max())
306  : fst_(fst),
307  selector_(selector),
308  max_length_(max_length),
309  accumulator_(new Accumulator()) {
310  accumulator_->Init(fst);
311  rng_.seed(selector_.Seed());
312  }
313 
315  const Fst<Arc> *fst = nullptr)
316  : fst_(fst ? *fst : sampler.fst_),
317  selector_(sampler.selector_),
318  max_length_(sampler.max_length_) {
319  if (fst) {
320  accumulator_ = std::make_unique<Accumulator>();
321  accumulator_->Init(*fst);
322  } else { // Shallow copy.
323  accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
324  }
325  }
326 
327  bool Sample(const RandState<Arc> &rstate) {
328  sample_map_.clear();
329  if ((fst_.NumArcs(rstate.state_id) == 0 &&
330  fst_.Final(rstate.state_id) == Weight::Zero()) ||
331  rstate.length == max_length_) {
332  Reset();
333  return false;
334  }
335  if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
336  MultinomialSample(rstate);
337  Reset();
338  return true;
339  }
340  for (size_t i = 0; i < rstate.nsamples; ++i) {
341  ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
342  }
343  Reset();
344  return true;
345  }
346 
347  bool Done() const { return sample_iter_ == sample_map_.end(); }
348 
349  void Next() { ++sample_iter_; }
350 
351  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
352 
353  void Reset() { sample_iter_ = sample_map_.begin(); }
354 
355  bool Error() const { return accumulator_->Error(); }
356 
357  private:
358  using RNG = std::mt19937;
359 
360  // Sample according to the multinomial distribution of rstate.nsamples draws
361  // from p_.
362  void MultinomialSample(const RandState<Arc> &rstate) {
363  p_.clear();
364  for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
365  aiter.Next()) {
366  p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
367  }
368  if (fst_.Final(rstate.state_id) != Weight::Zero()) {
369  p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
370  }
371  if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
372  OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
373  } else {
374  for (size_t i = 0; i < p_.size(); ++i) {
375  sample_map_[i] = ceil(p_[i] * rstate.nsamples);
376  }
377  }
378  }
379 
380  const Fst<Arc> &fst_;
381  const Selector &selector_;
382  const int32_t max_length_;
383 
384  // Stores (N, K) for Value().
385  std::map<size_t, size_t> sample_map_;
386  std::map<size_t, size_t>::const_iterator sample_iter_;
387 
388  std::unique_ptr<Accumulator> accumulator_;
389  RNG rng_; // Random number generator.
390  std::vector<double> p_; // Multinomial parameters.
391  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
392 };
393 
394 // Options for random path generation with RandGenFst. The template argument is
395 // a sampler, typically the class ArcSampler. Ownership of the sampler is taken
396 // by RandGenFst.
397 template <class Sampler>
399  Sampler *sampler; // How to sample transitions at a state.
400  int32_t npath; // Number of paths to generate.
401  bool weighted; // Is the output tree weighted by path count, or
402  // is it just an unweighted DAG?
403  bool remove_total_weight; // Remove total weight when output is weighted.
404 
405  RandGenFstOptions(const CacheOptions &opts, Sampler *sampler,
406  int32_t npath = 1, bool weighted = true,
407  bool remove_total_weight = false)
408  : CacheOptions(opts),
409  sampler(sampler),
410  npath(npath),
411  weighted(weighted),
412  remove_total_weight(remove_total_weight) {}
413 };
414 
415 namespace internal {
416 
417 // Implementation of RandGenFst.
418 template <class FromArc, class ToArc, class Sampler>
419 class RandGenFstImpl : public CacheImpl<ToArc> {
420  public:
425 
426  using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
427  using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
428  using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
429  using CacheBaseImpl<CacheState<ToArc>>::HasStart;
430  using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
431  using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
432  using CacheBaseImpl<CacheState<ToArc>>::SetStart;
433 
434  using Label = typename FromArc::Label;
435  using StateId = typename FromArc::StateId;
436  using FromWeight = typename FromArc::Weight;
437 
438  using ToWeight = typename ToArc::Weight;
439 
441  const RandGenFstOptions<Sampler> &opts)
442  : CacheImpl<ToArc>(opts),
443  fst_(fst.Copy()),
444  sampler_(opts.sampler),
445  npath_(opts.npath),
446  weighted_(opts.weighted),
447  remove_total_weight_(opts.remove_total_weight),
448  superfinal_(kNoLabel) {
449  SetType("randgen");
450  SetProperties(
451  RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
453  SetInputSymbols(fst.InputSymbols());
454  SetOutputSymbols(fst.OutputSymbols());
455  }
456 
458  : CacheImpl<ToArc>(impl),
459  fst_(impl.fst_->Copy(true)),
460  sampler_(new Sampler(*impl.sampler_, fst_.get())),
461  npath_(impl.npath_),
462  weighted_(impl.weighted_),
463  superfinal_(kNoLabel) {
464  SetType("randgen");
465  SetProperties(impl.Properties(), kCopyProperties);
466  SetInputSymbols(impl.InputSymbols());
467  SetOutputSymbols(impl.OutputSymbols());
468  }
469 
471  if (!HasStart()) {
472  const auto s = fst_->Start();
473  if (s == kNoStateId) return kNoStateId;
474  SetStart(state_table_.size());
475  state_table_.emplace_back(
476  new RandState<FromArc>(s, npath_, 0, 0, nullptr));
477  }
478  return CacheImpl<ToArc>::Start();
479  }
480 
482  if (!HasFinal(s)) Expand(s);
483  return CacheImpl<ToArc>::Final(s);
484  }
485 
486  size_t NumArcs(StateId s) {
487  if (!HasArcs(s)) Expand(s);
488  return CacheImpl<ToArc>::NumArcs(s);
489  }
490 
492  if (!HasArcs(s)) Expand(s);
494  }
495 
497  if (!HasArcs(s)) Expand(s);
499  }
500 
501  uint64_t Properties() const override { return Properties(kFstProperties); }
502 
503  // Sets error if found, and returns other FST impl properties.
504  uint64_t Properties(uint64_t mask) const override {
505  if ((mask & kError) &&
506  (fst_->Properties(kError, false) || sampler_->Error())) {
507  SetProperties(kError, kError);
508  }
509  return FstImpl<ToArc>::Properties(mask);
510  }
511 
513  if (!HasArcs(s)) Expand(s);
515  }
516 
517  // Computes the outgoing transitions from a state, creating new destination
518  // states as needed.
519  void Expand(StateId s) {
520  if (s == superfinal_) {
521  SetFinal(s);
522  SetArcs(s);
523  return;
524  }
525  SetFinal(s, ToWeight::Zero());
526  const auto &rstate = *state_table_[s];
527  sampler_->Sample(rstate);
528  ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
529  const auto narcs = fst_->NumArcs(rstate.state_id);
530  for (; !sampler_->Done(); sampler_->Next()) {
531  const auto &sample_pair = sampler_->Value();
532  const auto pos = sample_pair.first;
533  const auto count = sample_pair.second;
534  double prob = static_cast<double>(count) / rstate.nsamples;
535  if (pos < narcs) { // Regular transition.
536  aiter.Seek(sample_pair.first);
537  const auto &aarc = aiter.Value();
538  auto weight =
539  weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
540  EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
541  state_table_.size());
542  auto nrstate = std::make_unique<RandState<FromArc>>(
543  aarc.nextstate, count, rstate.length + 1, pos, &rstate);
544  state_table_.push_back(std::move(nrstate));
545  } else { // Super-final transition.
546  if (weighted_) {
547  const auto weight =
548  remove_total_weight_
549  ? to_weight_(Log64Weight(-log(prob)))
550  : to_weight_(Log64Weight(-log(prob * npath_)));
551  SetFinal(s, weight);
552  } else {
553  if (superfinal_ == kNoLabel) {
554  superfinal_ = state_table_.size();
555  state_table_.emplace_back(
556  new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
557  }
558  for (size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
559  }
560  }
561  }
562  SetArcs(s);
563  }
564 
565  private:
566  const std::unique_ptr<Fst<FromArc>> fst_;
567  std::unique_ptr<Sampler> sampler_;
568  const int32_t npath_;
569  std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
570  const bool weighted_;
571  bool remove_total_weight_;
572  StateId superfinal_;
573  const WeightConvert<Log64Weight, ToWeight> to_weight_{};
574 };
575 
576 } // namespace internal
577 
578 // FST class to randomly generate paths through an FST, with details controlled
579 // by RandGenOptionsFst. Output format is a tree weighted by the path count.
580 template <class FromArc, class ToArc, class Sampler>
581 class RandGenFst
582  : public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
583  public:
584  using Label = typename FromArc::Label;
585  using StateId = typename FromArc::StateId;
586  using Weight = typename FromArc::Weight;
587 
589  using State = typename Store::State;
590 
592 
593  friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
594  friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
595 
597  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
598 
599  // See Fst<>::Copy() for doc.
600  RandGenFst(const RandGenFst &fst, bool safe = false)
601  : ImplToFst<Impl>(fst, safe) {}
602 
603  // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
604  RandGenFst *Copy(bool safe = false) const override {
605  return new RandGenFst(*this, safe);
606  }
607 
608  inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
609 
610  void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
611  GetMutableImpl()->InitArcIterator(s, data);
612  }
613 
614  private:
617 
618  RandGenFst &operator=(const RandGenFst &) = delete;
619 };
620 
621 // Specialization for RandGenFst.
622 template <class FromArc, class ToArc, class Sampler>
623 class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
624  : public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
625  public:
627  : CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
628  fst, fst.GetMutableImpl()) {}
629 };
630 
631 // Specialization for RandGenFst.
632 template <class FromArc, class ToArc, class Sampler>
633 class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
634  : public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
635  public:
636  using StateId = typename FromArc::StateId;
637 
639  : CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
640  fst.GetMutableImpl(), s) {
641  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
642  }
643 };
644 
645 template <class FromArc, class ToArc, class Sampler>
647  StateIteratorData<ToArc> *data) const {
648  data->base =
649  std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
650  *this);
651 }
652 
653 // Options for random path generation.
654 template <class Selector>
656  const Selector &selector; // How an arc is selected at a state.
657  int32_t max_length; // Maximum path length.
658  int32_t npath; // Number of paths to generate.
659  bool weighted; // Is the output tree weighted by path count, or
660  // is it just an unweighted DAG?
661  bool remove_total_weight; // Remove total weight when output is weighted?
662 
663  explicit RandGenOptions(
664  const Selector &selector,
665  int32_t max_length = std::numeric_limits<int32_t>::max(),
666  int32_t npath = 1, bool weighted = false,
667  bool remove_total_weight = false)
668  : selector(selector),
669  max_length(max_length),
670  npath(npath),
671  weighted(weighted),
672  remove_total_weight(remove_total_weight) {}
673 };
674 
675 namespace internal {
676 
677 template <class FromArc, class ToArc>
679  public:
680  using StateId = typename FromArc::StateId;
681  using Weight = typename FromArc::Weight;
682 
683  explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
684 
685  void InitVisit(const Fst<FromArc> &ifst) {
686  ifst_ = &ifst;
687  ofst_->DeleteStates();
688  ofst_->SetInputSymbols(ifst.InputSymbols());
689  ofst_->SetOutputSymbols(ifst.OutputSymbols());
690  if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
691  path_.clear();
692  }
693 
694  constexpr bool InitState(StateId, StateId) const { return true; }
695 
696  bool TreeArc(StateId, const ToArc &arc) {
697  if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
698  path_.push_back(arc);
699  } else {
700  OutputPath();
701  }
702  return true;
703  }
704 
705  bool BackArc(StateId, const FromArc &) {
706  FSTERROR() << "RandGenVisitor: cyclic input";
707  ofst_->SetProperties(kError, kError);
708  return false;
709  }
710 
711  bool ForwardOrCrossArc(StateId, const FromArc &) {
712  OutputPath();
713  return true;
714  }
715 
716  void FinishState(StateId s, StateId p, const FromArc *) {
717  if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
718  }
719 
720  void FinishVisit() {}
721 
722  private:
723  void OutputPath() {
724  if (ofst_->Start() == kNoStateId) {
725  const auto start = ofst_->AddState();
726  ofst_->SetStart(start);
727  }
728  auto src = ofst_->Start();
729  for (size_t i = 0; i < path_.size(); ++i) {
730  const auto dest = ofst_->AddState();
731  const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
732  ofst_->AddArc(src, arc);
733  src = dest;
734  }
735  ofst_->SetFinal(src);
736  }
737 
738  const Fst<FromArc> *ifst_;
739  MutableFst<ToArc> *ofst_;
740  std::vector<ToArc> path_;
741 
742  RandGenVisitor(const RandGenVisitor &) = delete;
743  RandGenVisitor &operator=(const RandGenVisitor &) = delete;
744 };
745 
746 } // namespace internal
747 
748 // Randomly generate paths through an FST; details controlled by
749 // RandGenOptions.
750 template <class FromArc, class ToArc, class Selector>
751 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
752  const RandGenOptions<Selector> &opts) {
753  using Sampler = ArcSampler<FromArc, Selector>;
754  auto sampler =
755  std::make_unique<Sampler>(ifst, opts.selector, opts.max_length);
756  RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler.release(),
757  opts.npath, opts.weighted,
758  opts.remove_total_weight);
759  RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
760  if (opts.weighted) {
761  *ofst = rfst;
762  } else {
763  internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
764  DfsVisit(rfst, &rand_visitor);
765  }
766 }
767 
768 // Randomly generate a path through an FST with the uniform distribution
769 // over the transitions.
770 template <class FromArc, class ToArc>
771 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
772  uint64_t seed = std::random_device()()) {
773  const UniformArcSelector<FromArc> uniform_selector(seed);
774  RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
775  RandGen(ifst, ofst, opts);
776 }
777 
778 } // namespace fst
779 
780 #endif // FST_RANDGEN_H_
void OneMultinomialSample(const std::vector< double > &probs, size_t num_to_sample, Result *result, RNG *rng)
Definition: randgen.h:274
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:327
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:124
typename FromArc::Weight Weight
Definition: randgen.h:586
StateIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst)
Definition: randgen.h:626
uint64_t RandGenProperties(uint64_t inprops, bool weighted)
Definition: properties.cc:234
typename FromArc::StateId StateId
Definition: randgen.h:435
constexpr bool InitState(StateId, StateId) const
Definition: randgen.h:694
constexpr int kNoLabel
Definition: fst.h:195
void Next()
Definition: randgen.h:248
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename ToArc::Weight ToWeight
Definition: randgen.h:438
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:104
void Reset()
Definition: fst.h:548
FastLogProbArcSelector(uint64_t seed)
Definition: randgen.h:154
typename Arc::StateId StateId
Definition: randgen.h:71
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
virtual size_t NumArcs(StateId) const =0
std::pair< size_t, size_t > Value() const
Definition: randgen.h:351
std::mt19937_64 & MutableRand() const
Definition: randgen.h:131
StateId state_id
Definition: randgen.h:181
void FinishState(StateId s, StateId p, const FromArc *)
Definition: randgen.h:716
RandState(StateId state_id, size_t nsamples=0, size_t length=0, size_t select=0, const RandState< Arc > *parent=nullptr)
Definition: randgen.h:187
typename FromArc::StateId StateId
Definition: randgen.h:585
void Expand(StateId s)
Definition: randgen.h:519
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:215
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:107
const SymbolTable * OutputSymbols() const
Definition: fst.h:761
size_t NumInputEpsilons(StateId s)
Definition: randgen.h:491
bool Error() const
Definition: randgen.h:254
RandGenFstImpl(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:440
int32_t max_length
Definition: randgen.h:657
constexpr uint64_t kError
Definition: properties.h:52
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:112
typename Arc::Weight Weight
Definition: randgen.h:96
size_t NumOutputEpsilons(StateId s)
Definition: randgen.h:496
UniformArcSelector(uint64_t seed=std::random_device()())
Definition: randgen.h:74
typename Arc::StateId StateId
Definition: randgen.h:95
bool BackArc(StateId, const FromArc &)
Definition: randgen.h:705
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:59
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:436
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:113
RandGenVisitor(MutableFst< ToArc > *ofst)
Definition: randgen.h:683
LogProbArcSelector(uint64_t seed)
Definition: randgen.h:102
typename FromArc::Label Label
Definition: randgen.h:584
void InitVisit(const Fst< FromArc > &ifst)
Definition: randgen.h:685
constexpr int kNoStateId
Definition: fst.h:196
typename Arc::Weight Weight
Definition: randgen.h:206
const Arc & Value() const
Definition: fst.h:536
size_t nsamples
Definition: randgen.h:182
#define FSTERROR()
Definition: util.h:56
RandGenFst * Copy(bool safe=false) const override
Definition: randgen.h:604
size_t length
Definition: randgen.h:183
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:472
uint64_t Seed() const
Definition: randgen.h:124
typename Arc::Weight Weight
Definition: randgen.h:72
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:118
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:314
RandGenFstImpl(const RandGenFstImpl &impl)
Definition: randgen.h:457
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:751
typename Arc::StateId StateId
Definition: randgen.h:179
typename FromArc::Weight Weight
Definition: randgen.h:681
bool Done() const
Definition: randgen.h:245
void Seek(size_t a)
Definition: fst.h:556
RandGenFst(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:596
constexpr uint64_t kCopyProperties
Definition: properties.h:163
bool ForwardOrCrossArc(StateId, const FromArc &)
Definition: randgen.h:711
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:673
void Reset()
Definition: randgen.h:252
bool TreeArc(StateId, const ToArc &arc)
Definition: randgen.h:696
bool Done() const
Definition: fst.h:532
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
std::pair< size_t, size_t > Value() const
Definition: randgen.h:250
typename Arc::Weight Weight
Definition: randgen.h:145
size_t operator()(const Fst< Arc > &fst, StateId s, CacheLogAccumulator< Arc > *accumulator) const
Definition: randgen.h:157
ArcIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst, StateId s)
Definition: randgen.h:638
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data)
Definition: randgen.h:512
void InitStateIterator(StateIteratorData< ToArc > *data) const override
Definition: randgen.h:646
RandGenOptions(const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max(), int32_t npath=1, bool weighted=false, bool remove_total_weight=false)
Definition: randgen.h:663
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32_t npath=1, bool weighted=true, bool remove_total_weight=false)
Definition: randgen.h:405
void SetState(StateId s, int depth=0)
Definition: accumulator.h:518
RandGenFst(const RandGenFst &fst, bool safe=false)
Definition: randgen.h:600
constexpr uint64_t kFstProperties
Definition: properties.h:326
typename FromArc::Label Label
Definition: randgen.h:434
virtual const SymbolTable * InputSymbols() const =0
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:543
const SymbolTable * InputSymbols() const
Definition: fst.h:759
const RandState< Arc > * parent
Definition: randgen.h:185
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
Definition: randgen.h:210
typename FromArc::StateId StateId
Definition: randgen.h:680
ToWeight Final(StateId s)
Definition: randgen.h:481
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data) const override
Definition: randgen.h:610
Log64Weight ToLogWeight(const Weight &weight) const
Definition: randgen.h:127
const Selector & selector
Definition: randgen.h:656
uint64_t Properties() const override
Definition: randgen.h:501
uint64_t Properties(uint64_t mask) const override
Definition: randgen.h:504
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:580
typename Arc::StateId StateId
Definition: randgen.h:205
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:323
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:229
internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetMutableImpl() const
Definition: impl-to-fst.h:125
typename FromArc::Weight FromWeight
Definition: randgen.h:436
size_t NumArcs(StateId s)
Definition: randgen.h:486
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:77
size_t select
Definition: randgen.h:184
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
Definition: randgen.h:304
const internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetImpl() const
Definition: impl-to-fst.h:123
bool remove_total_weight
Definition: randgen.h:661
typename Store::State State
Definition: randgen.h:589
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:540