FST  openfst-1.7.1
OpenFst Library
equal.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Function to test equality of two FSTs.
5 
6 #ifndef FST_EQUAL_H_
7 #define FST_EQUAL_H_
8 
9 #include <fst/log.h>
10 
11 #include <fst/fst.h>
12 #include <fst/test-properties.h>
13 
14 
15 namespace fst {
16 
17 constexpr uint32 kEqualFsts = 0x0001;
18 constexpr uint32 kEqualFstTypes = 0x0002;
19 constexpr uint32 kEqualCompatProperties = 0x0004;
20 constexpr uint32 kEqualCompatSymbols = 0x0008;
21 constexpr uint32 kEqualAll =
22  kEqualFsts | kEqualFstTypes | kEqualCompatProperties | kEqualCompatSymbols;
23 
25  public:
26  explicit WeightApproxEqual(float delta) : delta_(delta) {}
27 
28  template <class Weight>
29  bool operator()(const Weight &w1, const Weight &w2) const {
30  return ApproxEqual(w1, w2, delta_);
31  }
32 
33  private:
34  float delta_;
35 };
36 
37 // Tests if two Fsts have the same states and arcs in the same order (when
38 // etype & kEqualFst).
39 // Also optional checks equality of Fst types (etype & kEqualFstTypes) and
40 // compatibility of stored properties (etype & kEqualCompatProperties) and
41 // of symbol tables (etype & kEqualCompatSymbols).
42 template <class Arc, class WeightEqual>
43 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
44  WeightEqual weight_equal, uint32 etype = kEqualFsts) {
45  if ((etype & kEqualFstTypes) && (fst1.Type() != fst2.Type())) {
46  VLOG(1) << "Equal: Mismatched FST types (" << fst1.Type() << " != "
47  << fst2.Type() << ")";
48  return false;
49  }
50  if ((etype & kEqualCompatProperties) &&
52  fst2.Properties(kCopyProperties, false))) {
53  VLOG(1) << "Equal: Properties not compatible";
54  return false;
55  }
56  if (etype & kEqualCompatSymbols) {
57  if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols(), false)) {
58  VLOG(1) << "Equal: Input symbols not compatible";
59  return false;
60  }
61  if (!CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols(), false)) {
62  VLOG(1) << "Equal: Output symbols not compatible";
63  return false;
64  }
65  }
66  if (!(etype & kEqualFsts)) return true;
67  if (fst1.Start() != fst2.Start()) {
68  VLOG(1) << "Equal: Mismatched start states (" << fst1.Start() << " != "
69  << fst2.Start() << ")";
70  return false;
71  }
72  StateIterator<Fst<Arc>> siter1(fst1);
73  StateIterator<Fst<Arc>> siter2(fst2);
74  while (!siter1.Done() || !siter2.Done()) {
75  if (siter1.Done() || siter2.Done()) {
76  VLOG(1) << "Equal: Mismatched number of states";
77  return false;
78  }
79  const auto s1 = siter1.Value();
80  const auto s2 = siter2.Value();
81  if (s1 != s2) {
82  VLOG(1) << "Equal: Mismatched states (" << s1 << "!= "
83  << s2 << ")";
84  return false;
85  }
86  const auto &final1 = fst1.Final(s1);
87  const auto &final2 = fst2.Final(s2);
88  if (!weight_equal(final1, final2)) {
89  VLOG(1) << "Equal: Mismatched final weights at state " << s1
90  << " (" << final1 << " != " << final2 << ")";
91  return false;
92  }
93  ArcIterator<Fst<Arc>> aiter1(fst1, s1);
94  ArcIterator<Fst<Arc>> aiter2(fst2, s2);
95  for (auto a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
96  if (aiter1.Done() || aiter2.Done()) {
97  VLOG(1) << "Equal: Mismatched number of arcs at state " << s1;
98  return false;
99  }
100  const auto &arc1 = aiter1.Value();
101  const auto &arc2 = aiter2.Value();
102  if (arc1.ilabel != arc2.ilabel) {
103  VLOG(1) << "Equal: Mismatched arc input labels at state " << s1
104  << ", arc " << a << " (" << arc1.ilabel << " != "
105  << arc2.ilabel << ")";
106  return false;
107  } else if (arc1.olabel != arc2.olabel) {
108  VLOG(1) << "Equal: Mismatched arc output labels at state " << s1
109  << ", arc " << a << " (" << arc1.olabel << " != "
110  << arc2.olabel << ")";
111  return false;
112  } else if (!weight_equal(arc1.weight, arc2.weight)) {
113  VLOG(1) << "Equal: Mismatched arc weights at state " << s1
114  << ", arc " << a << " (" << arc1.weight << " != "
115  << arc2.weight << ")";
116  return false;
117  } else if (arc1.nextstate != arc2.nextstate) {
118  VLOG(1) << "Equal: Mismatched next state at state " << s1
119  << ", arc " << a << " (" << arc1.nextstate << " != "
120  << arc2.nextstate << ")";
121  return false;
122  }
123  aiter1.Next();
124  aiter2.Next();
125  }
126  // Sanity checks: should never fail.
127  if (fst1.NumArcs(s1) != fst2.NumArcs(s2)) {
128  FSTERROR() << "Equal: Inconsistent arc counts at state " << s1
129  << " (" << fst1.NumArcs(s1) << " != "
130  << fst2.NumArcs(s2) << ")";
131  return false;
132  }
133  if (fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2)) {
134  FSTERROR() << "Equal: Inconsistent input epsilon counts at state " << s1
135  << " (" << fst1.NumInputEpsilons(s1) << " != "
136  << fst2.NumInputEpsilons(s2) << ")";
137  return false;
138  }
139  if (fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) {
140  FSTERROR() << "Equal: Inconsistent output epsilon counts at state " << s1
141  << " (" << fst1.NumOutputEpsilons(s1) << " != "
142  << fst2.NumOutputEpsilons(s2) << ")";
143  }
144  siter1.Next();
145  siter2.Next();
146  }
147  return true;
148 }
149 
150 template <class Arc>
151 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
152  float delta = kDelta, uint32 etype = kEqualFsts) {
153  return Equal(fst1, fst2, WeightApproxEqual(delta), etype);
154 }
155 
156 // Support double deltas without forcing all clients to cast to float.
157 // Without this overload, Equal<Arc, WeightEqual=double> will be chosen,
158 // since it is a better match than double -> float narrowing, but
159 // the instantiation will fail.
160 template <class Arc>
161 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
162  double delta, uint32 etype = kEqualFsts) {
163  return Equal(fst1, fst2, WeightApproxEqual(static_cast<float>(delta)), etype);
164 }
165 
166 
167 } // namespace fst
168 
169 #endif // FST_EQUAL_H_
constexpr uint32 kEqualAll
Definition: equal.h:21
constexpr uint32 kEqualFstTypes
Definition: equal.h:18
bool operator()(const Weight &w1, const Weight &w2) const
Definition: equal.h:29
virtual size_t NumArcs(StateId) const =0
constexpr uint32 kEqualCompatProperties
Definition: equal.h:19
virtual Weight Final(StateId) const =0
WeightApproxEqual(float delta)
Definition: equal.h:26
bool CompatProperties(uint64 props1, uint64 props2)
constexpr uint32 kEqualCompatSymbols
Definition: equal.h:20
constexpr uint64 kCopyProperties
Definition: properties.h:138
const Arc & Value() const
Definition: fst.h:503
virtual uint64 Properties(uint64 mask, bool test) const =0
virtual size_t NumInputEpsilons(StateId) const =0
#define FSTERROR()
Definition: util.h:35
#define VLOG(level)
Definition: log.h:49
virtual StateId Start() const =0
StateId Value() const
Definition: fst.h:387
bool Done() const
Definition: fst.h:499
virtual const string & Type() const =0
uint32_t uint32
Definition: types.h:31
virtual const SymbolTable * InputSymbols() const =0
void Next()
Definition: fst.h:389
bool Done() const
Definition: fst.h:383
constexpr bool ApproxEqual(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2, float delta=kDelta)
Definition: float-weight.h:140
bool Equal(const Fst< Arc > &fst1, const Fst< Arc > &fst2, WeightEqual weight_equal, uint32 etype=kEqualFsts)
Definition: equal.h:43
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
constexpr uint32 kEqualFsts
Definition: equal.h:17
virtual size_t NumOutputEpsilons(StateId) const =0
constexpr float kDelta
Definition: weight.h:109
virtual const SymbolTable * OutputSymbols() const =0
void Next()
Definition: fst.h:507