FST  openfst-1.7.2
OpenFst Library
randequivalent.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Tests if two FSTS are equivalent by checking if random strings from one FST
5 // are transduced the same by both FSTs.
6 
7 #ifndef FST_RANDEQUIVALENT_H_
8 #define FST_RANDEQUIVALENT_H_
9 
10 #include <fst/log.h>
11 
12 #include <fst/arcsort.h>
13 #include <fst/compose.h>
14 #include <fst/project.h>
15 #include <fst/randgen.h>
16 #include <fst/shortest-distance.h>
17 #include <fst/vector-fst.h>
18 
19 
20 namespace fst {
21 
22 // Test if two FSTs are stochastically equivalent by randomly generating
23 // random paths through the FSTs.
24 //
25 // For each randomly generated path, the algorithm computes for each
26 // of the two FSTs the sum of the weights of all the successful paths
27 // sharing the same input and output labels as the considered randomly
28 // generated path and checks that these two values are within a user-specified
29 // delta. Returns optional error value (when FLAGS_error_fatal = false).
30 template <class Arc, class ArcSelector>
31 bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
32  int32 num_paths, float delta,
33  const RandGenOptions<ArcSelector> &opts,
34  bool *error = nullptr) {
35  using Weight = typename Arc::Weight;
36  if (error) *error = false;
37  // Checks that the symbol table are compatible.
38  if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) ||
39  !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) {
40  FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st "
41  << "argument do not match input/output symbol tables of 2nd "
42  << "argument";
43  if (error) *error = true;
44  return false;
45  }
46  static const ILabelCompare<Arc> icomp;
47  static const OLabelCompare<Arc> ocomp;
48  VectorFst<Arc> sfst1(fst1);
49  VectorFst<Arc> sfst2(fst2);
50  Connect(&sfst1);
51  Connect(&sfst2);
52  ArcSort(&sfst1, icomp);
53  ArcSort(&sfst2, icomp);
54  bool result = true;
55  for (int32 n = 0; n < num_paths; ++n) {
56  VectorFst<Arc> path;
57  const auto &fst = rand() % 2 ? sfst1 : sfst2; // NOLINT
58  RandGen(fst, &path, opts);
59  VectorFst<Arc> ipath(path);
60  VectorFst<Arc> opath(path);
61  Project(&ipath, PROJECT_INPUT);
62  Project(&opath, PROJECT_OUTPUT);
63  VectorFst<Arc> cfst1, pfst1;
64  Compose(ipath, sfst1, &cfst1);
65  ArcSort(&cfst1, ocomp);
66  Compose(cfst1, opath, &pfst1);
67  // Gives up if there are epsilon cycles in a non-idempotent semiring.
68  if (!(Weight::Properties() & kIdempotent) &&
69  pfst1.Properties(kCyclic, true)) {
70  continue;
71  }
72  const auto sum1 = ShortestDistance(pfst1);
73  VectorFst<Arc> cfst2;
74  Compose(ipath, sfst2, &cfst2);
75  ArcSort(&cfst2, ocomp);
76  VectorFst<Arc> pfst2;
77  Compose(cfst2, opath, &pfst2);
78  // Gives up if there are epsilon cycles in a non-idempotent semiring.
79  if (!(Weight::Properties() & kIdempotent) &&
80  pfst2.Properties(kCyclic, true)) {
81  continue;
82  }
83  const auto sum2 = ShortestDistance(pfst2);
84  if (!ApproxEqual(sum1, sum2, delta)) {
85  VLOG(1) << "Sum1 = " << sum1;
86  VLOG(1) << "Sum2 = " << sum2;
87  result = false;
88  break;
89  }
90  }
91  if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) {
92  if (error) *error = true;
93  return false;
94  }
95  return result;
96 }
97 
98 // Tests if two FSTs are equivalent by randomly generating a nnum_paths paths
99 // (no longer than the path_length) using a user-specified seed, optionally
100 // indicating an error setting an optional error argument to true.
101 template <class Arc>
102 bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32 num_paths,
103  float delta = kDelta, time_t seed = time(nullptr),
104  int32 max_length = std::numeric_limits<int32>::max(),
105  bool *error = nullptr) {
106  const UniformArcSelector<Arc> uniform_selector(seed);
107  const RandGenOptions<UniformArcSelector<Arc>> opts(uniform_selector,
108  max_length);
109  return RandEquivalent(fst1, fst2, num_paths, delta, opts, error);
110 }
111 
112 } // namespace fst
113 
114 #endif // FST_RANDEQUIVALENT_H_
bool RandEquivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, int32 num_paths, float delta, const RandGenOptions< ArcSelector > &opts, bool *error=nullptr)
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:268
virtual uint64 Properties(uint64 mask, bool test) const =0
#define FSTERROR()
Definition: util.h:35
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:730
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:87
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
Definition: compose.h:981
constexpr uint64 kIdempotent
Definition: weight.h:123
#define VLOG(level)
Definition: log.h:49
void Project(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, ProjectType project_type)
Definition: project.h:65
virtual const SymbolTable * InputSymbols() const =0
int32_t int32
Definition: types.h:26
constexpr bool ApproxEqual(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2, float delta=kDelta)
Definition: float-weight.h:140
constexpr uint64 kError
Definition: properties.h:33
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
uint64 Properties(uint64 mask, bool test) const override
Definition: fst.h:889
constexpr float kDelta
Definition: weight.h:109
constexpr uint64 kCyclic
Definition: properties.h:90
virtual const SymbolTable * OutputSymbols() const =0