FST  openfst-1.8.3
OpenFst Library
union-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Union weight set and associated semiring operation definitions.
19 //
20 // TODO(riley): add in normalizer functor.
21 
22 #ifndef FST_UNION_WEIGHT_H_
23 #define FST_UNION_WEIGHT_H_
24 
25 #include <climits>
26 #include <cstddef>
27 #include <cstdint>
28 #include <iostream>
29 #include <istream>
30 #include <list>
31 #include <ostream>
32 #include <random>
33 #include <sstream>
34 #include <string>
35 #include <utility>
36 
37 #include <fst/util.h>
38 #include <fst/weight.h>
39 
40 namespace fst {
41 
42 // Example UnionWeightOptions for UnionWeight template below. The Merge
43 // operation is used to collapse elements of the set and the Compare function
44 // to efficiently implement the merge. In the simplest case, merge would just
45 // apply with equality of set elements so the result is a set (and not a
46 // multiset). More generally, this can be used to maintain the multiplicity or
47 // other such weight associated with the set elements (cf. Gallic weights).
48 
49 // template <class W>
50 // struct UnionWeightOptions {
51 // // Comparison function C is a total order on W that is monotonic w.r.t. to
52 // // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is
53 // // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a).
54 // //
55 // // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where
56 // // ~ is an equivalence relation on W. Also we require a ~ b iff
57 // // a.Reverse() ~ b.Reverse().
58 // using Compare = NaturalLess<W>;
59 //
60 // // How to combine two weights if a ~ b as above. For all a, b: a ~ b =>
61 // // merge(a, b) ~ a, Merge must define a semiring endomorphism from the
62 // // unmerged weight sets to the merged weight sets.
63 // struct Merge {
64 // W operator()(const W &w1, const W &w2) const { return w1; }
65 // };
66 //
67 // // For ReverseWeight.
68 // using ReverseOptions = UnionWeightOptions<ReverseWeight>;
69 // };
70 
71 template <class W, class O>
73 template <class W, class O>
75 template <class W, class O>
77 
78 template <class W, class O>
79 bool operator==(const UnionWeight<W, O> &, const UnionWeight<W, O> &);
80 
81 // Semiring that uses Times() and One() from W and union and the empty set
82 // for Plus() and Zero(), respectively. Template argument O specifies the union
83 // weight options as above.
84 template <class W, class O>
85 class UnionWeight {
86  public:
87  using Weight = W;
88  using Compare = typename O::Compare;
89  using Merge = typename O::Merge;
90 
91  using ReverseWeight =
93 
94  friend class UnionWeightIterator<W, O>;
95  friend class UnionWeightReverseIterator<W, O>;
96 
97  // Sets represented as first_ weight + rest_ weights. Uses first_ as
98  // NoWeight() to indicate the union weight Zero() as the empty set. Uses
99  // rest_ containing NoWeight() to indicate the union weight NoWeight().
100  UnionWeight() : first_(W::NoWeight()) {}
101 
102  explicit UnionWeight(W weight) : first_(weight) {
103  if (!weight.Member()) rest_.push_back(W::NoWeight());
104  }
105 
106  static const UnionWeight &Zero() {
107  static const auto *const zero = new UnionWeight;
108  return *zero;
109  }
110 
111  static const UnionWeight &One() {
112  static const auto *const one = new UnionWeight(W::One());
113  return *one;
114  }
115 
116  static const UnionWeight &NoWeight() {
117  static const auto *const no_weight =
118  new UnionWeight(W::Zero(), W::NoWeight());
119  return *no_weight;
120  }
121 
122  static const std::string &Type() {
123  static const std::string *const type =
124  new std::string(W::Type() + "_union");
125  return *type;
126  }
127 
128  static constexpr uint64_t Properties() {
129  return W::Properties() &
131  }
132 
133  bool Member() const;
134 
135  std::istream &Read(std::istream &strm);
136 
137  std::ostream &Write(std::ostream &strm) const;
138 
139  size_t Hash() const;
140 
141  UnionWeight Quantize(float delta = kDelta) const;
142 
143  ReverseWeight Reverse() const;
144 
145  // These operations combined with the UnionWeightIterator and
146  // UnionWeightReverseIterator provide the access and mutation of the union
147  // weight internal elements.
148 
149  // Common initializer among constructors; clears existing UnionWeight.
150  void Clear() {
151  first_ = W::NoWeight();
152  rest_.clear();
153  }
154 
155  size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; }
156 
157  const W &Back() const { return rest_.empty() ? first_ : rest_.back(); }
158 
159  // When srt is true, assumes elements added sorted w.r.t Compare and merging
160  // of weights performed as needed. Otherwise, just ensures first_ is the
161  // least element wrt Compare.
162  void PushBack(W weight, bool srt);
163 
164  // Sorts the elements of the set. Assumes that first_, if present, is the
165  // least element.
166  void Sort() { rest_.sort(comp_); }
167 
168  private:
169  W &Back() {
170  if (rest_.empty()) {
171  return first_;
172  } else {
173  return rest_.back();
174  }
175  }
176 
177  UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {}
178 
179  W first_; // First weight in set.
180  std::list<W> rest_; // Remaining weights in set.
181  Compare comp_;
182  Merge merge_;
183 };
184 
185 template <class W, class O>
186 void UnionWeight<W, O>::PushBack(W weight, bool srt) {
187  if (!weight.Member()) {
188  rest_.push_back(std::move(weight));
189  } else if (!first_.Member()) {
190  first_ = std::move(weight);
191  } else if (srt) {
192  auto &back = Back();
193  if (comp_(back, weight)) {
194  rest_.push_back(std::move(weight));
195  } else {
196  back = merge_(back, std::move(weight));
197  }
198  } else {
199  if (comp_(first_, weight)) {
200  rest_.push_back(std::move(weight));
201  } else {
202  rest_.push_back(first_);
203  first_ = std::move(weight);
204  }
205  }
206 }
207 
208 // Traverses union weight in the forward direction.
209 template <class W, class O>
210 class UnionWeightIterator {
211  public:
212  explicit UnionWeightIterator(const UnionWeight<W, O> &weight)
213  : first_(weight.first_),
214  rest_(weight.rest_),
215  init_(true),
216  it_(rest_.begin()) {}
217 
218  bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); }
219 
220  const W &Value() const { return init_ ? first_ : *it_; }
221 
222  void Next() {
223  if (init_) {
224  init_ = false;
225  } else {
226  ++it_;
227  }
228  }
229 
230  void Reset() {
231  init_ = true;
232  it_ = rest_.begin();
233  }
234 
235  private:
236  const W &first_;
237  const std::list<W> &rest_;
238  bool init_; // in the initialized state?
239  typename std::list<W>::const_iterator it_;
240 };
241 
242 // Traverses union weight in backward direction.
243 template <typename L, class O>
245  public:
247  : first_(weight.first_),
248  rest_(weight.rest_),
249  fin_(!first_.Member()),
250  it_(rest_.rbegin()) {}
251 
252  bool Done() const { return fin_; }
253 
254  const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; }
255 
256  void Next() {
257  if (it_ == rest_.rend()) {
258  fin_ = true;
259  } else {
260  ++it_;
261  }
262  }
263 
264  void Reset() {
265  fin_ = !first_.Member();
266  it_ = rest_.rbegin();
267  }
268 
269  private:
270  const L &first_;
271  const std::list<L> &rest_;
272  bool fin_; // in the final state?
273  typename std::list<L>::const_reverse_iterator it_;
274 };
275 
276 // UnionWeight member functions follow that require UnionWeightIterator.
277 template <class W, class O>
278 inline std::istream &UnionWeight<W, O>::Read(std::istream &istrm) {
279  Clear();
280  int32_t size;
281  ReadType(istrm, &size);
282  for (int i = 0; i < size; ++i) {
283  W weight;
284  ReadType(istrm, &weight);
285  PushBack(weight, true);
286  }
287  return istrm;
288 }
289 
290 template <class W, class O>
291 inline std::ostream &UnionWeight<W, O>::Write(std::ostream &ostrm) const {
292  const int32_t size = Size();
293  WriteType(ostrm, size);
294  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
295  WriteType(ostrm, it.Value());
296  }
297  return ostrm;
298 }
299 
300 template <class W, class O>
301 inline bool UnionWeight<W, O>::Member() const {
302  if (Size() <= 1) return true;
303  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
304  if (!it.Value().Member()) return false;
305  }
306  return true;
307 }
308 
309 template <class W, class O>
311  UnionWeight weight;
312  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
313  weight.PushBack(it.Value().Quantize(delta), true);
314  }
315  return weight;
316 }
317 
318 template <class W, class O>
320  const {
321  ReverseWeight weight;
322  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
323  weight.PushBack(it.Value().Reverse(), false);
324  }
325  weight.Sort();
326  return weight;
327 }
328 
329 template <class W, class O>
330 inline size_t UnionWeight<W, O>::Hash() const {
331  size_t h = 0;
332  static constexpr int lshift = 5;
333  static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift;
334  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
335  h = h << lshift ^ h >> rshift ^ it.Value().Hash();
336  }
337  return h;
338 }
339 
340 // Requires union weight has been canonicalized.
341 template <class W, class O>
342 inline bool operator==(const UnionWeight<W, O> &w1,
343  const UnionWeight<W, O> &w2) {
344  if (w1.Size() != w2.Size()) return false;
347  for (; !it1.Done(); it1.Next(), it2.Next()) {
348  if (it1.Value() != it2.Value()) return false;
349  }
350  return true;
351 }
352 
353 // Requires union weight has been canonicalized.
354 template <class W, class O>
355 inline bool operator!=(const UnionWeight<W, O> &w1,
356  const UnionWeight<W, O> &w2) {
357  return !(w1 == w2);
358 }
359 
360 // Requires union weight has been canonicalized.
361 template <class W, class O>
362 inline bool ApproxEqual(const UnionWeight<W, O> &w1,
363  const UnionWeight<W, O> &w2, float delta = kDelta) {
364  if (w1.Size() != w2.Size()) return false;
367  for (; !it1.Done(); it1.Next(), it2.Next()) {
368  if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false;
369  }
370  return true;
371 }
372 
373 template <class W, class O>
374 inline std::ostream &operator<<(std::ostream &ostrm,
375  const UnionWeight<W, O> &weight) {
376  UnionWeightIterator<W, O> it(weight);
377  if (it.Done()) {
378  return ostrm << "EmptySet";
379  } else if (!weight.Member()) {
380  return ostrm << "BadSet";
381  } else {
382  CompositeWeightWriter writer(ostrm);
383  writer.WriteBegin();
384  for (; !it.Done(); it.Next()) writer.WriteElement(it.Value());
385  writer.WriteEnd();
386  }
387  return ostrm;
388 }
389 
390 template <class W, class O>
391 inline std::istream &operator>>(std::istream &istrm,
392  UnionWeight<W, O> &weight) {
393  std::string s;
394  istrm >> s;
395  if (s == "EmptySet") {
396  weight = UnionWeight<W, O>::Zero();
397  } else if (s == "BadSet") {
398  weight = UnionWeight<W, O>::NoWeight();
399  } else {
400  weight = UnionWeight<W, O>::Zero();
401  std::istringstream sstrm(s);
402  CompositeWeightReader reader(sstrm);
403  reader.ReadBegin();
404  bool more = true;
405  while (more) {
406  W v;
407  more = reader.ReadElement(&v);
408  weight.PushBack(v, true);
409  }
410  reader.ReadEnd();
411  }
412  return istrm;
413 }
414 
415 template <class W, class O>
417  const UnionWeight<W, O> &w2) {
418  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
419  if (w1 == UnionWeight<W, O>::Zero()) return w2;
420  if (w2 == UnionWeight<W, O>::Zero()) return w1;
423  UnionWeight<W, O> sum;
424  typename O::Compare comp;
425  while (!it1.Done() && !it2.Done()) {
426  const auto v1 = it1.Value();
427  const auto v2 = it2.Value();
428  if (comp(v1, v2)) {
429  sum.PushBack(v1, true);
430  it1.Next();
431  } else {
432  sum.PushBack(v2, true);
433  it2.Next();
434  }
435  }
436  for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true);
437  for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true);
438  return sum;
439 }
440 
441 template <class W, class O>
443  const UnionWeight<W, O> &w2) {
444  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
445  if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
446  return UnionWeight<W, O>::Zero();
447  }
450  UnionWeight<W, O> prod1;
451  for (; !it1.Done(); it1.Next()) {
452  UnionWeight<W, O> prod2;
453  for (; !it2.Done(); it2.Next()) {
454  prod2.PushBack(Times(it1.Value(), it2.Value()), true);
455  }
456  prod1 = Plus(prod1, prod2);
457  it2.Reset();
458  }
459  return prod1;
460 }
461 
462 template <class W, class O>
464  const UnionWeight<W, O> &w2, DivideType typ) {
465  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
466  if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
467  return UnionWeight<W, O>::Zero();
468  }
471  UnionWeight<W, O> quot;
472  if (w1.Size() == 1) {
473  for (; !it2.Done(); it2.Next()) {
474  quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
475  }
476  } else if (w2.Size() == 1) {
477  for (; !it1.Done(); it1.Next()) {
478  quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
479  }
480  } else {
482  }
483  return quot;
484 }
485 
486 // This function object generates weights over the union of weights for the
487 // underlying generators for the template weight types. This is intended
488 // primarily for testing.
489 template <class W, class O>
491  public:
494 
495  explicit WeightGenerate(uint64_t seed = std::random_device()(),
496  bool allow_zero = true,
497  size_t num_random_weights = kNumRandomWeights)
498  : rand_(seed),
499  allow_zero_(allow_zero),
500  num_random_weights_(num_random_weights),
501  generate_(seed, false) {}
502 
503  Weight operator()() const {
504  const int sample = std::uniform_int_distribution<>(
505  0, num_random_weights_ + allow_zero_ - 1)(rand_);
506  if (allow_zero_ && sample == num_random_weights_) {
507  return Weight::Zero();
508  } else if (std::bernoulli_distribution(.5)(rand_)) {
509  return Weight(generate_());
510  } else {
511  return Plus(Weight(generate_()), Weight(generate_()));
512  }
513  }
514 
515  private:
516  mutable std::mt19937_64 rand_;
517  const bool allow_zero_;
518  const size_t num_random_weights_;
519  const Generate generate_;
520 };
521 
522 } // namespace fst
523 
524 #endif // FST_UNION_WEIGHT_H_
typename O::Compare Compare
Definition: union-weight.h:88
static constexpr uint64_t Properties()
Definition: union-weight.h:128
static const UnionWeight & One()
Definition: union-weight.h:111
UnionWeight(W weight)
Definition: union-weight.h:102
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
const W & Value() const
Definition: union-weight.h:220
size_t Hash() const
Definition: union-weight.h:330
std::ostream & Write(std::ostream &strm) const
Definition: union-weight.h:291
constexpr uint64_t kIdempotent
Definition: weight.h:147
std::istream & Read(std::istream &strm)
Definition: union-weight.h:278
static const std::string & Type()
Definition: union-weight.h:122
constexpr uint64_t kRightSemiring
Definition: weight.h:139
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
ReverseWeight Reverse() const
Definition: union-weight.h:319
bool Member() const
Definition: union-weight.h:301
void PushBack(W weight, bool srt)
Definition: union-weight.h:186
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:185
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:54
constexpr uint64_t kCommutative
Definition: weight.h:144
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:71
static const UnionWeight & Zero()
Definition: union-weight.h:106
UnionWeightReverseIterator(const UnionWeight< L, O > &weight)
Definition: union-weight.h:246
UnionWeightIterator(const UnionWeight< W, O > &weight)
Definition: union-weight.h:212
static const UnionWeight & NoWeight()
Definition: union-weight.h:116
typename O::Merge Merge
Definition: union-weight.h:89
bool ReadElement(T *comp, bool last=false)
Definition: weight.h:367
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:51
UnionWeight Quantize(float delta=kDelta) const
Definition: union-weight.h:310
constexpr size_t kNumRandomWeights
Definition: weight.h:154
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Definition: union-weight.h:495
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
void WriteElement(const T &comp)
Definition: weight.h:317
DivideType
Definition: weight.h:165
const W & Back() const
Definition: union-weight.h:157
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
constexpr float kDelta
Definition: weight.h:133
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
size_t Size() const
Definition: union-weight.h:155