FST  openfst-1.8.2.post1
OpenFst Library
union-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Union weight set and associated semiring operation definitions.
19 //
20 // TODO(riley): add in normalizer functor.
21 
22 #ifndef FST_UNION_WEIGHT_H_
23 #define FST_UNION_WEIGHT_H_
24 
25 #include <cstdint>
26 #include <iostream>
27 #include <list>
28 #include <random>
29 #include <sstream>
30 #include <string>
31 #include <utility>
32 
33 
34 #include <fst/weight.h>
35 
36 
37 namespace fst {
38 
39 // Example UnionWeightOptions for UnionWeight template below. The Merge
40 // operation is used to collapse elements of the set and the Compare function
41 // to efficiently implement the merge. In the simplest case, merge would just
42 // apply with equality of set elements so the result is a set (and not a
43 // multiset). More generally, this can be used to maintain the multiplicity or
44 // other such weight associated with the set elements (cf. Gallic weights).
45 
46 // template <class W>
47 // struct UnionWeightOptions {
48 // // Comparison function C is a total order on W that is monotonic w.r.t. to
49 // // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is
50 // // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a).
51 // //
52 // // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where
53 // // ~ is an equivalence relation on W. Also we require a ~ b iff
54 // // a.Reverse() ~ b.Reverse().
55 // using Compare = NaturalLess<W>;
56 //
57 // // How to combine two weights if a ~ b as above. For all a, b: a ~ b =>
58 // // merge(a, b) ~ a, Merge must define a semiring endomorphism from the
59 // // unmerged weight sets to the merged weight sets.
60 // struct Merge {
61 // W operator()(const W &w1, const W &w2) const { return w1; }
62 // };
63 //
64 // // For ReverseWeight.
65 // using ReverseOptions = UnionWeightOptions<ReverseWeight>;
66 // };
67 
68 template <class W, class O>
70 
71 template <class W, class O>
73 
74 template <class W, class O>
76 
77 template <class W, class O>
78 bool operator==(const UnionWeight<W, O> &, const UnionWeight<W, O> &);
79 
80 // Semiring that uses Times() and One() from W and union and the empty set
81 // for Plus() and Zero(), respectively. Template argument O specifies the union
82 // weight options as above.
83 template <class W, class O>
84 class UnionWeight {
85  public:
86  using Weight = W;
87  using Compare = typename O::Compare;
88  using Merge = typename O::Merge;
89 
90  using ReverseWeight =
92 
93  friend class UnionWeightIterator<W, O>;
94  friend class UnionWeightReverseIterator<W, O>;
95 
96  // Sets represented as first_ weight + rest_ weights. Uses first_ as
97  // NoWeight() to indicate the union weight Zero() as the empty set. Uses
98  // rest_ containing NoWeight() to indicate the union weight NoWeight().
99  UnionWeight() : first_(W::NoWeight()) {}
100 
101  explicit UnionWeight(W weight) : first_(weight) {
102  if (!weight.Member()) rest_.push_back(W::NoWeight());
103  }
104 
105  static const UnionWeight &Zero() {
106  static const auto *const zero = new UnionWeight;
107  return *zero;
108  }
109 
110  static const UnionWeight &One() {
111  static const auto *const one = new UnionWeight(W::One());
112  return *one;
113  }
114 
115  static const UnionWeight &NoWeight() {
116  static const auto *const no_weight =
117  new UnionWeight(W::Zero(), W::NoWeight());
118  return *no_weight;
119  }
120 
121  static const std::string &Type() {
122  static const std::string *const type =
123  new std::string(W::Type() + "_union");
124  return *type;
125  }
126 
127  static constexpr uint64_t Properties() {
128  return W::Properties() &
130  }
131 
132  bool Member() const;
133 
134  std::istream &Read(std::istream &strm);
135 
136  std::ostream &Write(std::ostream &strm) const;
137 
138  size_t Hash() const;
139 
140  UnionWeight Quantize(float delta = kDelta) const;
141 
142  ReverseWeight Reverse() const;
143 
144  // These operations combined with the UnionWeightIterator and
145  // UnionWeightReverseIterator provide the access and mutation of the union
146  // weight internal elements.
147 
148  // Common initializer among constructors; clears existing UnionWeight.
149  void Clear() {
150  first_ = W::NoWeight();
151  rest_.clear();
152  }
153 
154  size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; }
155 
156  const W &Back() const { return rest_.empty() ? first_ : rest_.back(); }
157 
158  // When srt is true, assumes elements added sorted w.r.t Compare and merging
159  // of weights performed as needed. Otherwise, just ensures first_ is the
160  // least element wrt Compare.
161  void PushBack(W weight, bool srt);
162 
163  // Sorts the elements of the set. Assumes that first_, if present, is the
164  // least element.
165  void Sort() { rest_.sort(comp_); }
166 
167  private:
168  W &Back() {
169  if (rest_.empty()) {
170  return first_;
171  } else {
172  return rest_.back();
173  }
174  }
175 
176  UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {}
177 
178  W first_; // First weight in set.
179  std::list<W> rest_; // Remaining weights in set.
180  Compare comp_;
181  Merge merge_;
182 };
183 
184 template <class W, class O>
185 void UnionWeight<W, O>::PushBack(W weight, bool srt) {
186  if (!weight.Member()) {
187  rest_.push_back(std::move(weight));
188  } else if (!first_.Member()) {
189  first_ = std::move(weight);
190  } else if (srt) {
191  auto &back = Back();
192  if (comp_(back, weight)) {
193  rest_.push_back(std::move(weight));
194  } else {
195  back = merge_(back, std::move(weight));
196  }
197  } else {
198  if (comp_(first_, weight)) {
199  rest_.push_back(std::move(weight));
200  } else {
201  rest_.push_back(first_);
202  first_ = std::move(weight);
203  }
204  }
205 }
206 
207 // Traverses union weight in the forward direction.
208 template <class W, class O>
209 class UnionWeightIterator {
210  public:
211  explicit UnionWeightIterator(const UnionWeight<W, O> &weight)
212  : first_(weight.first_),
213  rest_(weight.rest_),
214  init_(true),
215  it_(rest_.begin()) {}
216 
217  bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); }
218 
219  const W &Value() const { return init_ ? first_ : *it_; }
220 
221  void Next() {
222  if (init_) {
223  init_ = false;
224  } else {
225  ++it_;
226  }
227  }
228 
229  void Reset() {
230  init_ = true;
231  it_ = rest_.begin();
232  }
233 
234  private:
235  const W &first_;
236  const std::list<W> &rest_;
237  bool init_; // in the initialized state?
238  typename std::list<W>::const_iterator it_;
239 };
240 
241 // Traverses union weight in backward direction.
242 template <typename L, class O>
244  public:
246  : first_(weight.first_),
247  rest_(weight.rest_),
248  fin_(!first_.Member()),
249  it_(rest_.rbegin()) {}
250 
251  bool Done() const { return fin_; }
252 
253  const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; }
254 
255  void Next() {
256  if (it_ == rest_.rend()) {
257  fin_ = true;
258  } else {
259  ++it_;
260  }
261  }
262 
263  void Reset() {
264  fin_ = !first_.Member();
265  it_ = rest_.rbegin();
266  }
267 
268  private:
269  const L &first_;
270  const std::list<L> &rest_;
271  bool fin_; // in the final state?
272  typename std::list<L>::const_reverse_iterator it_;
273 };
274 
275 // UnionWeight member functions follow that require UnionWeightIterator.
276 template <class W, class O>
277 inline std::istream &UnionWeight<W, O>::Read(std::istream &istrm) {
278  Clear();
279  int32_t size;
280  ReadType(istrm, &size);
281  for (int i = 0; i < size; ++i) {
282  W weight;
283  ReadType(istrm, &weight);
284  PushBack(weight, true);
285  }
286  return istrm;
287 }
288 
289 template <class W, class O>
290 inline std::ostream &UnionWeight<W, O>::Write(std::ostream &ostrm) const {
291  const int32_t size = Size();
292  WriteType(ostrm, size);
293  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
294  WriteType(ostrm, it.Value());
295  }
296  return ostrm;
297 }
298 
299 template <class W, class O>
300 inline bool UnionWeight<W, O>::Member() const {
301  if (Size() <= 1) return true;
302  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
303  if (!it.Value().Member()) return false;
304  }
305  return true;
306 }
307 
308 template <class W, class O>
310  UnionWeight weight;
311  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
312  weight.PushBack(it.Value().Quantize(delta), true);
313  }
314  return weight;
315 }
316 
317 template <class W, class O>
319  const {
320  ReverseWeight weight;
321  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
322  weight.PushBack(it.Value().Reverse(), false);
323  }
324  weight.Sort();
325  return weight;
326 }
327 
328 template <class W, class O>
329 inline size_t UnionWeight<W, O>::Hash() const {
330  size_t h = 0;
331  static constexpr int lshift = 5;
332  static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift;
333  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
334  h = h << lshift ^ h >> rshift ^ it.Value().Hash();
335  }
336  return h;
337 }
338 
339 // Requires union weight has been canonicalized.
340 template <class W, class O>
341 inline bool operator==(const UnionWeight<W, O> &w1,
342  const UnionWeight<W, O> &w2) {
343  if (w1.Size() != w2.Size()) return false;
346  for (; !it1.Done(); it1.Next(), it2.Next()) {
347  if (it1.Value() != it2.Value()) return false;
348  }
349  return true;
350 }
351 
352 // Requires union weight has been canonicalized.
353 template <class W, class O>
354 inline bool operator!=(const UnionWeight<W, O> &w1,
355  const UnionWeight<W, O> &w2) {
356  return !(w1 == w2);
357 }
358 
359 // Requires union weight has been canonicalized.
360 template <class W, class O>
361 inline bool ApproxEqual(const UnionWeight<W, O> &w1,
362  const UnionWeight<W, O> &w2, float delta = kDelta) {
363  if (w1.Size() != w2.Size()) return false;
366  for (; !it1.Done(); it1.Next(), it2.Next()) {
367  if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false;
368  }
369  return true;
370 }
371 
372 template <class W, class O>
373 inline std::ostream &operator<<(std::ostream &ostrm,
374  const UnionWeight<W, O> &weight) {
375  UnionWeightIterator<W, O> it(weight);
376  if (it.Done()) {
377  return ostrm << "EmptySet";
378  } else if (!weight.Member()) {
379  return ostrm << "BadSet";
380  } else {
381  CompositeWeightWriter writer(ostrm);
382  writer.WriteBegin();
383  for (; !it.Done(); it.Next()) writer.WriteElement(it.Value());
384  writer.WriteEnd();
385  }
386  return ostrm;
387 }
388 
389 template <class W, class O>
390 inline std::istream &operator>>(std::istream &istrm,
391  UnionWeight<W, O> &weight) {
392  std::string s;
393  istrm >> s;
394  if (s == "EmptySet") {
395  weight = UnionWeight<W, O>::Zero();
396  } else if (s == "BadSet") {
397  weight = UnionWeight<W, O>::NoWeight();
398  } else {
399  weight = UnionWeight<W, O>::Zero();
400  std::istringstream sstrm(s);
401  CompositeWeightReader reader(sstrm);
402  reader.ReadBegin();
403  bool more = true;
404  while (more) {
405  W v;
406  more = reader.ReadElement(&v);
407  weight.PushBack(v, true);
408  }
409  reader.ReadEnd();
410  }
411  return istrm;
412 }
413 
414 template <class W, class O>
416  const UnionWeight<W, O> &w2) {
417  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
418  if (w1 == UnionWeight<W, O>::Zero()) return w2;
419  if (w2 == UnionWeight<W, O>::Zero()) return w1;
422  UnionWeight<W, O> sum;
423  typename O::Compare comp;
424  while (!it1.Done() && !it2.Done()) {
425  const auto v1 = it1.Value();
426  const auto v2 = it2.Value();
427  if (comp(v1, v2)) {
428  sum.PushBack(v1, true);
429  it1.Next();
430  } else {
431  sum.PushBack(v2, true);
432  it2.Next();
433  }
434  }
435  for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true);
436  for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true);
437  return sum;
438 }
439 
440 template <class W, class O>
442  const UnionWeight<W, O> &w2) {
443  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
444  if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
445  return UnionWeight<W, O>::Zero();
446  }
449  UnionWeight<W, O> prod1;
450  for (; !it1.Done(); it1.Next()) {
451  UnionWeight<W, O> prod2;
452  for (; !it2.Done(); it2.Next()) {
453  prod2.PushBack(Times(it1.Value(), it2.Value()), true);
454  }
455  prod1 = Plus(prod1, prod2);
456  it2.Reset();
457  }
458  return prod1;
459 }
460 
461 template <class W, class O>
463  const UnionWeight<W, O> &w2, DivideType typ) {
464  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
465  if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
466  return UnionWeight<W, O>::Zero();
467  }
470  UnionWeight<W, O> quot;
471  if (w1.Size() == 1) {
472  for (; !it2.Done(); it2.Next()) {
473  quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
474  }
475  } else if (w2.Size() == 1) {
476  for (; !it1.Done(); it1.Next()) {
477  quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
478  }
479  } else {
481  }
482  return quot;
483 }
484 
485 // This function object generates weights over the union of weights for the
486 // underlying generators for the template weight types. This is intended
487 // primarily for testing.
488 template <class W, class O>
490  public:
493 
494  explicit WeightGenerate(uint64_t seed = std::random_device()(),
495  bool allow_zero = true,
496  size_t num_random_weights = kNumRandomWeights)
497  : rand_(seed),
498  allow_zero_(allow_zero),
499  num_random_weights_(num_random_weights),
500  generate_(seed, false) {}
501 
502  Weight operator()() const {
503  const int sample = std::uniform_int_distribution<>(
504  0, num_random_weights_ + allow_zero_ - 1)(rand_);
505  if (allow_zero_ && sample == num_random_weights_) {
506  return Weight::Zero();
507  } else if (std::bernoulli_distribution(.5)(rand_)) {
508  return Weight(generate_());
509  } else {
510  return Plus(Weight(generate_()), Weight(generate_()));
511  }
512  }
513 
514  private:
515  mutable std::mt19937_64 rand_;
516  const bool allow_zero_;
517  const size_t num_random_weights_;
518  const Generate generate_;
519 };
520 
521 } // namespace fst
522 
523 #endif // FST_UNION_WEIGHT_H_
typename O::Compare Compare
Definition: union-weight.h:87
static constexpr uint64_t Properties()
Definition: union-weight.h:127
static const UnionWeight & One()
Definition: union-weight.h:110
UnionWeight(W weight)
Definition: union-weight.h:101
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
const W & Value() const
Definition: union-weight.h:219
size_t Hash() const
Definition: union-weight.h:329
std::ostream & Write(std::ostream &strm) const
Definition: union-weight.h:290
constexpr uint64_t kIdempotent
Definition: weight.h:144
std::istream & Read(std::istream &strm)
Definition: union-weight.h:277
static const std::string & Type()
Definition: union-weight.h:121
constexpr uint64_t kRightSemiring
Definition: weight.h:136
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:214
ReverseWeight Reverse() const
Definition: union-weight.h:318
bool Member() const
Definition: union-weight.h:300
void PushBack(W weight, bool srt)
Definition: union-weight.h:185
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:181
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:53
constexpr uint64_t kCommutative
Definition: weight.h:141
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:70
static const UnionWeight & Zero()
Definition: union-weight.h:105
UnionWeightReverseIterator(const UnionWeight< L, O > &weight)
Definition: union-weight.h:245
UnionWeightIterator(const UnionWeight< W, O > &weight)
Definition: union-weight.h:211
static const UnionWeight & NoWeight()
Definition: union-weight.h:115
typename O::Merge Merge
Definition: union-weight.h:88
bool ReadElement(T *comp, bool last=false)
Definition: weight.h:366
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:66
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:50
UnionWeight Quantize(float delta=kDelta) const
Definition: union-weight.h:309
constexpr size_t kNumRandomWeights
Definition: weight.h:151
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Definition: union-weight.h:494
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:68
void WriteElement(const T &comp)
Definition: weight.h:316
DivideType
Definition: weight.h:162
const W & Back() const
Definition: union-weight.h:156
constexpr uint64_t kLeftSemiring
Definition: weight.h:133
constexpr float kDelta
Definition: weight.h:130
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:57
size_t Size() const
Definition: union-weight.h:154