FST  openfst-1.7.1 OpenFst Library
union-weight.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Union weight set and associated semiring operation definitions.
5 //
6 // TODO(riley): add in normalizer functor
7
8 #ifndef FST_UNION_WEIGHT_H_
9 #define FST_UNION_WEIGHT_H_
10
11 #include <cstdlib>
12
13 #include <iostream>
14 #include <list>
15 #include <sstream>
16 #include <string>
17 #include <utility>
18
19 #include <fst/weight.h>
20
21
22 namespace fst {
23
24 // Example UnionWeightOptions for UnionWeight template below. The Merge
25 // operation is used to collapse elements of the set and the Compare function
26 // to efficiently implement the merge. In the simplest case, merge would just
27 // apply with equality of set elements so the result is a set (and not a
28 // multiset). More generally, this can be used to maintain the multiplicity or
29 // other such weight associated with the set elements (cf. Gallic weights).
30
31 // template <class W>
32 // struct UnionWeightOptions {
33 // // Comparison function C is a total order on W that is monotonic w.r.t. to
34 // // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is
35 // // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a).
36 // //
37 // // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where
38 // // ~ is an equivalence relation on W. Also we require a ~ b iff
39 // // a.Reverse() ~ b.Reverse().
40 // using Compare = NaturalLess<W>;
41 //
42 // // How to combine two weights if a ~ b as above. For all a, b: a ~ b =>
43 // // merge(a, b) ~ a, Merge must define a semiring endomorphism from the
44 // // unmerged weight sets to the merged weight sets.
45 // struct Merge {
46 // W operator()(const W &w1, const W &w2) const { return w1; }
47 // };
48 //
49 // // For ReverseWeight.
50 // using ReverseOptions = UnionWeightOptions<ReverseWeight>;
51 // };
52
53 template <class W, class O>
55
56 template <class W, class O>
58
59 template <class W, class O>
61
62 template <class W, class O>
63 bool operator==(const UnionWeight<W, O> &, const UnionWeight<W, O> &);
64
65 // Semiring that uses Times() and One() from W and union and the empty set
66 // for Plus() and Zero(), respectively. Template argument O specifies the union
67 // weight options as above.
68 template <class W, class O>
69 class UnionWeight {
70  public:
71  using Weight = W;
72  using Compare = typename O::Compare;
73  using Merge = typename O::Merge;
74
75  using ReverseWeight =
77
78  friend class UnionWeightIterator<W, O>;
79  friend class UnionWeightReverseIterator<W, O>;
80  friend bool operator==
81  <>(const UnionWeight<W, O> &, const UnionWeight<W, O> &);
82
83  // Sets represented as first_ weight + rest_ weights. Uses first_ as
84  // NoWeight() to indicate the union weight Zero() ask the empty set. Uses
85  // rest_ containing NoWeight() to indicate the union weight NoWeight().
86  UnionWeight() : first_(W::NoWeight()) {}
87
88  explicit UnionWeight(W weight) : first_(weight) {
89  if (weight == W::NoWeight()) rest_.push_back(weight);
90  }
91
92  static const UnionWeight<W, O> &Zero() {
93  static const UnionWeight<W, O> zero(W::NoWeight());
94  return zero;
95  }
96
97  static const UnionWeight<W, O> &One() {
98  static const UnionWeight<W, O> one(W::One());
99  return one;
100  }
101
102  static const UnionWeight<W, O> &NoWeight() {
103  static const UnionWeight<W, O> no_weight(W::Zero(), W::NoWeight());
104  return no_weight;
105  }
106
107  static const string &Type() {
108  static const string *const type = new string(W::Type() + "_union");
109  return *type;
110  }
111
112  static constexpr uint64 Properties() {
113  return W::Properties() &
115  }
116
117  bool Member() const;
118
120
121  std::ostream &Write(std::ostream &strm) const;
122
123  size_t Hash() const;
124
125  UnionWeight<W, O> Quantize(float delta = kDelta) const;
126
127  ReverseWeight Reverse() const;
128
129  // These operations combined with the UnionWeightIterator and
130  // UnionWeightReverseIterator provide the access and mutation of the union
131  // weight internal elements.
132
133  // Common initializer among constructors; clears existing UnionWeight.
134  void Clear() {
135  first_ = W::NoWeight();
136  rest_.clear();
137  }
138
139  size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; }
140
141  const W &Back() const { return rest_.empty() ? first_ : rest_.back(); }
142
143  // When srt is true, assumes elements added sorted w.r.t Compare and merging
144  // of weights performed as needed. Otherwise, just ensures first_ is the
145  // least element wrt Compare.
146  void PushBack(W weight, bool srt);
147
148  // Sorts the elements of the set. Assumes that first_, if present, is the
149  // least element.
150  void Sort() { rest_.sort(comp_); }
151
152  private:
153  W &Back() {
154  if (rest_.empty()) {
155  return first_;
156  } else {
157  return rest_.back();
158  }
159  }
160
161  UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {}
162
163  W first_; // First weight in set.
164  std::list<W> rest_; // Remaining weights in set.
165  Compare comp_;
166  Merge merge_;
167 };
168
169 template <class W, class O>
170 void UnionWeight<W, O>::PushBack(W weight, bool srt) {
171  if (!weight.Member()) {
172  rest_.push_back(std::move(weight));
173  } else if (!first_.Member()) {
174  first_ = std::move(weight);
175  } else if (srt) {
176  auto &back = Back();
177  if (comp_(back, weight)) {
178  rest_.push_back(std::move(weight));
179  } else {
180  back = merge_(back, std::move(weight));
181  }
182  } else {
183  if (comp_(first_, weight)) {
184  rest_.push_back(std::move(weight));
185  } else {
186  rest_.push_back(first_);
187  first_ = std::move(weight);
188  }
189  }
190 }
191
192 // Traverses union weight in the forward direction.
193 template <class W, class O>
194 class UnionWeightIterator {
195  public:
196  explicit UnionWeightIterator(const UnionWeight<W, O> &weight)
197  : first_(weight.first_),
198  rest_(weight.rest_),
199  init_(true),
200  it_(rest_.begin()) {}
201
202  bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); }
203
204  const W &Value() const { return init_ ? first_ : *it_; }
205
206  void Next() {
207  if (init_) {
208  init_ = false;
209  } else {
210  ++it_;
211  }
212  }
213
214  void Reset() {
215  init_ = true;
216  it_ = rest_.begin();
217  }
218
219  private:
220  const W &first_;
221  const std::list<W> &rest_;
222  bool init_; // in the initialized state?
223  typename std::list<W>::const_iterator it_;
224 };
225
226 // Traverses union weight in backward direction.
227 template <typename L, class O>
229  public:
231  : first_(weight.first_),
232  rest_(weight.rest_),
233  fin_(!first_.Member()),
234  it_(rest_.rbegin()) {}
235
236  bool Done() const { return fin_; }
237
238  const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; }
239
240  void Next() {
241  if (it_ == rest_.rend()) {
242  fin_ = true;
243  } else {
244  ++it_;
245  }
246  }
247
248  void Reset() {
249  fin_ = !first_.Member();
250  it_ = rest_.rbegin();
251  }
252
253  private:
254  const L &first_;
255  const std::list<L> &rest_;
256  bool fin_; // in the final state?
257  typename std::list<L>::const_reverse_iterator it_;
258 };
259
260 // UnionWeight member functions follow that require UnionWeightIterator.
261 template <class W, class O>
262 inline std::istream &UnionWeight<W, O>::Read(std::istream &istrm) {
263  Clear();
264  int32 size;
266  for (int i = 0; i < size; ++i) {
267  W weight;
269  PushBack(weight, true);
270  }
271  return istrm;
272 }
273
274 template <class W, class O>
275 inline std::ostream &UnionWeight<W, O>::Write(std::ostream &ostrm) const {
276  const int32 size = Size();
277  WriteType(ostrm, size);
278  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
279  WriteType(ostrm, it.Value());
280  }
281  return ostrm;
282 }
283
284 template <class W, class O>
285 inline bool UnionWeight<W, O>::Member() const {
286  if (Size() <= 1) return true;
287  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
288  if (!it.Value().Member()) return false;
289  }
290  return true;
291 }
292
293 template <class W, class O>
295  UnionWeight<W, O> weight;
296  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
297  weight.PushBack(it.Value().Quantize(delta), true);
298  }
299  return weight;
300 }
301
302 template <class W, class O>
304  const {
305  ReverseWeight weight;
306  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
307  weight.PushBack(it.Value().Reverse(), false);
308  }
309  weight.Sort();
310  return weight;
311 }
312
313 template <class W, class O>
314 inline size_t UnionWeight<W, O>::Hash() const {
315  size_t h = 0;
316  static constexpr int lshift = 5;
317  static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift;
318  for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
319  h = h << lshift ^ h >> rshift ^ it.Value().Hash();
320  }
321  return h;
322 }
323
324 // Requires union weight has been canonicalized.
325 template <class W, class O>
326 inline bool operator==(const UnionWeight<W, O> &w1,
327  const UnionWeight<W, O> &w2) {
328  if (w1.Size() != w2.Size()) return false;
331  for (; !it1.Done(); it1.Next(), it2.Next()) {
332  if (it1.Value() != it2.Value()) return false;
333  }
334  return true;
335 }
336
337 // Requires union weight has been canonicalized.
338 template <class W, class O>
339 inline bool operator!=(const UnionWeight<W, O> &w1,
340  const UnionWeight<W, O> &w2) {
341  return !(w1 == w2);
342 }
343
344 // Requires union weight has been canonicalized.
345 template <class W, class O>
346 inline bool ApproxEqual(const UnionWeight<W, O> &w1,
347  const UnionWeight<W, O> &w2, float delta = kDelta) {
348  if (w1.Size() != w2.Size()) return false;
351  for (; !it1.Done(); it1.Next(), it2.Next()) {
352  if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false;
353  }
354  return true;
355 }
356
357 template <class W, class O>
358 inline std::ostream &operator<<(std::ostream &ostrm,
359  const UnionWeight<W, O> &weight) {
360  UnionWeightIterator<W, O> it(weight);
361  if (it.Done()) {
362  return ostrm << "EmptySet";
363  } else if (!weight.Member()) {
365  } else {
366  CompositeWeightWriter writer(ostrm);
367  writer.WriteBegin();
368  for (; !it.Done(); it.Next()) writer.WriteElement(it.Value());
369  writer.WriteEnd();
370  }
371  return ostrm;
372 }
373
374 template <class W, class O>
375 inline std::istream &operator>>(std::istream &istrm,
376  UnionWeight<W, O> &weight) {
377  string s;
378  istrm >> s;
379  if (s == "EmptySet") {
380  weight = UnionWeight<W, O>::Zero();
381  } else if (s == "BadSet") {
382  weight = UnionWeight<W, O>::NoWeight();
383  } else {
384  weight = UnionWeight<W, O>::Zero();
385  std::istringstream sstrm(s);
388  bool more = true;
389  while (more) {
390  W v;
392  weight.PushBack(v, true);
393  }
395  }
396  return istrm;
397 }
398
399 template <class W, class O>
401  const UnionWeight<W, O> &w2) {
402  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
403  if (w1 == UnionWeight<W, O>::Zero()) return w2;
404  if (w2 == UnionWeight<W, O>::Zero()) return w1;
407  UnionWeight<W, O> sum;
408  typename O::Compare comp;
409  while (!it1.Done() && !it2.Done()) {
410  const auto v1 = it1.Value();
411  const auto v2 = it2.Value();
412  if (comp(v1, v2)) {
413  sum.PushBack(v1, true);
414  it1.Next();
415  } else {
416  sum.PushBack(v2, true);
417  it2.Next();
418  }
419  }
420  for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true);
421  for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true);
422  return sum;
423 }
424
425 template <class W, class O>
427  const UnionWeight<W, O> &w2) {
428  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
429  if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
430  return UnionWeight<W, O>::Zero();
431  }
434  UnionWeight<W, O> prod1;
435  for (; !it1.Done(); it1.Next()) {
436  UnionWeight<W, O> prod2;
437  for (; !it2.Done(); it2.Next()) {
438  prod2.PushBack(Times(it1.Value(), it2.Value()), true);
439  }
440  prod1 = Plus(prod1, prod2);
441  it2.Reset();
442  }
443  return prod1;
444 }
445
446 template <class W, class O>
448  const UnionWeight<W, O> &w2, DivideType typ) {
449  if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
450  if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
451  return UnionWeight<W, O>::Zero();
452  }
455  UnionWeight<W, O> quot;
456  if (w1.Size() == 1) {
457  for (; !it2.Done(); it2.Next()) {
458  quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
459  }
460  } else if (w2.Size() == 1) {
461  for (; !it1.Done(); it1.Next()) {
462  quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
463  }
464  } else {
466  }
467  return quot;
468 }
469
470 // This function object generates weights over the union of weights for the
471 // underlying generators for the template weight types. This is intended
472 // primarily for testing.
473 template <class W, class O>
475  public:
478
479  explicit WeightGenerate(bool allow_zero = true,
480  size_t num_random_weights = kNumRandomWeights)
481  : generate_(false), allow_zero_(allow_zero),
482  num_random_weights_(num_random_weights) {}
483
484  Weight operator()() const {
485  const int n = rand() % (num_random_weights_ + 1); // NOLINT
486  if (allow_zero_ && n == num_random_weights_) {
487  return Weight::Zero();
488  } else if (n % 2 == 0) {
489  return Weight(generate_());
490  } else {
491  return Plus(Weight(generate_()), Weight(generate_()));
492  }
493  }
494
495  private:
496  Generate generate_;
497  // Permits Zero() and zero divisors.
498  bool allow_zero_;
499  // The number of alternative random weights.
500  const size_t num_random_weights_;
501 };
502
503 } // namespace fst
504
505 #endif // FST_UNION_WEIGHT_H_
ExpectationWeight< X1, X2 > Divide(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2, DivideType typ=DIVIDE_ANY)
uint64_t uint64
Definition: types.h:32
UnionWeight(W weight)
Definition: union-weight.h:88
constexpr uint64 kRightSemiring
Definition: weight.h:115
static constexpr uint64 Properties()
Definition: union-weight.h:112
const W & Value() const
Definition: union-weight.h:204
size_t Hash() const
Definition: union-weight.h:314
static const string & Type()
Definition: union-weight.h:107
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
std::ostream & Write(std::ostream &strm) const
Definition: union-weight.h:275
constexpr uint64 kCommutative
Definition: weight.h:120
Definition: union-weight.h:262
constexpr uint64 kLeftSemiring
Definition: weight.h:112
static const UnionWeight< W, O > & Zero()
Definition: union-weight.h:92
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
ReverseWeight Reverse() const
Definition: union-weight.h:303
bool Member() const
Definition: union-weight.h:285
void PushBack(W weight, bool srt)
Definition: union-weight.h:170
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:160
bool operator==(const PdtStateTuple< S, K > &x, const PdtStateTuple< S, K > &y)
Definition: pdt.h:133
constexpr uint64 kIdempotent
Definition: weight.h:123
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
std::ostream & operator<<(std::ostream &strm, const FloatWeightTpl< T > &w)
Definition: float-weight.h:146
WeightGenerate(bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Definition: union-weight.h:479
constexpr bool operator!=(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2)
Definition: float-weight.h:119
static const UnionWeight< W, O > & NoWeight()
Definition: union-weight.h:102
UnionWeightReverseIterator(const UnionWeight< L, O > &weight)
Definition: union-weight.h:230
friend bool operator==(const UnionWeight< W, O > &, const UnionWeight< W, O > &)
Definition: union-weight.h:326
UnionWeightIterator(const UnionWeight< W, O > &weight)
Definition: union-weight.h:196
Definition: weight.h:349
int32_t int32
Definition: types.h:26
constexpr bool ApproxEqual(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2, float delta=kDelta)
Definition: float-weight.h:140
UnionWeight< W, O > Quantize(float delta=kDelta) const
Definition: union-weight.h:294
constexpr size_t kNumRandomWeights
Definition: weight.h:130
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
void WriteElement(const T &comp)
Definition: weight.h:299
static const UnionWeight< W, O > & One()
Definition: union-weight.h:97
DivideType
Definition: weight.h:142
const W & Back() const
Definition: union-weight.h:141
constexpr float kDelta
Definition: weight.h:109
size_t Size() const
Definition: union-weight.h:139