FST  openfst-1.8.2
OpenFst Library
randequivalent.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Tests if two FSTS are equivalent by checking if random strings from one FST
19 // are transduced the same by both FSTs.
20 
21 #ifndef FST_RANDEQUIVALENT_H_
22 #define FST_RANDEQUIVALENT_H_
23 
24 #include <cstdint>
25 #include <random>
26 
27 #include <fst/log.h>
28 
29 #include <fst/arcsort.h>
30 #include <fst/compose.h>
31 #include <fst/project.h>
32 #include <fst/randgen.h>
33 #include <fst/shortest-distance.h>
34 #include <fst/vector-fst.h>
35 #include <fst/weight.h>
36 
37 
38 namespace fst {
39 
40 // Test if two FSTs are stochastically equivalent by randomly generating
41 // random paths through the FSTs.
42 //
43 // For each randomly generated path, the algorithm computes for each
44 // of the two FSTs the sum of the weights of all the successful paths
45 // sharing the same input and output labels as the considered randomly
46 // generated path and checks that these two values are within a user-specified
47 // delta. Returns optional error value (when FST_FLAGS_error_fatal = false).
48 template <class Arc, class ArcSelector>
49 bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32_t npath,
50  const RandGenOptions<ArcSelector> &opts,
51  float delta = kDelta,
52  uint64_t seed = std::random_device()(),
53  bool *error = nullptr) {
54  using Weight = typename Arc::Weight;
55  if (error) *error = false;
56  // Checks that the symbol table are compatible.
57  if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) ||
58  !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) {
59  FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st "
60  << "argument do not match input/output symbol tables of 2nd "
61  << "argument";
62  if (error) *error = true;
63  return false;
64  }
65  static const ILabelCompare<Arc> icomp;
66  static const OLabelCompare<Arc> ocomp;
67  VectorFst<Arc> sfst1(fst1);
68  VectorFst<Arc> sfst2(fst2);
69  Connect(&sfst1);
70  Connect(&sfst2);
71  ArcSort(&sfst1, icomp);
72  ArcSort(&sfst2, icomp);
73  bool result = true;
74  std::mt19937 rand(seed);
75  std::bernoulli_distribution coin(.5);
76  for (int32_t n = 0; n < npath; ++n) {
77  VectorFst<Arc> path;
78  const auto &fst = coin(rand) ? sfst1 : sfst2;
79  RandGen(fst, &path, opts);
80  VectorFst<Arc> ipath(path);
81  VectorFst<Arc> opath(path);
82  Project(&ipath, ProjectType::INPUT);
84  VectorFst<Arc> cfst1, pfst1;
85  Compose(ipath, sfst1, &cfst1);
86  ArcSort(&cfst1, ocomp);
87  Compose(cfst1, opath, &pfst1);
88  // Gives up if there are epsilon cycles in a non-idempotent semiring.
89  if (!IsIdempotent<Weight>::value && pfst1.Properties(kCyclic, true)) {
90  continue;
91  }
92  const auto sum1 = ShortestDistance(pfst1);
93  VectorFst<Arc> cfst2;
94  Compose(ipath, sfst2, &cfst2);
95  ArcSort(&cfst2, ocomp);
96  VectorFst<Arc> pfst2;
97  Compose(cfst2, opath, &pfst2);
98  // Gives up if there are epsilon cycles in a non-idempotent semiring.
99  if (!IsIdempotent<Weight>::value && pfst2.Properties(kCyclic, true)) {
100  continue;
101  }
102  const auto sum2 = ShortestDistance(pfst2);
103  if (!ApproxEqual(sum1, sum2, delta)) {
104  VLOG(1) << "Sum1 = " << sum1;
105  VLOG(1) << "Sum2 = " << sum2;
106  result = false;
107  break;
108  }
109  }
110  if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) {
111  if (error) *error = true;
112  return false;
113  }
114  return result;
115 }
116 
117 // Tests if two FSTs are equivalent by randomly generating a nnpath paths
118 // (no longer than the path_length) using a user-specified seed, optionally
119 // indicating an error setting an optional error argument to true.
120 template <class Arc>
121 bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32_t npath,
122  float delta = kDelta,
123  uint64_t seed = std::random_device()(),
124  int32_t max_length = std::numeric_limits<int32_t>::max(),
125  bool *error = nullptr) {
126  const UniformArcSelector<Arc> uniform_selector(seed);
127  const RandGenOptions<UniformArcSelector<Arc>> opts(uniform_selector,
128  max_length);
129  return RandEquivalent(fst1, fst2, npath, opts, delta, seed, error);
130 }
131 
132 } // namespace fst
133 
134 #endif // FST_RANDEQUIVALENT_H_
constexpr uint64_t kCyclic
Definition: properties.h:108
bool RandEquivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, int32_t npath, const RandGenOptions< ArcSelector > &opts, float delta=kDelta, uint64_t seed=std::random_device()(), bool *error=nullptr)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kError
Definition: properties.h:51
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:278
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:156
#define FSTERROR()
Definition: util.h:53
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:752
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:102
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
Definition: compose.h:995
#define VLOG(level)
Definition: log.h:50
void Project(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, ProjectType project_type)
Definition: project.h:84
virtual const SymbolTable * InputSymbols() const =0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
constexpr float kDelta
Definition: weight.h:130
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:57
uint64_t Properties(uint64_t mask, bool test) const override
Definition: fst.h:966
virtual const SymbolTable * OutputSymbols() const =0