FST  openfst-1.8.2
OpenFst Library
set-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Weights consisting of sets (of integral Labels) and
19 // associated semiring operation definitions using intersect
20 // and union.
21 
22 #ifndef FST_SET_WEIGHT_H_
23 #define FST_SET_WEIGHT_H_
24 
25 #include <algorithm>
26 #include <cstdint>
27 #include <list>
28 #include <random>
29 #include <string>
30 #include <vector>
31 
32 #include <fst/union-weight.h>
33 #include <fst/weight.h>
34 #include <string_view>
35 
36 
37 namespace fst {
38 
39 inline constexpr int kSetEmpty = 0; // Label for the empty set.
40 inline constexpr int kSetUniv = -1; // Label for the universal set.
41 inline constexpr int kSetBad = -2; // Label for a non-set.
42 inline constexpr char kSetSeparator = '_'; // Label separator in sets.
43 
44 // Determines whether to use (intersect, union) or (union, intersect)
45 // as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a
46 // restricted version of (intersect, union) that requires summed
47 // arguments to be equal (or an error is signalled), useful for
48 // algorithms that require a unique labelled path weight. SET_BOOLEAN
49 // treats all non-Zero() elements as equivalent (with Zero() ==
50 // UnivSet()), useful for algorithms that don't really depend on the
51 // detailed sets.
52 enum SetType {
57 };
58 
59 template <class>
61 
62 // Set semiring of integral labels.
63 template <typename L, SetType S = SET_INTERSECT_UNION>
64 class SetWeight {
65  public:
66  using Label = L;
70  // Allow type-converting copy and move constructors private access.
71  template <typename L2, SetType S2>
72  friend class SetWeight;
73 
74  SetWeight() {}
75 
76  // Input should be positive, sorted and unique.
77  template <typename Iterator>
78  SetWeight(const Iterator begin, const Iterator end) {
79  for (auto iter = begin; iter != end; ++iter) PushBack(*iter);
80  }
81 
82  // Input should be positive. (Non-positive value has
83  // special internal meaning w.r.t. integral constants above.)
84  explicit SetWeight(Label label) { PushBack(label); }
85 
86  template <SetType S2>
87  explicit SetWeight(const SetWeight<Label, S2> &w)
88  : first_(w.first_), rest_(w.rest_) {}
89 
90  template <SetType S2>
92  : first_(w.first_), rest_(std::move(w.rest_)) {
93  w.Clear();
94  }
95 
96  template <SetType S2>
98  first_ = w.first_;
99  rest_ = w.rest_;
100  return *this;
101  }
102 
103  template <SetType S2>
105  first_ = w.first_;
106  rest_ = std::move(w.rest_);
107  w.Clear();
108  return *this;
109  }
110 
111  static const SetWeight &Zero() {
112  return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet();
113  }
114 
115  static const SetWeight &One() {
116  return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet();
117  }
118 
119  static const SetWeight &NoWeight() {
120  static const auto *const no_weight = new SetWeight(Label(kSetBad));
121  return *no_weight;
122  }
123 
124  static const std::string &Type() {
125  static const std::string *const type =
126  new std::string(S == SET_UNION_INTERSECT
127  ? "union_intersect_set"
128  : (S == SET_INTERSECT_UNION
129  ? "intersect_union_set"
131  ? "restricted_set_intersect_union"
132  : "boolean_set")));
133  return *type;
134  }
135 
136  bool Member() const;
137 
138  std::istream &Read(std::istream &strm);
139 
140  std::ostream &Write(std::ostream &strm) const;
141 
142  size_t Hash() const;
143 
144  SetWeight Quantize(float delta = kDelta) const { return *this; }
145 
146  ReverseWeight Reverse() const;
147 
148  static constexpr uint64_t Properties() {
150  }
151 
152  // These operations combined with the SetWeightIterator
153  // provide the access and mutation of the set internal elements.
154 
155  // The empty set.
156  static const SetWeight &EmptySet() {
157  static const auto *const empty = new SetWeight(Label(kSetEmpty));
158  return *empty;
159  }
160 
161  // The univeral set.
162  static const SetWeight &UnivSet() {
163  static const auto *const univ = new SetWeight(Label(kSetUniv));
164  return *univ;
165  }
166 
167  // Clear existing SetWeight.
168  void Clear() {
169  first_ = kSetEmpty;
170  rest_.clear();
171  }
172 
173  size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
174 
176  if (rest_.empty()) {
177  return first_;
178  } else {
179  return rest_.back();
180  }
181  }
182 
183  // Caller must add in sort order and be unique (or error signalled).
184  // Input should also be positive. Non-positive value (for the first
185  // push) has special internal meaning w.r.t. integral constants above.
186  void PushBack(Label label) {
187  if (first_ == kSetEmpty) {
188  first_ = label;
189  } else {
190  if (label <= Back() || label <= 0) {
191  FSTERROR() << "SetWeight: labels must be positive, added"
192  << " in sort order and be unique.";
193  rest_.push_back(Label(kSetBad));
194  }
195  rest_.push_back(label);
196  }
197  }
198 
199  private:
200  Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty).
201  std::list<Label> rest_; // Remaining labels in set.
202 };
203 
204 // Traverses set in forward direction.
205 template <class SetWeight_>
206 class SetWeightIterator {
207  public:
208  using Weight = SetWeight_;
209  using Label = typename Weight::Label;
210 
211  explicit SetWeightIterator(const Weight &w)
212  : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
213 
214  bool Done() const {
215  if (init_) {
216  return first_ == kSetEmpty;
217  } else {
218  return iter_ == rest_.end();
219  }
220  }
221 
222  const Label &Value() const { return init_ ? first_ : *iter_; }
223 
224  void Next() {
225  if (init_) {
226  init_ = false;
227  } else {
228  ++iter_;
229  }
230  }
231 
232  void Reset() {
233  init_ = true;
234  iter_ = rest_.begin();
235  }
236 
237  private:
238  const Label &first_;
239  const decltype(Weight::rest_) &rest_;
240  bool init_; // In the initialized state?
241  typename decltype(Weight::rest_)::const_iterator iter_;
242 };
243 
244 // SetWeight member functions follow that require SetWeightIterator
245 
246 template <typename Label, SetType S>
247 inline std::istream &SetWeight<Label, S>::Read(std::istream &strm) {
248  Clear();
249  int32_t size;
250  ReadType(strm, &size);
251  for (int32_t i = 0; i < size; ++i) {
252  Label label;
253  ReadType(strm, &label);
254  PushBack(label);
255  }
256  return strm;
257 }
258 
259 template <typename Label, SetType S>
260 inline std::ostream &SetWeight<Label, S>::Write(std::ostream &strm) const {
261  const int32_t size = Size();
262  WriteType(strm, size);
263  for (Iterator iter(*this); !iter.Done(); iter.Next()) {
264  WriteType(strm, iter.Value());
265  }
266  return strm;
267 }
268 
269 template <typename Label, SetType S>
270 inline bool SetWeight<Label, S>::Member() const {
271  Iterator iter(*this);
272  return iter.Value() != Label(kSetBad);
273 }
274 
275 template <typename Label, SetType S>
278  return *this;
279 }
280 
281 template <typename Label, SetType S>
282 inline size_t SetWeight<Label, S>::Hash() const {
283  using Weight = SetWeight<Label, S>;
284  if (S == SET_BOOLEAN) {
285  return *this == Weight::Zero() ? 0 : 1;
286  } else {
287  size_t h = 0;
288  for (Iterator iter(*this); !iter.Done(); iter.Next()) {
289  h ^= h << 1 ^ iter.Value();
290  }
291  return h;
292  }
293 }
294 
295 // Default ==
296 template <typename Label, SetType S>
297 inline bool operator==(const SetWeight<Label, S> &w1,
298  const SetWeight<Label, S> &w2) {
299  if (w1.Size() != w2.Size()) return false;
300  using Iterator = typename SetWeight<Label, S>::Iterator;
301  Iterator iter1(w1);
302  Iterator iter2(w2);
303  for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
304  if (iter1.Value() != iter2.Value()) return false;
305  }
306  return true;
307 }
308 
309 // Boolean ==
310 template <typename Label>
312  const SetWeight<Label, SET_BOOLEAN> &w2) {
313  // x == kSetEmpty if x \nin {kUnivSet, kSetBad}
314  if (!w1.Member() || !w2.Member()) return false;
316  Iterator iter1(w1);
317  Iterator iter2(w2);
318  Label label1 = iter1.Done() ? kSetEmpty : iter1.Value();
319  Label label2 = iter2.Done() ? kSetEmpty : iter2.Value();
320  if (label1 == kSetUniv) return label2 == kSetUniv;
321  if (label2 == kSetUniv) return label1 == kSetUniv;
322  return true;
323 }
324 
325 template <typename Label, SetType S>
326 inline bool operator!=(const SetWeight<Label, S> &w1,
327  const SetWeight<Label, S> &w2) {
328  return !(w1 == w2);
329 }
330 
331 template <typename Label, SetType S>
332 inline bool ApproxEqual(const SetWeight<Label, S> &w1,
333  const SetWeight<Label, S> &w2, float delta = kDelta) {
334  return w1 == w2;
335 }
336 
337 template <typename Label, SetType S>
338 inline std::ostream &operator<<(std::ostream &strm,
339  const SetWeight<Label, S> &weight) {
340  typename SetWeight<Label, S>::Iterator iter(weight);
341  if (iter.Done()) {
342  return strm << "EmptySet";
343  } else if (iter.Value() == Label(kSetUniv)) {
344  return strm << "UnivSet";
345  } else if (iter.Value() == Label(kSetBad)) {
346  return strm << "BadSet";
347  } else {
348  for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
349  if (i > 0) strm << kSetSeparator;
350  strm << iter.Value();
351  }
352  }
353  return strm;
354 }
355 
356 template <typename Label, SetType S>
357 inline std::istream &operator>>(std::istream &strm,
358  SetWeight<Label, S> &weight) {
359  std::string str;
360  strm >> str;
361  using Weight = SetWeight<Label, S>;
362  if (str == "EmptySet") {
363  weight = Weight(Label(kSetEmpty));
364  } else if (str == "UnivSet") {
365  weight = Weight(Label(kSetUniv));
366  } else {
367  weight.Clear();
368  for (std::string_view sv : StrSplit(str, kSetSeparator)) {
369  auto maybe_label = ParseInt64(sv);
370  if (!maybe_label.has_value()) {
371  strm.clear(std::ios::badbit);
372  break;
373  }
374  weight.PushBack(*maybe_label);
375  }
376  }
377  return strm;
378 }
379 
380 template <typename Label, SetType S>
382  const SetWeight<Label, S> &w2) {
383  using Weight = SetWeight<Label, S>;
384  using Iterator = typename SetWeight<Label, S>::Iterator;
385  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
386  if (w1 == Weight::EmptySet()) return w2;
387  if (w2 == Weight::EmptySet()) return w1;
388  if (w1 == Weight::UnivSet()) return w1;
389  if (w2 == Weight::UnivSet()) return w2;
390  Iterator it1(w1);
391  Iterator it2(w2);
392  Weight result;
393  while (!it1.Done() && !it2.Done()) {
394  const auto v1 = it1.Value();
395  const auto v2 = it2.Value();
396  if (v1 < v2) {
397  result.PushBack(v1);
398  it1.Next();
399  } else if (v1 > v2) {
400  result.PushBack(v2);
401  it2.Next();
402  } else {
403  result.PushBack(v1);
404  it1.Next();
405  it2.Next();
406  }
407  }
408  for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
409  for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value());
410  return result;
411 }
412 
413 template <typename Label, SetType S>
415  const SetWeight<Label, S> &w2) {
416  using Weight = SetWeight<Label, S>;
417  using Iterator = typename SetWeight<Label, S>::Iterator;
418  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
419  if (w1 == Weight::EmptySet()) return w1;
420  if (w2 == Weight::EmptySet()) return w2;
421  if (w1 == Weight::UnivSet()) return w2;
422  if (w2 == Weight::UnivSet()) return w1;
423  Iterator it1(w1);
424  Iterator it2(w2);
425  Weight result;
426  while (!it1.Done() && !it2.Done()) {
427  const auto v1 = it1.Value();
428  const auto v2 = it2.Value();
429  if (v1 < v2) {
430  it1.Next();
431  } else if (v1 > v2) {
432  it2.Next();
433  } else {
434  result.PushBack(v1);
435  it1.Next();
436  it2.Next();
437  }
438  }
439  return result;
440 }
441 
442 template <typename Label, SetType S>
444  const SetWeight<Label, S> &w2) {
445  using Weight = SetWeight<Label, S>;
446  using Iterator = typename SetWeight<Label, S>::Iterator;
447  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
448  if (w1 == Weight::EmptySet()) return w1;
449  if (w2 == Weight::EmptySet()) return w1;
450  if (w2 == Weight::UnivSet()) return Weight::EmptySet();
451  Iterator it1(w1);
452  Iterator it2(w2);
453  Weight result;
454  while (!it1.Done() && !it2.Done()) {
455  const auto v1 = it1.Value();
456  const auto v2 = it2.Value();
457  if (v1 < v2) {
458  result.PushBack(v1);
459  it1.Next();
460  } else if (v1 > v2) {
461  it2.Next();
462  } else {
463  it1.Next();
464  it2.Next();
465  }
466  }
467  for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
468  return result;
469 }
470 
471 // Default: Plus = Intersect.
472 template <typename Label, SetType S>
474  const SetWeight<Label, S> &w2) {
475  return Intersect(w1, w2);
476 }
477 
478 // Plus = Union.
479 template <typename Label>
483  return Union(w1, w2);
484 }
485 
486 // Plus = Set equality is required (for non-Zero() input). The
487 // restriction is useful (e.g., in determinization) to ensure the input
488 // has a unique labelled path weight.
489 template <typename Label>
494  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
495  if (w1 == Weight::Zero()) return w2;
496  if (w2 == Weight::Zero()) return w1;
497  if (w1 != w2) {
498  FSTERROR() << "SetWeight::Plus: Unequal arguments "
499  << "(non-unique labelled path weights?)"
500  << " w1 = " << w1 << " w2 = " << w2;
501  return Weight::NoWeight();
502  }
503  return w1;
504 }
505 
506 // Plus = Or.
507 template <typename Label>
510  const SetWeight<Label, SET_BOOLEAN> &w2) {
511  using Weight = SetWeight<Label, SET_BOOLEAN>;
512  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
513  if (w1 == Weight::One()) return w1;
514  if (w2 == Weight::One()) return w2;
515  return Weight::Zero();
516 }
517 
518 // Default: Times = Union.
519 template <typename Label, SetType S>
521  const SetWeight<Label, S> &w2) {
522  return Union(w1, w2);
523 }
524 
525 // Times = Intersect.
526 template <typename Label>
530  return Intersect(w1, w2);
531 }
532 
533 // Times = And.
534 template <typename Label>
537  const SetWeight<Label, SET_BOOLEAN> &w2) {
538  using Weight = SetWeight<Label, SET_BOOLEAN>;
539  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
540  if (w1 == Weight::One()) return w2;
541  return w1;
542 }
543 
544 // Divide = Difference.
545 template <typename Label, SetType S>
547  const SetWeight<Label, S> &w2,
548  DivideType divide_type = DIVIDE_ANY) {
549  return Difference(w1, w2);
550 }
551 
552 // Divide = dividend (or the universal set if the
553 // dividend == divisor).
554 template <typename Label>
558  DivideType divide_type = DIVIDE_ANY) {
560  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
561  if (w1 == w2) return Weight::UnivSet();
562  return w1;
563 }
564 
565 // Divide = Or Not.
566 template <typename Label>
570  DivideType divide_type = DIVIDE_ANY) {
571  using Weight = SetWeight<Label, SET_BOOLEAN>;
572  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
573  if (w1 == Weight::One()) return w1;
574  if (w2 == Weight::Zero()) return Weight::One();
575  return Weight::Zero();
576 }
577 
578 // Converts between different set types.
579 template <typename Label, SetType S1, SetType S2>
584  for (Iterator iter(w1); !iter.Done(); iter.Next())
585  w2.PushBack(iter.Value());
586  return w2;
587  }
588 };
589 
590 // This function object generates SetWeights that are random integer sets
591 // from {1, ... , alphabet_size}^{0, max_set_length} U { Zero }. This is
592 // intended primarily for testing.
593 template <class Label, SetType S>
595  public:
597 
598  explicit WeightGenerate(uint64_t seed = std::random_device()(),
599  bool allow_zero = true,
600  size_t alphabet_size = kNumRandomWeights,
601  size_t max_set_length = kNumRandomWeights)
602  : allow_zero_(allow_zero),
603  alphabet_size_(alphabet_size),
604  max_set_length_(max_set_length) {}
605 
606  Weight operator()() const {
607  const int n = std::uniform_int_distribution<>(
608  0, max_set_length_ + allow_zero_ - 1)(rand_);
609  if (allow_zero_ && n == max_set_length_) return Weight::Zero();
610  std::vector<Label> labels;
611  labels.reserve(n);
612  for (int i = 0; i < n; ++i) {
613  labels.push_back(
614  std::uniform_int_distribution<>(0, alphabet_size_)(rand_));
615  }
616  std::sort(labels.begin(), labels.end());
617  const auto labels_end = std::unique(labels.begin(), labels.end());
618  labels.resize(labels_end - labels.begin());
619  return Weight(labels.begin(), labels.end());
620  }
621 
622  private:
623  mutable std::mt19937_64 rand_;
624  const bool allow_zero_;
625  const size_t alphabet_size_;
626  const size_t max_set_length_;
627 };
628 
629 } // namespace fst
630 
631 #endif // FST_SET_WEIGHT_H_
static const std::string & Type()
Definition: set-weight.h:124
std::ostream & Write(std::ostream &strm) const
Definition: set-weight.h:260
static const SetWeight & EmptySet()
Definition: set-weight.h:156
constexpr char kSetSeparator
Definition: set-weight.h:42
static constexpr uint64_t Properties()
Definition: set-weight.h:148
SetWeight & operator=(const SetWeight< Label, S2 > &w)
Definition: set-weight.h:97
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
ReverseWeight Reverse() const
Definition: set-weight.h:277
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t alphabet_size=kNumRandomWeights, size_t max_set_length=kNumRandomWeights)
Definition: set-weight.h:598
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
SetWeight & operator=(SetWeight< Label, S2 > &&w)
Definition: set-weight.h:104
SetType
Definition: set-weight.h:52
void Intersect(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const IntersectOptions &opts=IntersectOptions())
Definition: intersect.h:141
constexpr uint64_t kIdempotent
Definition: weight.h:144
constexpr int kSetUniv
Definition: set-weight.h:40
static const SetWeight & NoWeight()
Definition: set-weight.h:119
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
Definition: compat.cc:81
SetWeight(Label label)
Definition: set-weight.h:84
std::istream & Read(std::istream &strm)
Definition: set-weight.h:247
constexpr uint64_t kRightSemiring
Definition: weight.h:136
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:211
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
Definition: difference.h:165
const Label & Value() const
Definition: set-weight.h:222
#define FSTERROR()
Definition: util.h:53
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
Definition: util.cc:42
typename Weight::Label Label
Definition: set-weight.h:209
Label Back()
Definition: set-weight.h:175
size_t Hash() const
Definition: set-weight.h:282
bool Member() const
Definition: set-weight.h:270
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:181
SetWeightIterator< SetWeight > Iterator
Definition: set-weight.h:68
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
Definition: union.h:110
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:53
SetWeightIterator(const Weight &w)
Definition: set-weight.h:211
SetWeight(const Iterator begin, const Iterator end)
Definition: set-weight.h:78
constexpr uint64_t kCommutative
Definition: weight.h:141
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:70
SetWeight(SetWeight< Label, S2 > &&w)
Definition: set-weight.h:91
static const SetWeight & UnivSet()
Definition: set-weight.h:162
SetWeight< Label, S2 > operator()(const SetWeight< Label, S1 > &w1) const
Definition: set-weight.h:581
static const SetWeight & Zero()
Definition: set-weight.h:111
void PushBack(Label label)
Definition: set-weight.h:186
constexpr int kSetBad
Definition: set-weight.h:41
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:66
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:50
constexpr size_t kNumRandomWeights
Definition: weight.h:151
static const SetWeight & One()
Definition: set-weight.h:115
SetWeight Quantize(float delta=kDelta) const
Definition: set-weight.h:144
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:65
size_t Size() const
Definition: set-weight.h:173
DivideType
Definition: weight.h:162
constexpr uint64_t kLeftSemiring
Definition: weight.h:133
constexpr float kDelta
Definition: weight.h:130
bool Done() const
Definition: set-weight.h:214
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:57
constexpr int kSetEmpty
Definition: set-weight.h:39
SetWeight(const SetWeight< Label, S2 > &w)
Definition: set-weight.h:87