FST  openfst-1.8.2
OpenFst Library
equal.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Function to test equality of two FSTs.
19 
20 #ifndef FST_EQUAL_H_
21 #define FST_EQUAL_H_
22 
23 #include <cstdint>
24 
25 #include <fst/log.h>
26 
27 #include <fst/fst.h>
28 #include <fst/test-properties.h>
29 
30 
31 namespace fst {
32 
33 inline constexpr uint8_t kEqualFsts = 0x01;
34 inline constexpr uint8_t kEqualFstTypes = 0x02;
35 inline constexpr uint8_t kEqualCompatProperties = 0x04;
36 inline constexpr uint8_t kEqualCompatSymbols = 0x08;
37 inline constexpr uint8_t kEqualAll =
38  kEqualFsts | kEqualFstTypes | kEqualCompatProperties | kEqualCompatSymbols;
39 
41  public:
42  explicit WeightApproxEqual(float delta) : delta_(delta) {}
43 
44  // We use two weight types to avoid some conflicts caused by
45  // conversions.
46  template <class Weight1, class Weight2>
47  bool operator()(const Weight1 &w1, const Weight2 &w2) const {
48  return ApproxEqual(w1, w2, delta_);
49  }
50 
51  private:
52  const float delta_;
53 };
54 
55 // Tests if two FSTs have the same states and arcs in the same order (when
56 // etype & kEqualFst); optionally, also checks equality of FST types
57 // (etype & kEqualFstTypes) and compatibility of stored properties
58 // (etype & kEqualCompatProperties) and of symbol tables
59 // (etype & kEqualCompatSymbols).
60 template <class Arc, class WeightEqual>
61 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, WeightEqual weight_equal,
62  uint8_t etype = kEqualFsts) {
63  if ((etype & kEqualFstTypes) && (fst1.Type() != fst2.Type())) {
64  VLOG(1) << "Equal: Mismatched FST types (" << fst1.Type()
65  << " != " << fst2.Type() << ")";
66  return false;
67  }
68  if ((etype & kEqualCompatProperties) &&
70  fst2.Properties(kCopyProperties, false))) {
71  VLOG(1) << "Equal: Properties not compatible";
72  return false;
73  }
74  if (etype & kEqualCompatSymbols) {
75  if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols(), false)) {
76  VLOG(1) << "Equal: Input symbols not compatible";
77  return false;
78  }
79  if (!CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols(), false)) {
80  VLOG(1) << "Equal: Output symbols not compatible";
81  return false;
82  }
83  }
84  if (!(etype & kEqualFsts)) return true;
85  if (fst1.Start() != fst2.Start()) {
86  VLOG(1) << "Equal: Mismatched start states (" << fst1.Start()
87  << " != " << fst2.Start() << ")";
88  return false;
89  }
90  StateIterator<Fst<Arc>> siter1(fst1);
91  StateIterator<Fst<Arc>> siter2(fst2);
92  while (!siter1.Done() || !siter2.Done()) {
93  if (siter1.Done() || siter2.Done()) {
94  VLOG(1) << "Equal: Mismatched number of states";
95  return false;
96  }
97  const auto s1 = siter1.Value();
98  const auto s2 = siter2.Value();
99  if (s1 != s2) {
100  VLOG(1) << "Equal: Mismatched states (" << s1 << "!= " << s2 << ")";
101  return false;
102  }
103  const auto &final1 = fst1.Final(s1);
104  const auto &final2 = fst2.Final(s2);
105  if (!weight_equal(final1, final2)) {
106  VLOG(1) << "Equal: Mismatched final weights at state " << s1 << " ("
107  << final1 << " != " << final2 << ")";
108  return false;
109  }
110  ArcIterator<Fst<Arc>> aiter1(fst1, s1);
111  ArcIterator<Fst<Arc>> aiter2(fst2, s2);
112  for (auto a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
113  if (aiter1.Done() || aiter2.Done()) {
114  VLOG(1) << "Equal: Mismatched number of arcs at state " << s1;
115  return false;
116  }
117  const auto &arc1 = aiter1.Value();
118  const auto &arc2 = aiter2.Value();
119  if (arc1.ilabel != arc2.ilabel) {
120  VLOG(1) << "Equal: Mismatched arc input labels at state " << s1
121  << ", arc " << a << " (" << arc1.ilabel << " != " << arc2.ilabel
122  << ")";
123  return false;
124  } else if (arc1.olabel != arc2.olabel) {
125  VLOG(1) << "Equal: Mismatched arc output labels at state " << s1
126  << ", arc " << a << " (" << arc1.olabel << " != " << arc2.olabel
127  << ")";
128  return false;
129  } else if (!weight_equal(arc1.weight, arc2.weight)) {
130  VLOG(1) << "Equal: Mismatched arc weights at state " << s1 << ", arc "
131  << a << " (" << arc1.weight << " != " << arc2.weight << ")";
132  return false;
133  } else if (arc1.nextstate != arc2.nextstate) {
134  VLOG(1) << "Equal: Mismatched next state at state " << s1 << ", arc "
135  << a << " (" << arc1.nextstate << " != " << arc2.nextstate
136  << ")";
137  return false;
138  }
139  aiter1.Next();
140  aiter2.Next();
141  }
142  // Sanity checks: should never fail.
143  if (fst1.NumArcs(s1) != fst2.NumArcs(s2)) {
144  FSTERROR() << "Equal: Inconsistent arc counts at state " << s1 << " ("
145  << fst1.NumArcs(s1) << " != " << fst2.NumArcs(s2) << ")";
146  return false;
147  }
148  if (fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2)) {
149  FSTERROR() << "Equal: Inconsistent input epsilon counts at state " << s1
150  << " (" << fst1.NumInputEpsilons(s1)
151  << " != " << fst2.NumInputEpsilons(s2) << ")";
152  return false;
153  }
154  if (fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) {
155  FSTERROR() << "Equal: Inconsistent output epsilon counts at state " << s1
156  << " (" << fst1.NumOutputEpsilons(s1)
157  << " != " << fst2.NumOutputEpsilons(s2) << ")";
158  }
159  siter1.Next();
160  siter2.Next();
161  }
162  return true;
163 }
164 
165 template <class Arc>
166 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta = kDelta,
167  uint8_t etype = kEqualFsts) {
168  return Equal(fst1, fst2, WeightApproxEqual(delta), etype);
169 }
170 
171 // Support double deltas without forcing all clients to cast to float.
172 // Without this overload, Equal<Arc, WeightEqual=double> will be chosen,
173 // since it is a better match than double -> float narrowing, but
174 // the instantiation will fail.
175 template <class Arc>
176 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, double delta,
177  uint8_t etype = kEqualFsts) {
178  return Equal(fst1, fst2, WeightApproxEqual(static_cast<float>(delta)), etype);
179 }
180 
181 } // namespace fst
182 
183 #endif // FST_EQUAL_H_
virtual uint64_t Properties(uint64_t mask, bool test) const =0
virtual size_t NumArcs(StateId) const =0
bool CompatProperties(uint64_t props1, uint64_t props2)
Definition: properties.h:506
constexpr uint8_t kEqualCompatProperties
Definition: equal.h:35
virtual Weight Final(StateId) const =0
WeightApproxEqual(float delta)
Definition: equal.h:42
constexpr uint8_t kEqualFstTypes
Definition: equal.h:34
const Arc & Value() const
Definition: fst.h:537
virtual size_t NumInputEpsilons(StateId) const =0
#define FSTERROR()
Definition: util.h:53
virtual const std::string & Type() const =0
constexpr uint8_t kEqualAll
Definition: equal.h:37
constexpr uint64_t kCopyProperties
Definition: properties.h:162
#define VLOG(level)
Definition: log.h:50
virtual StateId Start() const =0
StateId Value() const
Definition: fst.h:419
bool Done() const
Definition: fst.h:533
bool Equal(const Fst< Arc > &fst1, const Fst< Arc > &fst2, WeightEqual weight_equal, uint8_t etype=kEqualFsts)
Definition: equal.h:61
virtual const SymbolTable * InputSymbols() const =0
void Next()
Definition: fst.h:421
bool Done() const
Definition: fst.h:415
bool operator()(const Weight1 &w1, const Weight2 &w2) const
Definition: equal.h:47
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
constexpr uint8_t kEqualFsts
Definition: equal.h:33
constexpr uint8_t kEqualCompatSymbols
Definition: equal.h:36
virtual size_t NumOutputEpsilons(StateId) const =0
constexpr float kDelta
Definition: weight.h:130
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:57
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:541