FST  openfst-1.8.3
OpenFst Library
randequivalent.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Tests if two FSTS are equivalent by checking if random strings from one FST
19 // are transduced the same by both FSTs.
20 
21 #ifndef FST_RANDEQUIVALENT_H_
22 #define FST_RANDEQUIVALENT_H_
23 
24 #include <cstdint>
25 #include <limits>
26 #include <random>
27 
28 #include <fst/log.h>
29 #include <fst/arcsort.h>
30 #include <fst/compose.h>
31 #include <fst/connect.h>
32 #include <fst/fst.h>
33 #include <fst/project.h>
34 #include <fst/properties.h>
35 #include <fst/randgen.h>
36 #include <fst/shortest-distance.h>
37 #include <fst/symbol-table.h>
38 #include <fst/util.h>
39 #include <fst/vector-fst.h>
40 #include <fst/weight.h>
41 
42 namespace fst {
43 
44 // Test if two FSTs are stochastically equivalent by randomly generating
45 // random paths through the FSTs.
46 //
47 // For each randomly generated path, the algorithm computes for each
48 // of the two FSTs the sum of the weights of all the successful paths
49 // sharing the same input and output labels as the considered randomly
50 // generated path and checks that these two values are within a user-specified
51 // delta. Returns optional error value (when FST_FLAGS_error_fatal = false).
52 template <class Arc, class ArcSelector>
53 bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32_t npath,
54  const RandGenOptions<ArcSelector> &opts,
55  float delta = kDelta,
56  uint64_t seed = std::random_device()(),
57  bool *error = nullptr) {
58  using Weight = typename Arc::Weight;
59  if (error) *error = false;
60  // Checks that the symbol table are compatible.
61  if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) ||
62  !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) {
63  FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st "
64  << "argument do not match input/output symbol tables of 2nd "
65  << "argument";
66  if (error) *error = true;
67  return false;
68  }
69  static const ILabelCompare<Arc> icomp;
70  static const OLabelCompare<Arc> ocomp;
71  VectorFst<Arc> sfst1(fst1);
72  VectorFst<Arc> sfst2(fst2);
73  Connect(&sfst1);
74  Connect(&sfst2);
75  ArcSort(&sfst1, icomp);
76  ArcSort(&sfst2, icomp);
77  bool result = true;
78  std::mt19937 rand(seed);
79  std::bernoulli_distribution coin(.5);
80  for (int32_t n = 0; n < npath; ++n) {
81  VectorFst<Arc> path;
82  const auto &fst = coin(rand) ? sfst1 : sfst2;
83  RandGen(fst, &path, opts);
84  VectorFst<Arc> ipath(path);
85  VectorFst<Arc> opath(path);
86  Project(&ipath, ProjectType::INPUT);
88  VectorFst<Arc> cfst1, pfst1;
89  Compose(ipath, sfst1, &cfst1);
90  ArcSort(&cfst1, ocomp);
91  Compose(cfst1, opath, &pfst1);
92  // Gives up if there are epsilon cycles in a non-idempotent semiring.
93  if (!IsIdempotent<Weight>::value && pfst1.Properties(kCyclic, true)) {
94  continue;
95  }
96  const auto sum1 = ShortestDistance(pfst1);
97  VectorFst<Arc> cfst2;
98  Compose(ipath, sfst2, &cfst2);
99  ArcSort(&cfst2, ocomp);
100  VectorFst<Arc> pfst2;
101  Compose(cfst2, opath, &pfst2);
102  // Gives up if there are epsilon cycles in a non-idempotent semiring.
103  if (!IsIdempotent<Weight>::value && pfst2.Properties(kCyclic, true)) {
104  continue;
105  }
106  const auto sum2 = ShortestDistance(pfst2);
107  if (!ApproxEqual(sum1, sum2, delta)) {
108  VLOG(1) << "Sum1 = " << sum1;
109  VLOG(1) << "Sum2 = " << sum2;
110  result = false;
111  break;
112  }
113  }
114  if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) {
115  if (error) *error = true;
116  return false;
117  }
118  return result;
119 }
120 
121 // Tests if two FSTs are equivalent by randomly generating a nnpath paths
122 // (no longer than the path_length) using a user-specified seed, optionally
123 // indicating an error setting an optional error argument to true.
124 template <class Arc>
125 bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32_t npath,
126  float delta = kDelta,
127  uint64_t seed = std::random_device()(),
128  int32_t max_length = std::numeric_limits<int32_t>::max(),
129  bool *error = nullptr) {
130  const UniformArcSelector<Arc> uniform_selector(seed);
131  const RandGenOptions<UniformArcSelector<Arc>> opts(uniform_selector,
132  max_length);
133  return RandEquivalent(fst1, fst2, npath, opts, delta, seed, error);
134 }
135 
136 } // namespace fst
137 
138 #endif // FST_RANDEQUIVALENT_H_
constexpr uint64_t kCyclic
Definition: properties.h:109
bool RandEquivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, int32_t npath, const RandGenOptions< ArcSelector > &opts, float delta=kDelta, uint64_t seed=std::random_device()(), bool *error=nullptr)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kError
Definition: properties.h:52
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:47
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:159
#define FSTERROR()
Definition: util.h:56
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
Definition: randgen.h:751
void ArcSort(MutableFst< Arc > *fst, Compare comp)
Definition: arcsort.h:109
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
Definition: compose.h:1005
#define VLOG(level)
Definition: log.h:54
void Project(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, ProjectType project_type)
Definition: project.h:89
virtual const SymbolTable * InputSymbols() const =0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
constexpr float kDelta
Definition: weight.h:133
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
uint64_t Properties(uint64_t mask, bool test) const override
Definition: impl-to-fst.h:63
virtual const SymbolTable * OutputSymbols() const =0