FST  openfst-1.8.2
OpenFst Library
randgen.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes and functions to generate random paths through an FST.
19 
20 #ifndef FST_RANDGEN_H_
21 #define FST_RANDGEN_H_
22 
23 #include <algorithm>
24 #include <cmath>
25 #include <cstddef>
26 #include <cstdint>
27 #include <limits>
28 #include <map>
29 #include <memory>
30 #include <numeric>
31 #include <random>
32 #include <utility>
33 #include <vector>
34 
35 #include <fst/log.h>
36 
37 #include <fst/accumulator.h>
38 #include <fst/cache.h>
39 #include <fst/dfs-visit.h>
40 #include <fst/float-weight.h>
41 #include <fst/fst-decl.h>
42 #include <fst/fst.h>
43 #include <fst/mutable-fst.h>
44 #include <fst/properties.h>
45 #include <fst/util.h>
46 #include <fst/weight.h>
47 
48 #include <vector>
49 
50 namespace fst {
51 
52 // The RandGenFst class is roughly similar to ArcMapFst in that it takes two
53 // template parameters denoting the input and output arc types. However, it also
54 // takes an additional template parameter which specifies a sampler object which
55 // samples (with replacement) arcs from an FST state. The sampler in turn takes
56 // a template parameter for a selector object which actually chooses the arc.
57 //
58 // Arc selector functors are used to select a random transition given an FST
59 // state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
60 // NumArcs(s), then the final weight is selected; otherwise the N-th arc is
61 // selected. It is assumed these are not applied to any state which is neither
62 // final nor has any arcs leaving it.
63 
64 // Randomly selects a transition using the uniform distribution. This class is
65 // not thread-safe.
66 template <class Arc>
68  public:
69  using StateId = typename Arc::StateId;
70  using Weight = typename Arc::Weight;
71 
72  explicit UniformArcSelector(uint64_t seed = std::random_device()())
73  : rand_(seed) {}
74 
75  size_t operator()(const Fst<Arc> &fst, StateId s) const {
76  const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
77  return static_cast<size_t>(
78  std::uniform_int_distribution<>(0, n - 1)(rand_));
79  }
80 
81  private:
82  mutable std::mt19937_64 rand_;
83 };
84 
85 // Randomly selects a transition w.r.t. the weights treated as negative log
86 // probabilities after normalizing for the total weight leaving the state. Zero
87 // transitions are disregarded. It assumed that Arc::Weight::Value() accesses
88 // the floating point representation of the weight. This class is not
89 // thread-safe.
90 template <class Arc>
92  public:
93  using StateId = typename Arc::StateId;
94  using Weight = typename Arc::Weight;
95 
96  // Constructs a selector with a non-deterministic seed.
97  LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
98 
99  // Constructs a selector with a given seed.
100  explicit LogProbArcSelector(uint64_t seed) : seed_(seed), rand_(seed) {}
101 
102  size_t operator()(const Fst<Arc> &fst, StateId s) const {
103  // Finds total weight leaving state.
104  auto sum = Log64Weight::Zero();
105  ArcIterator<Fst<Arc>> aiter(fst, s);
106  for (; !aiter.Done(); aiter.Next()) {
107  const auto &arc = aiter.Value();
108  sum = Plus(sum, to_log_weight_(arc.weight));
109  }
110  sum = Plus(sum, to_log_weight_(fst.Final(s)));
111  const double threshold =
112  std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
113  auto p = Log64Weight::Zero();
114  size_t n = 0;
115  for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
116  p = Plus(p, to_log_weight_(aiter.Value().weight));
117  if (exp(-p.Value()) > threshold) return n;
118  }
119  return n;
120  }
121 
122  uint64_t Seed() const { return seed_; }
123 
124  protected:
125  Log64Weight ToLogWeight(const Weight &weight) const {
126  return to_log_weight_(weight);
127  }
128 
129  std::mt19937_64 &MutableRand() const { return rand_; }
130 
131  private:
132  const uint64_t seed_;
133  mutable std::mt19937_64 rand_;
134  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
135 };
136 
137 // Useful alias when using StdArc.
139 
140 // Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
141 // accumulation computations. This class is not thread-safe.
142 template <class Arc>
144  public:
145  using StateId = typename Arc::StateId;
146  using Weight = typename Arc::Weight;
147 
151 
152  // Constructs a selector with a non-deterministic seed.
154  // Constructs a selector with a given seed.
155  explicit FastLogProbArcSelector(uint64_t seed)
156  : LogProbArcSelector<Arc>(seed) {}
157 
158  size_t operator()(const Fst<Arc> &fst, StateId s,
159  CacheLogAccumulator<Arc> *accumulator) const {
160  accumulator->SetState(s);
161  ArcIterator<Fst<Arc>> aiter(fst, s);
162  // Finds total weight leaving state.
163  const double sum =
164  ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s)))
165  .Value();
166  const double r =
167  -log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
168  Weight w = from_log_weight_(r + sum);
169  aiter.Reset();
170  return accumulator->LowerBound(w, &aiter);
171  }
172 
173  private:
174  const WeightConvert<Log64Weight, Weight> from_log_weight_{};
175 };
176 
177 // Random path state info maintained by RandGenFst and passed to samplers.
178 template <typename Arc>
179 struct RandState {
180  using StateId = typename Arc::StateId;
181 
182  StateId state_id; // Current input FST state.
183  size_t nsamples; // Number of samples to be sampled at this state.
184  size_t length; // Length of path to this random state.
185  size_t select; // Previous sample arc selection.
186  const RandState<Arc> *parent; // Previous random state on this path.
187 
188  explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
189  size_t select = 0, const RandState<Arc> *parent = nullptr)
190  : state_id(state_id),
191  nsamples(nsamples),
192  length(length),
193  select(select),
194  parent(parent) {}
195 
197 };
198 
199 // This class, given an arc selector, samples, with replacement, multiple random
200 // transitions from an FST's state. This is a generic version with a
201 // straightforward use of the arc selector. Specializations may be defined for
202 // arc selectors for greater efficiency or special behavior.
203 template <class Arc, class Selector>
204 class ArcSampler {
205  public:
206  using StateId = typename Arc::StateId;
207  using Weight = typename Arc::Weight;
208 
209  // The max_length argument may be interpreted (or ignored) by a selector as
210  // it chooses. This generic version interprets this literally.
211  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
212  int32_t max_length = std::numeric_limits<int32_t>::max())
213  : fst_(fst), selector_(selector), max_length_(max_length) {}
214 
215  // Allow updating FST argument; pass only if changed.
217  const Fst<Arc> *fst = nullptr)
218  : fst_(fst ? *fst : sampler.fst_),
219  selector_(sampler.selector_),
220  max_length_(sampler.max_length_) {
221  Reset();
222  }
223 
224  // Samples a fixed number of samples from the given state. The length argument
225  // specifies the length of the path to the state. Returns true if the samples
226  // were collected. No samples may be collected if either there are no
227  // transitions leaving the state and the state is non-final, or if the path
228  // length has been exceeded. Iterator members are provided to read the samples
229  // in the order in which they were collected.
230  bool Sample(const RandState<Arc> &rstate) {
231  sample_map_.clear();
232  if ((fst_.NumArcs(rstate.state_id) == 0 &&
233  fst_.Final(rstate.state_id) == Weight::Zero()) ||
234  rstate.length == max_length_) {
235  Reset();
236  return false;
237  }
238  for (size_t i = 0; i < rstate.nsamples; ++i) {
239  ++sample_map_[selector_(fst_, rstate.state_id)];
240  }
241  Reset();
242  return true;
243  }
244 
245  // More samples?
246  bool Done() const { return sample_iter_ == sample_map_.end(); }
247 
248  // Gets the next sample.
249  void Next() { ++sample_iter_; }
250 
251  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
252 
253  void Reset() { sample_iter_ = sample_map_.begin(); }
254 
255  bool Error() const { return false; }
256 
257  private:
258  const Fst<Arc> &fst_;
259  const Selector &selector_;
260  const int32_t max_length_;
261 
262  // Stores (N, K) as described for Value().
263  std::map<size_t, size_t> sample_map_;
264  std::map<size_t, size_t>::const_iterator sample_iter_;
265 
266  ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
267 };
268 
269 // Samples one sample of num_to_sample dimensions from a multinomial
270 // distribution parameterized by a vector of probabilities. The result
271 // container should be pre-initialized (e.g., an empty map or a zeroed vector
272 // sized the same as the vector of probabilities.
273 // probs.size()).
274 template <class Result, class RNG>
275 void OneMultinomialSample(const std::vector<double> &probs,
276  size_t num_to_sample, Result *result, RNG *rng) {
277  using distribution = std::binomial_distribution<size_t>;
278  // Left-over probability mass. Keep an array of the partial sums because
279  // keeping a scalar and modifying norm -= probs[i] in the loop will result
280  // in round-off error and can have probs[i] > norm.
281  std::vector<double> norm(probs.size());
282  std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
283  // Left-over number of samples needed.
284  for (size_t i = 0; i < probs.size(); ++i) {
285  distribution::result_type num_sampled = 0;
286  if (probs[i] > 0) {
287  distribution d(num_to_sample, probs[i] / norm[i]);
288  num_sampled = d(*rng);
289  }
290  if (num_sampled != 0) (*result)[i] = num_sampled;
291  num_to_sample -= std::min(num_sampled, num_to_sample);
292  }
293 }
294 
295 // Specialization for FastLogProbArcSelector.
296 template <class Arc>
298  public:
299  using StateId = typename Arc::StateId;
300  using Weight = typename Arc::Weight;
301 
304 
305  ArcSampler(const Fst<Arc> &fst, const Selector &selector,
306  int32_t max_length = std::numeric_limits<int32_t>::max())
307  : fst_(fst),
308  selector_(selector),
309  max_length_(max_length),
310  accumulator_(new Accumulator()) {
311  accumulator_->Init(fst);
312  rng_.seed(selector_.Seed());
313  }
314 
316  const Fst<Arc> *fst = nullptr)
317  : fst_(fst ? *fst : sampler.fst_),
318  selector_(sampler.selector_),
319  max_length_(sampler.max_length_) {
320  if (fst) {
321  accumulator_ = std::make_unique<Accumulator>();
322  accumulator_->Init(*fst);
323  } else { // Shallow copy.
324  accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
325  }
326  }
327 
328  bool Sample(const RandState<Arc> &rstate) {
329  sample_map_.clear();
330  if ((fst_.NumArcs(rstate.state_id) == 0 &&
331  fst_.Final(rstate.state_id) == Weight::Zero()) ||
332  rstate.length == max_length_) {
333  Reset();
334  return false;
335  }
336  if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
337  MultinomialSample(rstate);
338  Reset();
339  return true;
340  }
341  for (size_t i = 0; i < rstate.nsamples; ++i) {
342  ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
343  }
344  Reset();
345  return true;
346  }
347 
348  bool Done() const { return sample_iter_ == sample_map_.end(); }
349 
350  void Next() { ++sample_iter_; }
351 
352  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
353 
354  void Reset() { sample_iter_ = sample_map_.begin(); }
355 
356  bool Error() const { return accumulator_->Error(); }
357 
358  private:
359  using RNG = std::mt19937;
360 
361  // Sample according to the multinomial distribution of rstate.nsamples draws
362  // from p_.
363  void MultinomialSample(const RandState<Arc> &rstate) {
364  p_.clear();
365  for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
366  aiter.Next()) {
367  p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
368  }
369  if (fst_.Final(rstate.state_id) != Weight::Zero()) {
370  p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
371  }
372  if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
373  OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
374  } else {
375  for (size_t i = 0; i < p_.size(); ++i) {
376  sample_map_[i] = ceil(p_[i] * rstate.nsamples);
377  }
378  }
379  }
380 
381  const Fst<Arc> &fst_;
382  const Selector &selector_;
383  const int32_t max_length_;
384 
385  // Stores (N, K) for Value().
386  std::map<size_t, size_t> sample_map_;
387  std::map<size_t, size_t>::const_iterator sample_iter_;
388 
389  std::unique_ptr<Accumulator> accumulator_;
390  RNG rng_; // Random number generator.
391  std::vector<double> p_; // Multinomial parameters.
392  const WeightConvert<Weight, Log64Weight> to_log_weight_{};
393 };
394 
395 // Options for random path generation with RandGenFst. The template argument is
396 // a sampler, typically the class ArcSampler. Ownership of the sampler is taken
397 // by RandGenFst.
398 template <class Sampler>
400  Sampler *sampler; // How to sample transitions at a state.
401  int32_t npath; // Number of paths to generate.
402  bool weighted; // Is the output tree weighted by path count, or
403  // is it just an unweighted DAG?
404  bool remove_total_weight; // Remove total weight when output is weighted.
405 
406  RandGenFstOptions(const CacheOptions &opts, Sampler *sampler,
407  int32_t npath = 1, bool weighted = true,
408  bool remove_total_weight = false)
409  : CacheOptions(opts),
410  sampler(sampler),
411  npath(npath),
412  weighted(weighted),
413  remove_total_weight(remove_total_weight) {}
414 };
415 
416 namespace internal {
417 
418 // Implementation of RandGenFst.
419 template <class FromArc, class ToArc, class Sampler>
420 class RandGenFstImpl : public CacheImpl<ToArc> {
421  public:
426 
427  using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
428  using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
429  using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
430  using CacheBaseImpl<CacheState<ToArc>>::HasStart;
431  using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
432  using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
433  using CacheBaseImpl<CacheState<ToArc>>::SetStart;
434 
435  using Label = typename FromArc::Label;
436  using StateId = typename FromArc::StateId;
437  using FromWeight = typename FromArc::Weight;
438 
439  using ToWeight = typename ToArc::Weight;
440 
442  const RandGenFstOptions<Sampler> &opts)
443  : CacheImpl<ToArc>(opts),
444  fst_(fst.Copy()),
445  sampler_(opts.sampler),
446  npath_(opts.npath),
447  weighted_(opts.weighted),
448  remove_total_weight_(opts.remove_total_weight),
449  superfinal_(kNoLabel) {
450  SetType("randgen");
451  SetProperties(
452  RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
454  SetInputSymbols(fst.InputSymbols());
455  SetOutputSymbols(fst.OutputSymbols());
456  }
457 
459  : CacheImpl<ToArc>(impl),
460  fst_(impl.fst_->Copy(true)),
461  sampler_(new Sampler(*impl.sampler_, fst_.get())),
462  npath_(impl.npath_),
463  weighted_(impl.weighted_),
464  superfinal_(kNoLabel) {
465  SetType("randgen");
466  SetProperties(impl.Properties(), kCopyProperties);
467  SetInputSymbols(impl.InputSymbols());
468  SetOutputSymbols(impl.OutputSymbols());
469  }
470 
472  if (!HasStart()) {
473  const auto s = fst_->Start();
474  if (s == kNoStateId) return kNoStateId;
475  SetStart(state_table_.size());
476  state_table_.emplace_back(
477  new RandState<FromArc>(s, npath_, 0, 0, nullptr));
478  }
479  return CacheImpl<ToArc>::Start();
480  }
481 
483  if (!HasFinal(s)) Expand(s);
484  return CacheImpl<ToArc>::Final(s);
485  }
486 
487  size_t NumArcs(StateId s) {
488  if (!HasArcs(s)) Expand(s);
489  return CacheImpl<ToArc>::NumArcs(s);
490  }
491 
493  if (!HasArcs(s)) Expand(s);
495  }
496 
498  if (!HasArcs(s)) Expand(s);
500  }
501 
502  uint64_t Properties() const override { return Properties(kFstProperties); }
503 
504  // Sets error if found, and returns other FST impl properties.
505  uint64_t Properties(uint64_t mask) const override {
506  if ((mask & kError) &&
507  (fst_->Properties(kError, false) || sampler_->Error())) {
508  SetProperties(kError, kError);
509  }
510  return FstImpl<ToArc>::Properties(mask);
511  }
512 
514  if (!HasArcs(s)) Expand(s);
516  }
517 
518  // Computes the outgoing transitions from a state, creating new destination
519  // states as needed.
520  void Expand(StateId s) {
521  if (s == superfinal_) {
522  SetFinal(s);
523  SetArcs(s);
524  return;
525  }
526  SetFinal(s, ToWeight::Zero());
527  const auto &rstate = *state_table_[s];
528  sampler_->Sample(rstate);
529  ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
530  const auto narcs = fst_->NumArcs(rstate.state_id);
531  for (; !sampler_->Done(); sampler_->Next()) {
532  const auto &sample_pair = sampler_->Value();
533  const auto pos = sample_pair.first;
534  const auto count = sample_pair.second;
535  double prob = static_cast<double>(count) / rstate.nsamples;
536  if (pos < narcs) { // Regular transition.
537  aiter.Seek(sample_pair.first);
538  const auto &aarc = aiter.Value();
539  auto weight =
540  weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
541  EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
542  state_table_.size());
543  auto *nrstate = new RandState<FromArc>(aarc.nextstate, count,
544  rstate.length + 1, pos, &rstate);
545  state_table_.emplace_back(nrstate);
546  } else { // Super-final transition.
547  if (weighted_) {
548  const auto weight =
549  remove_total_weight_
550  ? to_weight_(Log64Weight(-log(prob)))
551  : to_weight_(Log64Weight(-log(prob * npath_)));
552  SetFinal(s, weight);
553  } else {
554  if (superfinal_ == kNoLabel) {
555  superfinal_ = state_table_.size();
556  state_table_.emplace_back(
557  new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
558  }
559  for (size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
560  }
561  }
562  }
563  SetArcs(s);
564  }
565 
566  private:
567  const std::unique_ptr<Fst<FromArc>> fst_;
568  std::unique_ptr<Sampler> sampler_;
569  const int32_t npath_;
570  std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
571  const bool weighted_;
572  bool remove_total_weight_;
573  StateId superfinal_;
574  const WeightConvert<Log64Weight, ToWeight> to_weight_{};
575 };
576 
577 } // namespace internal
578 
579 // FST class to randomly generate paths through an FST, with details controlled
580 // by RandGenOptionsFst. Output format is a tree weighted by the path count.
581 template <class FromArc, class ToArc, class Sampler>
582 class RandGenFst
583  : public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
584  public:
585  using Label = typename FromArc::Label;
586  using StateId = typename FromArc::StateId;
587  using Weight = typename FromArc::Weight;
588 
590  using State = typename Store::State;
591 
593 
594  friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
595  friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
596 
598  : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
599 
600  // See Fst<>::Copy() for doc.
601  RandGenFst(const RandGenFst &fst, bool safe = false)
602  : ImplToFst<Impl>(fst, safe) {}
603 
604  // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
605  RandGenFst *Copy(bool safe = false) const override {
606  return new RandGenFst(*this, safe);
607  }
608 
609  inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
610 
611  void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
612  GetMutableImpl()->InitArcIterator(s, data);
613  }
614 
615  private:
618 
619  RandGenFst &operator=(const RandGenFst &) = delete;
620 };
621 
622 // Specialization for RandGenFst.
623 template <class FromArc, class ToArc, class Sampler>
624 class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
625  : public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
626  public:
628  : CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
629  fst, fst.GetMutableImpl()) {}
630 };
631 
632 // Specialization for RandGenFst.
633 template <class FromArc, class ToArc, class Sampler>
634 class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
635  : public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
636  public:
637  using StateId = typename FromArc::StateId;
638 
640  : CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
641  fst.GetMutableImpl(), s) {
642  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
643  }
644 };
645 
646 template <class FromArc, class ToArc, class Sampler>
648  StateIteratorData<ToArc> *data) const {
649  data->base =
650  std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
651  *this);
652 }
653 
654 // Options for random path generation.
655 template <class Selector>
657  const Selector &selector; // How an arc is selected at a state.
658  int32_t max_length; // Maximum path length.
659  int32_t npath; // Number of paths to generate.
660  bool weighted; // Is the output tree weighted by path count, or
661  // is it just an unweighted DAG?
662  bool remove_total_weight; // Remove total weight when output is weighted?
663 
664  explicit RandGenOptions(
665  const Selector &selector,
666  int32_t max_length = std::numeric_limits<int32_t>::max(),
667  int32_t npath = 1, bool weighted = false,
668  bool remove_total_weight = false)
669  : selector(selector),
670  max_length(max_length),
671  npath(npath),
672  weighted(weighted),
673  remove_total_weight(remove_total_weight) {}
674 };
675 
676 namespace internal {
677 
678 template <class FromArc, class ToArc>
680  public:
681  using StateId = typename FromArc::StateId;
682  using Weight = typename FromArc::Weight;
683 
684  explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
685 
686  void InitVisit(const Fst<FromArc> &ifst) {
687  ifst_ = &ifst;
688  ofst_->DeleteStates();
689  ofst_->SetInputSymbols(ifst.InputSymbols());
690  ofst_->SetOutputSymbols(ifst.OutputSymbols());
691  if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
692  path_.clear();
693  }
694 
695  constexpr bool InitState(StateId, StateId) const { return true; }
696 
697  bool TreeArc(StateId, const ToArc &arc) {
698  if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
699  path_.push_back(arc);
700  } else {
701  OutputPath();
702  }
703  return true;
704  }
705 
706  bool BackArc(StateId, const FromArc &) {
707  FSTERROR() << "RandGenVisitor: cyclic input";
708  ofst_->SetProperties(kError, kError);
709  return false;
710  }
711 
712  bool ForwardOrCrossArc(StateId, const FromArc &) {
713  OutputPath();
714  return true;
715  }
716 
717  void FinishState(StateId s, StateId p, const FromArc *) {
718  if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
719  }
720 
721  void FinishVisit() {}
722 
723  private:
724  void OutputPath() {
725  if (ofst_->Start() == kNoStateId) {
726  const auto start = ofst_->AddState();
727  ofst_->SetStart(start);
728  }
729  auto src = ofst_->Start();
730  for (size_t i = 0; i < path_.size(); ++i) {
731  const auto dest = ofst_->AddState();
732  const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
733  ofst_->AddArc(src, arc);
734  src = dest;
735  }
736  ofst_->SetFinal(src);
737  }
738 
739  const Fst<FromArc> *ifst_;
740  MutableFst<ToArc> *ofst_;
741  std::vector<ToArc> path_;
742 
743  RandGenVisitor(const RandGenVisitor &) = delete;
744  RandGenVisitor &operator=(const RandGenVisitor &) = delete;
745 };
746 
747 } // namespace internal
748 
749 // Randomly generate paths through an FST; details controlled by
750 // RandGenOptions.
751 template <class FromArc, class ToArc, class Selector>
752 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
753  const RandGenOptions<Selector> &opts) {
754  using Sampler = ArcSampler<FromArc, Selector>;
755  auto *sampler = new Sampler(ifst, opts.selector, opts.max_length);
756  RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler, opts.npath,
757  opts.weighted, opts.remove_total_weight);
758  RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
759  if (opts.weighted) {
760  *ofst = rfst;
761  } else {
762  internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
763  DfsVisit(rfst, &rand_visitor);
764  }
765 }
766 
767 // Randomly generate a path through an FST with the uniform distribution
768 // over the transitions.
769 template <class FromArc, class ToArc>
770 void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
771  uint64_t seed = std::random_device()()) {
772  const UniformArcSelector<FromArc> uniform_selector(seed);
773  RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
774  RandGen(ifst, ofst, opts);
775 }
776 
777 } // namespace fst
778 
779 #endif // FST_RANDGEN_H_
void OneMultinomialSample(const std::vector< double > &probs, size_t num_to_sample, Result *result, RNG *rng)
Definition: randgen.h:275
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:328
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:114
typename FromArc::Weight Weight
Definition: randgen.h:587
StateIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst)
Definition: randgen.h:627
uint64_t RandGenProperties(uint64_t inprops, bool weighted)
Definition: properties.cc:235
typename FromArc::StateId StateId
Definition: randgen.h:436
constexpr bool InitState(StateId, StateId) const
Definition: randgen.h:695
constexpr int kNoLabel
Definition: fst.h:201
void Next()
Definition: randgen.h:249
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename ToArc::Weight ToWeight
Definition: randgen.h:439
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:102
void Reset()
Definition: fst.h:549
FastLogProbArcSelector(uint64_t seed)
Definition: randgen.h:155
typename Arc::StateId StateId
Definition: randgen.h:69
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
virtual size_t NumArcs(StateId) const =0
std::pair< size_t, size_t > Value() const
Definition: randgen.h:352
std::mt19937_64 & MutableRand() const
Definition: randgen.h:129
StateId state_id
Definition: randgen.h:182
void FinishState(StateId s, StateId p, const FromArc *)
Definition: randgen.h:717
RandState(StateId state_id, size_t nsamples=0, size_t length=0, size_t select=0, const RandState< Arc > *parent=nullptr)
Definition: randgen.h:188
typename FromArc::StateId StateId
Definition: randgen.h:586
void Expand(StateId s)
Definition: randgen.h:520
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:216
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:97
const SymbolTable * OutputSymbols() const
Definition: fst.h:762
size_t NumInputEpsilons(StateId s)
Definition: randgen.h:492
bool Error() const
Definition: randgen.h:255
RandGenFstImpl(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:441
int32_t max_length
Definition: randgen.h:658
constexpr uint64_t kError
Definition: properties.h:51
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
Definition: dfs-visit.h:109
typename Arc::Weight Weight
Definition: randgen.h:94
size_t NumOutputEpsilons(StateId s)
Definition: randgen.h:497
UniformArcSelector(uint64_t seed=std::random_device()())
Definition: randgen.h:72
typename Arc::StateId StateId
Definition: randgen.h:93
bool BackArc(StateId, const FromArc &)
Definition: randgen.h:706
virtual Weight Final(StateId) const =0
SetType
Definition: set-weight.h:52
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:432
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:103
RandGenVisitor(MutableFst< ToArc > *ofst)
Definition: randgen.h:684
LogProbArcSelector(uint64_t seed)
Definition: randgen.h:100
typename FromArc::Label Label
Definition: randgen.h:585
void InitVisit(const Fst< FromArc > &ifst)
Definition: randgen.h:686
constexpr int kNoStateId
Definition: fst.h:202
typename Arc::Weight Weight
Definition: randgen.h:207
const Arc & Value() const
Definition: fst.h:537
size_t nsamples
Definition: randgen.h:183
#define FSTERROR()
Definition: util.h:53
RandGenFst * Copy(bool safe=false) const override
Definition: randgen.h:605
size_t length
Definition: randgen.h:184
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:468
uint64_t Seed() const
Definition: randgen.h:122
typename Arc::Weight Weight
Definition: randgen.h:70
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:108
ArcSampler(const ArcSampler< Arc, Selector > &sampler, const Fst< Arc > *fst=nullptr)
Definition: randgen.h:315
RandGenFstImpl(const RandGenFstImpl &impl)
Definition: randgen.h:458
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:752
typename Arc::StateId StateId
Definition: randgen.h:180
typename FromArc::Weight Weight
Definition: randgen.h:682
bool Done() const
Definition: randgen.h:246
void Seek(size_t a)
Definition: fst.h:557
RandGenFst(const Fst< FromArc > &fst, const RandGenFstOptions< Sampler > &opts)
Definition: randgen.h:597
constexpr uint64_t kCopyProperties
Definition: properties.h:162
bool ForwardOrCrossArc(StateId, const FromArc &)
Definition: randgen.h:712
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:666
void Reset()
Definition: randgen.h:253
bool TreeArc(StateId, const ToArc &arc)
Definition: randgen.h:697
bool Done() const
Definition: fst.h:533
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
std::pair< size_t, size_t > Value() const
Definition: randgen.h:251
typename Arc::Weight Weight
Definition: randgen.h:146
size_t operator()(const Fst< Arc > &fst, StateId s, CacheLogAccumulator< Arc > *accumulator) const
Definition: randgen.h:158
ArcIterator(const RandGenFst< FromArc, ToArc, Sampler > &fst, StateId s)
Definition: randgen.h:639
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data)
Definition: randgen.h:513
void InitStateIterator(StateIteratorData< ToArc > *data) const override
Definition: randgen.h:647
RandGenOptions(const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max(), int32_t npath=1, bool weighted=false, bool remove_total_weight=false)
Definition: randgen.h:664
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32_t npath=1, bool weighted=true, bool remove_total_weight=false)
Definition: randgen.h:406
void SetState(StateId s, int depth=0)
Definition: accumulator.h:513
RandGenFst(const RandGenFst &fst, bool safe=false)
Definition: randgen.h:601
constexpr uint64_t kFstProperties
Definition: properties.h:325
typename FromArc::Label Label
Definition: randgen.h:435
virtual const SymbolTable * InputSymbols() const =0
Weight Sum(Weight w, Weight v)
Definition: accumulator.h:535
const SymbolTable * InputSymbols() const
Definition: fst.h:760
const RandState< Arc > * parent
Definition: randgen.h:186
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
Definition: randgen.h:211
typename FromArc::StateId StateId
Definition: randgen.h:681
ToWeight Final(StateId s)
Definition: randgen.h:482
void InitArcIterator(StateId s, ArcIteratorData< ToArc > *data) const override
Definition: randgen.h:611
Log64Weight ToLogWeight(const Weight &weight) const
Definition: randgen.h:125
const Selector & selector
Definition: randgen.h:657
uint64_t Properties() const override
Definition: randgen.h:502
uint64_t Properties(uint64_t mask) const override
Definition: randgen.h:505
size_t LowerBound(Weight w, ArcIter *aiter)
Definition: accumulator.h:572
typename Arc::StateId StateId
Definition: randgen.h:206
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:316
bool Sample(const RandState< Arc > &rstate)
Definition: randgen.h:230
internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetMutableImpl() const
Definition: fst.h:1028
typename FromArc::Weight FromWeight
Definition: randgen.h:437
size_t NumArcs(StateId s)
Definition: randgen.h:487
size_t operator()(const Fst< Arc > &fst, StateId s) const
Definition: randgen.h:75
size_t select
Definition: randgen.h:185
ArcSampler(const Fst< Arc > &fst, const Selector &selector, int32_t max_length=std::numeric_limits< int32_t >::max())
Definition: randgen.h:305
const internal::RandGenFstImpl< FromArc, ToArc, Sampler > * GetImpl() const
Definition: fst.h:1026
bool remove_total_weight
Definition: randgen.h:662
typename Store::State State
Definition: randgen.h:590
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:541