FST  openfst-1.8.3
OpenFst Library
isomorphic.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Function to test two FSTs are isomorphic, i.e., they are equal up to a state
19 // and arc re-ordering. FSTs should be deterministic when viewed as
20 // unweighted automata. False negatives (but not false positives) are possible
21 // when the inputs are nondeterministic (when viewed as unweighted automata).
22 
23 #ifndef FST_ISOMORPHIC_H_
24 #define FST_ISOMORPHIC_H_
25 
26 #include <algorithm>
27 #include <cstddef>
28 #include <memory>
29 #include <queue>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 #include <fst/log.h>
35 #include <fst/fst.h>
36 #include <fst/util.h>
37 #include <fst/weight.h>
38 
39 namespace fst {
40 namespace internal {
41 
42 // Orders weights for equality checking; delta is ignored.
43 template <class Weight,
44  typename std::enable_if_t<IsIdempotent<Weight>::value> * = nullptr>
45 bool WeightCompare(const Weight &w1, const Weight &w2, float, bool *) {
46  static const NaturalLess<Weight> less;
47  return less(w1, w2);
48 }
49 
50 template <class Weight,
51  typename std::enable_if_t<!IsIdempotent<Weight>::value> * = nullptr>
52 bool WeightCompare(const Weight &w1, const Weight &w2, float delta,
53  bool *error) {
54  // No natural order; use hash.
55  const auto q1 = w1.Quantize(delta);
56  const auto q2 = w2.Quantize(delta);
57  const auto n1 = q1.Hash();
58  const auto n2 = q2.Hash();
59  // Hash not unique; very unlikely to happen.
60  if (n1 == n2 && q1 != q2) {
61  VLOG(1) << "Isomorphic: Weight hash collision";
62  *error = true;
63  }
64  return n1 < n2;
65 }
66 
67 template <class Arc>
68 class Isomorphism {
69  using StateId = typename Arc::StateId;
70 
71  public:
72  Isomorphism(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta)
73  : fst1_(fst1.Copy()),
74  fst2_(fst2.Copy()),
75  delta_(delta),
76  error_(false),
77  nondet_(false),
78  comp_(delta, &error_) {}
79 
80  // Checks if input FSTs are isomorphic.
81  bool IsIsomorphic() {
82  if (fst1_->Start() == kNoStateId && fst2_->Start() == kNoStateId) {
83  return true;
84  }
85  if (fst1_->Start() == kNoStateId || fst2_->Start() == kNoStateId) {
86  VLOG(1) << "Isomorphic: Only one of the FSTs is empty.";
87  return false;
88  }
89  PairState(fst1_->Start(), fst2_->Start());
90  while (!queue_.empty()) {
91  const auto &[state1, state2] = queue_.front();
92  if (!IsIsomorphicState(state1, state2)) {
93  if (nondet_) {
94  VLOG(1) << "Isomorphic: Non-determinism as an unweighted automaton. "
95  << "state1: " << state1 << " state2: " << state2;
96  error_ = true;
97  }
98  return false;
99  }
100  queue_.pop();
101  }
102  return true;
103  }
104 
105  bool Error() const { return error_; }
106 
107  private:
108  // Orders arcs for equality checking.
109  class ArcCompare {
110  public:
111  ArcCompare(float delta, bool *error) : delta_(delta), error_(error) {}
112 
113  bool operator()(const Arc &arc1, const Arc &arc2) const {
114  if (arc1.ilabel < arc2.ilabel) return true;
115  if (arc1.ilabel > arc2.ilabel) return false;
116  if (arc1.olabel < arc2.olabel) return true;
117  if (arc1.olabel > arc2.olabel) return false;
118  if (!ApproxEqual(arc1.weight, arc2.weight, delta_)) {
119  return WeightCompare(arc1.weight, arc2.weight, delta_, error_);
120  } else {
121  return arc1.nextstate < arc2.nextstate;
122  }
123  }
124 
125  private:
126  const float delta_;
127  bool *error_;
128  };
129 
130  // Maintains state correspondences and queue.
131  bool PairState(StateId s1, StateId s2) {
132  if (state_pairs_.size() <= s1) state_pairs_.resize(s1 + 1, kNoStateId);
133  if (state_pairs_[s1] == s2) {
134  return true; // Already seen this pair.
135  } else if (state_pairs_[s1] != kNoStateId) {
136  return false; // s1 already paired with another s2.
137  }
138  VLOG(3) << "Pairing states: (" << s1 << ", " << s2 << ")";
139  state_pairs_[s1] = s2;
140  queue_.emplace(s1, s2);
141  return true;
142  }
143 
144  // Checks if state pair is isomorphic.
145  bool IsIsomorphicState(StateId s1, StateId s2);
146 
147  std::unique_ptr<Fst<Arc>> fst1_;
148  std::unique_ptr<Fst<Arc>> fst2_;
149  float delta_; // Weight equality delta.
150  std::vector<Arc> arcs1_; // For sorting arcs on FST1.
151  std::vector<Arc> arcs2_; // For sorting arcs on FST2.
152  std::vector<StateId> state_pairs_; // Maintains state correspondences.
153  std::queue<std::pair<StateId, StateId>> queue_; // Queue of state pairs.
154  bool error_; // Error flag.
155  bool nondet_; // Nondeterminism detected.
156  ArcCompare comp_;
157 };
158 
159 template <class Arc>
160 bool Isomorphism<Arc>::IsIsomorphicState(StateId s1, StateId s2) {
161  if (!ApproxEqual(fst1_->Final(s1), fst2_->Final(s2), delta_)) {
162  VLOG(1) << "Isomorphic: Final weights not equal to within delta=" << delta_
163  << ": "
164  << "fst1.Final(" << s1 << ") = " << fst1_->Final(s1) << ", "
165  << "fst2.Final(" << s2 << ") = " << fst2_->Final(s2);
166  return false;
167  }
168  const auto narcs1 = fst1_->NumArcs(s1);
169  const auto narcs2 = fst2_->NumArcs(s2);
170  if (narcs1 != narcs2) {
171  VLOG(1) << "Isomorphic: NumArcs not equal. "
172  << "fst1.NumArcs(" << s1 << ") = " << narcs1 << ", "
173  << "fst2.NumArcs(" << s2 << ") = " << narcs2;
174  return false;
175  }
176  ArcIterator<Fst<Arc>> aiter1(*fst1_, s1);
177  ArcIterator<Fst<Arc>> aiter2(*fst2_, s2);
178  arcs1_.clear();
179  arcs1_.reserve(narcs1);
180  arcs2_.clear();
181  arcs2_.reserve(narcs2);
182  for (; !aiter1.Done(); aiter1.Next(), aiter2.Next()) {
183  arcs1_.push_back(aiter1.Value());
184  arcs2_.push_back(aiter2.Value());
185  }
186  std::sort(arcs1_.begin(), arcs1_.end(), comp_);
187  std::sort(arcs2_.begin(), arcs2_.end(), comp_);
188  for (size_t i = 0; i < arcs1_.size(); ++i) {
189  const auto &arc1 = arcs1_[i];
190  const auto &arc2 = arcs2_[i];
191  if (arc1.ilabel != arc2.ilabel) {
192  VLOG(1) << "Isomorphic: ilabels not equal. "
193  << "state1: " << s1 << " arc1: *" << arc1.ilabel << "* "
194  << arc1.olabel << " " << arc1.weight << " " << arc1.nextstate
195  << " state2: " << s2 << " arc2: *" << arc2.ilabel << "* "
196  << arc2.olabel << " " << arc2.weight << " " << arc2.nextstate;
197  return false;
198  }
199  if (arc1.olabel != arc2.olabel) {
200  VLOG(1) << "Isomorphic: olabels not equal. "
201  << "state1: " << s1 << " arc1: " << arc1.ilabel << " *"
202  << arc1.olabel << "* " << arc1.weight << " " << arc1.nextstate
203  << " state2: " << s2 << " arc2: " << arc2.ilabel << " *"
204  << arc2.olabel << "* " << arc2.weight << " " << arc2.nextstate;
205  return false;
206  }
207  if (!ApproxEqual(arc1.weight, arc2.weight, delta_)) {
208  VLOG(1) << "Isomorphic: weights not ApproxEqual. "
209  << "state1: " << s1 << " arc1: " << arc1.ilabel << " "
210  << arc1.olabel << " *" << arc1.weight << "* " << arc1.nextstate
211  << " state2: " << s2 << " arc2: " << arc2.ilabel << " "
212  << arc2.olabel << " *" << arc2.weight << "* " << arc2.nextstate;
213  return false;
214  }
215  if (!PairState(arc1.nextstate, arc2.nextstate)) {
216  VLOG(1) << "Isomorphic: nextstates could not be paired. "
217  << "state1: " << s1 << " arc1: " << arc1.ilabel << " "
218  << arc1.olabel << " " << arc1.weight << " *" << arc1.nextstate
219  << "* "
220  << "state2: " << s2 << " arc2: " << arc2.ilabel << " "
221  << arc2.olabel << " " << arc2.weight << " *" << arc2.nextstate
222  << "*";
223  return false;
224  }
225  if (i > 0) { // Checks for non-determinism.
226  const auto &arc0 = arcs1_[i - 1];
227  if (arc1.ilabel == arc0.ilabel && arc1.olabel == arc0.olabel &&
228  ApproxEqual(arc1.weight, arc0.weight, delta_)) {
229  // Any subsequent matching failure maybe a false negative
230  // since we only consider one permutation when pairing destination
231  // states of nondeterministic transitions.
232  VLOG(1) << "Isomorphic: Detected non-determinism as an unweighted "
233  << "automaton; deferring error. "
234  << "state: " << s1 << " arc1: " << arc1.ilabel << " "
235  << arc1.olabel << " " << arc1.weight << " " << arc1.nextstate
236  << " arc2: " << arc2.ilabel << " " << arc2.olabel << " "
237  << arc2.weight << " " << arc2.nextstate;
238  nondet_ = true;
239  }
240  }
241  }
242  return true;
243 }
244 
245 } // namespace internal
246 
247 // Tests if two FSTs have the same states and arcs up to a reordering.
248 // Inputs should be deterministic when viewed as unweighted automata.
249 // When the inputs are nondeterministic, the algorithm only considers one
250 // permutation for each set of equivalent nondeterministic transitions
251 // (the permutation that preserves state ID ordering) and hence might return
252 // false negatives (but it never returns false positives).
253 template <class Arc>
254 bool Isomorphic(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
255  float delta = kDelta) {
256  internal::Isomorphism<Arc> iso(fst1, fst2, delta);
257  const bool result = iso.IsIsomorphic();
258  if (iso.Error()) {
259  FSTERROR() << "Isomorphic: Cannot determine if inputs are isomorphic";
260  return false;
261  } else {
262  return result;
263  }
264 }
265 
266 } // namespace fst
267 
268 #endif // FST_ISOMORPHIC_H_
Isomorphism(const Fst< Arc > &fst1, const Fst< Arc > &fst2, float delta)
Definition: isomorphic.h:72
bool WeightCompare(const Weight &w1, const Weight &w2, float, bool *)
Definition: isomorphic.h:45
constexpr int kNoStateId
Definition: fst.h:196
const Arc & Value() const
Definition: fst.h:536
#define FSTERROR()
Definition: util.h:56
#define VLOG(level)
Definition: log.h:54
bool Done() const
Definition: fst.h:532
bool Isomorphic(FarReader< Arc > &reader1, FarReader< Arc > &reader2, float delta=kDelta, std::string_view begin_key="", std::string_view end_key="")
Definition: isomorphic.h:32
constexpr float kDelta
Definition: weight.h:133
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
void Next()
Definition: fst.h:540