FST  openfst-1.8.3
OpenFst Library
set-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Weights consisting of sets (of integral Labels) and
19 // associated semiring operation definitions using intersect
20 // and union.
21 
22 #ifndef FST_SET_WEIGHT_H_
23 #define FST_SET_WEIGHT_H_
24 
25 #include <algorithm>
26 #include <cstddef>
27 #include <cstdint>
28 #include <ios>
29 #include <istream>
30 #include <list>
31 #include <optional>
32 #include <ostream>
33 #include <random>
34 #include <string>
35 #include <utility>
36 #include <vector>
37 
38 #include <fst/log.h>
39 #include <fst/union-weight.h>
40 #include <fst/util.h>
41 #include <fst/weight.h>
42 #include <string_view>
43 
44 namespace fst {
45 
46 inline constexpr int kSetEmpty = 0; // Label for the empty set.
47 inline constexpr int kSetUniv = -1; // Label for the universal set.
48 inline constexpr int kSetBad = -2; // Label for a non-set.
49 inline constexpr char kSetSeparator = '_'; // Label separator in sets.
50 
51 // Determines whether to use (intersect, union) or (union, intersect)
52 // as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a
53 // restricted version of (intersect, union) that requires summed
54 // arguments to be equal (or an error is signalled), useful for
55 // algorithms that require a unique labelled path weight. SET_BOOLEAN
56 // treats all non-Zero() elements as equivalent (with Zero() ==
57 // UnivSet()), useful for algorithms that don't really depend on the
58 // detailed sets.
59 enum SetType {
64 };
65 
66 template <class>
68 
69 // Set semiring of integral labels.
70 template <typename L, SetType S = SET_INTERSECT_UNION>
71 class SetWeight {
72  public:
73  using Label = L;
77  // Allow type-converting copy and move constructors private access.
78  template <typename L2, SetType S2>
79  friend class SetWeight;
80 
81  SetWeight() = default;
82 
83  // Input should be positive, sorted and unique.
84  template <typename Iterator>
85  SetWeight(const Iterator begin, const Iterator end) {
86  for (auto iter = begin; iter != end; ++iter) PushBack(*iter);
87  }
88 
89  // Input should be positive. (Non-positive value has
90  // special internal meaning w.r.t. integral constants above.)
91  explicit SetWeight(Label label) { PushBack(label); }
92 
93  template <SetType S2>
94  explicit SetWeight(const SetWeight<Label, S2> &w)
95  : first_(w.first_), rest_(w.rest_) {}
96 
97  template <SetType S2>
99  : first_(w.first_), rest_(std::move(w.rest_)) {
100  w.Clear();
101  }
102 
103  template <SetType S2>
105  first_ = w.first_;
106  rest_ = w.rest_;
107  return *this;
108  }
109 
110  template <SetType S2>
112  first_ = w.first_;
113  rest_ = std::move(w.rest_);
114  w.Clear();
115  return *this;
116  }
117 
118  static const SetWeight &Zero() {
119  return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet();
120  }
121 
122  static const SetWeight &One() {
123  return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet();
124  }
125 
126  static const SetWeight &NoWeight() {
127  static const auto *const no_weight = new SetWeight(Label(kSetBad));
128  return *no_weight;
129  }
130 
131  static const std::string &Type() {
132  static const std::string *const type =
133  new std::string(S == SET_UNION_INTERSECT
134  ? "union_intersect_set"
135  : (S == SET_INTERSECT_UNION
136  ? "intersect_union_set"
138  ? "restricted_set_intersect_union"
139  : "boolean_set")));
140  return *type;
141  }
142 
143  bool Member() const;
144 
145  std::istream &Read(std::istream &strm);
146 
147  std::ostream &Write(std::ostream &strm) const;
148 
149  size_t Hash() const;
150 
151  SetWeight Quantize(float delta = kDelta) const { return *this; }
152 
153  ReverseWeight Reverse() const;
154 
155  static constexpr uint64_t Properties() {
157  }
158 
159  // These operations combined with the SetWeightIterator
160  // provide the access and mutation of the set internal elements.
161 
162  // The empty set.
163  static const SetWeight &EmptySet() {
164  static const auto *const empty = new SetWeight(Label(kSetEmpty));
165  return *empty;
166  }
167 
168  // The univeral set.
169  static const SetWeight &UnivSet() {
170  static const auto *const univ = new SetWeight(Label(kSetUniv));
171  return *univ;
172  }
173 
174  // Clear existing SetWeight.
175  void Clear() {
176  first_ = kSetEmpty;
177  rest_.clear();
178  }
179 
180  size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
181 
183  if (rest_.empty()) {
184  return first_;
185  } else {
186  return rest_.back();
187  }
188  }
189 
190  // Caller must add in sort order and be unique (or error signalled).
191  // Input should also be positive. Non-positive value (for the first
192  // push) has special internal meaning w.r.t. integral constants above.
193  void PushBack(Label label) {
194  if (first_ == kSetEmpty) {
195  first_ = label;
196  } else {
197  if (label <= Back() || label <= 0) {
198  FSTERROR() << "SetWeight: labels must be positive, added"
199  << " in sort order and be unique.";
200  rest_.push_back(Label(kSetBad));
201  }
202  rest_.push_back(label);
203  }
204  }
205 
206  private:
207  Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty).
208  std::list<Label> rest_; // Remaining labels in set.
209 };
210 
211 // Traverses set in forward direction.
212 template <class SetWeight_>
213 class SetWeightIterator {
214  public:
215  using Weight = SetWeight_;
216  using Label = typename Weight::Label;
217 
218  explicit SetWeightIterator(const Weight &w)
219  : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
220 
221  bool Done() const {
222  if (init_) {
223  return first_ == kSetEmpty;
224  } else {
225  return iter_ == rest_.end();
226  }
227  }
228 
229  const Label &Value() const { return init_ ? first_ : *iter_; }
230 
231  void Next() {
232  if (init_) {
233  init_ = false;
234  } else {
235  ++iter_;
236  }
237  }
238 
239  void Reset() {
240  init_ = true;
241  iter_ = rest_.begin();
242  }
243 
244  private:
245  const Label &first_;
246  const decltype(Weight::rest_) &rest_;
247  bool init_; // In the initialized state?
248  typename decltype(Weight::rest_)::const_iterator iter_;
249 };
250 
251 // SetWeight member functions follow that require SetWeightIterator
252 
253 template <typename Label, SetType S>
254 inline std::istream &SetWeight<Label, S>::Read(std::istream &strm) {
255  Clear();
256  int32_t size;
257  ReadType(strm, &size);
258  for (int32_t i = 0; i < size; ++i) {
259  Label label;
260  ReadType(strm, &label);
261  PushBack(label);
262  }
263  return strm;
264 }
265 
266 template <typename Label, SetType S>
267 inline std::ostream &SetWeight<Label, S>::Write(std::ostream &strm) const {
268  const int32_t size = Size();
269  WriteType(strm, size);
270  for (Iterator iter(*this); !iter.Done(); iter.Next()) {
271  WriteType(strm, iter.Value());
272  }
273  return strm;
274 }
275 
276 template <typename Label, SetType S>
277 inline bool SetWeight<Label, S>::Member() const {
278  Iterator iter(*this);
279  return iter.Value() != Label(kSetBad);
280 }
281 
282 template <typename Label, SetType S>
285  return *this;
286 }
287 
288 template <typename Label, SetType S>
289 inline size_t SetWeight<Label, S>::Hash() const {
290  using Weight = SetWeight<Label, S>;
291  if (S == SET_BOOLEAN) {
292  return *this == Weight::Zero() ? 0 : 1;
293  } else {
294  size_t h = 0;
295  for (Iterator iter(*this); !iter.Done(); iter.Next()) {
296  h ^= h << 1 ^ iter.Value();
297  }
298  return h;
299  }
300 }
301 
302 // Default ==
303 template <typename Label, SetType S>
304 inline bool operator==(const SetWeight<Label, S> &w1,
305  const SetWeight<Label, S> &w2) {
306  if (w1.Size() != w2.Size()) return false;
307  using Iterator = typename SetWeight<Label, S>::Iterator;
308  Iterator iter1(w1);
309  Iterator iter2(w2);
310  for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
311  if (iter1.Value() != iter2.Value()) return false;
312  }
313  return true;
314 }
315 
316 // Boolean ==
317 template <typename Label>
319  const SetWeight<Label, SET_BOOLEAN> &w2) {
320  // x == kSetEmpty if x \nin {kUnivSet, kSetBad}
321  if (!w1.Member() || !w2.Member()) return false;
323  Iterator iter1(w1);
324  Iterator iter2(w2);
325  Label label1 = iter1.Done() ? kSetEmpty : iter1.Value();
326  Label label2 = iter2.Done() ? kSetEmpty : iter2.Value();
327  if (label1 == kSetUniv) return label2 == kSetUniv;
328  if (label2 == kSetUniv) return label1 == kSetUniv;
329  return true;
330 }
331 
332 template <typename Label, SetType S>
333 inline bool operator!=(const SetWeight<Label, S> &w1,
334  const SetWeight<Label, S> &w2) {
335  return !(w1 == w2);
336 }
337 
338 template <typename Label, SetType S>
339 inline bool ApproxEqual(const SetWeight<Label, S> &w1,
340  const SetWeight<Label, S> &w2, float delta = kDelta) {
341  return w1 == w2;
342 }
343 
344 template <typename Label, SetType S>
345 inline std::ostream &operator<<(std::ostream &strm,
346  const SetWeight<Label, S> &weight) {
347  typename SetWeight<Label, S>::Iterator iter(weight);
348  if (iter.Done()) {
349  return strm << "EmptySet";
350  } else if (iter.Value() == Label(kSetUniv)) {
351  return strm << "UnivSet";
352  } else if (iter.Value() == Label(kSetBad)) {
353  return strm << "BadSet";
354  } else {
355  for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
356  if (i > 0) strm << kSetSeparator;
357  strm << iter.Value();
358  }
359  }
360  return strm;
361 }
362 
363 template <typename Label, SetType S>
364 inline std::istream &operator>>(std::istream &strm,
365  SetWeight<Label, S> &weight) {
366  std::string str;
367  strm >> str;
368  using Weight = SetWeight<Label, S>;
369  if (str == "EmptySet") {
370  weight = Weight(Label(kSetEmpty));
371  } else if (str == "UnivSet") {
372  weight = Weight(Label(kSetUniv));
373  } else {
374  weight.Clear();
375  for (std::string_view sv : StrSplit(str, kSetSeparator)) {
376  auto maybe_label = ParseInt64(sv);
377  if (!maybe_label.has_value()) {
378  strm.clear(std::ios::badbit);
379  break;
380  }
381  weight.PushBack(*maybe_label);
382  }
383  }
384  return strm;
385 }
386 
387 template <typename Label, SetType S>
389  const SetWeight<Label, S> &w2) {
390  using Weight = SetWeight<Label, S>;
391  using Iterator = typename SetWeight<Label, S>::Iterator;
392  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
393  if (w1 == Weight::EmptySet()) return w2;
394  if (w2 == Weight::EmptySet()) return w1;
395  if (w1 == Weight::UnivSet()) return w1;
396  if (w2 == Weight::UnivSet()) return w2;
397  Iterator it1(w1);
398  Iterator it2(w2);
399  Weight result;
400  while (!it1.Done() && !it2.Done()) {
401  const auto v1 = it1.Value();
402  const auto v2 = it2.Value();
403  if (v1 < v2) {
404  result.PushBack(v1);
405  it1.Next();
406  } else if (v1 > v2) {
407  result.PushBack(v2);
408  it2.Next();
409  } else {
410  result.PushBack(v1);
411  it1.Next();
412  it2.Next();
413  }
414  }
415  for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
416  for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value());
417  return result;
418 }
419 
420 template <typename Label, SetType S>
422  const SetWeight<Label, S> &w2) {
423  using Weight = SetWeight<Label, S>;
424  using Iterator = typename SetWeight<Label, S>::Iterator;
425  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
426  if (w1 == Weight::EmptySet()) return w1;
427  if (w2 == Weight::EmptySet()) return w2;
428  if (w1 == Weight::UnivSet()) return w2;
429  if (w2 == Weight::UnivSet()) return w1;
430  Iterator it1(w1);
431  Iterator it2(w2);
432  Weight result;
433  while (!it1.Done() && !it2.Done()) {
434  const auto v1 = it1.Value();
435  const auto v2 = it2.Value();
436  if (v1 < v2) {
437  it1.Next();
438  } else if (v1 > v2) {
439  it2.Next();
440  } else {
441  result.PushBack(v1);
442  it1.Next();
443  it2.Next();
444  }
445  }
446  return result;
447 }
448 
449 template <typename Label, SetType S>
451  const SetWeight<Label, S> &w2) {
452  using Weight = SetWeight<Label, S>;
453  using Iterator = typename SetWeight<Label, S>::Iterator;
454  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
455  if (w1 == Weight::EmptySet()) return w1;
456  if (w2 == Weight::EmptySet()) return w1;
457  if (w2 == Weight::UnivSet()) return Weight::EmptySet();
458  Iterator it1(w1);
459  Iterator it2(w2);
460  Weight result;
461  while (!it1.Done() && !it2.Done()) {
462  const auto v1 = it1.Value();
463  const auto v2 = it2.Value();
464  if (v1 < v2) {
465  result.PushBack(v1);
466  it1.Next();
467  } else if (v1 > v2) {
468  it2.Next();
469  } else {
470  it1.Next();
471  it2.Next();
472  }
473  }
474  for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
475  return result;
476 }
477 
478 // Default: Plus = Intersect.
479 template <typename Label, SetType S>
481  const SetWeight<Label, S> &w2) {
482  return Intersect(w1, w2);
483 }
484 
485 // Plus = Union.
486 template <typename Label>
490  return Union(w1, w2);
491 }
492 
493 // Plus = Set equality is required (for non-Zero() input). The
494 // restriction is useful (e.g., in determinization) to ensure the input
495 // has a unique labelled path weight.
496 template <typename Label>
501  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
502  if (w1 == Weight::Zero()) return w2;
503  if (w2 == Weight::Zero()) return w1;
504  if (w1 != w2) {
505  FSTERROR() << "SetWeight::Plus: Unequal arguments "
506  << "(non-unique labelled path weights?)"
507  << " w1 = " << w1 << " w2 = " << w2;
508  return Weight::NoWeight();
509  }
510  return w1;
511 }
512 
513 // Plus = Or.
514 template <typename Label>
517  const SetWeight<Label, SET_BOOLEAN> &w2) {
518  using Weight = SetWeight<Label, SET_BOOLEAN>;
519  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
520  if (w1 == Weight::One()) return w1;
521  if (w2 == Weight::One()) return w2;
522  return Weight::Zero();
523 }
524 
525 // Default: Times = Union.
526 template <typename Label, SetType S>
528  const SetWeight<Label, S> &w2) {
529  return Union(w1, w2);
530 }
531 
532 // Times = Intersect.
533 template <typename Label>
537  return Intersect(w1, w2);
538 }
539 
540 // Times = And.
541 template <typename Label>
544  const SetWeight<Label, SET_BOOLEAN> &w2) {
545  using Weight = SetWeight<Label, SET_BOOLEAN>;
546  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
547  if (w1 == Weight::One()) return w2;
548  return w1;
549 }
550 
551 // Divide = Difference.
552 template <typename Label, SetType S>
554  const SetWeight<Label, S> &w2,
555  DivideType divide_type = DIVIDE_ANY) {
556  return Difference(w1, w2);
557 }
558 
559 // Divide = dividend (or the universal set if the
560 // dividend == divisor).
561 template <typename Label>
565  DivideType divide_type = DIVIDE_ANY) {
567  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
568  if (w1 == w2) return Weight::UnivSet();
569  return w1;
570 }
571 
572 // Divide = Or Not.
573 template <typename Label>
577  DivideType divide_type = DIVIDE_ANY) {
578  using Weight = SetWeight<Label, SET_BOOLEAN>;
579  if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
580  if (w1 == Weight::One()) return w1;
581  if (w2 == Weight::Zero()) return Weight::One();
582  return Weight::Zero();
583 }
584 
585 // Converts between different set types.
586 template <typename Label, SetType S1, SetType S2>
591  for (Iterator iter(w1); !iter.Done(); iter.Next())
592  w2.PushBack(iter.Value());
593  return w2;
594  }
595 };
596 
597 // This function object generates SetWeights that are random integer sets
598 // from {1, ... , alphabet_size}^{0, max_set_length} U { Zero }. This is
599 // intended primarily for testing.
600 template <class Label, SetType S>
602  public:
604 
605  explicit WeightGenerate(uint64_t seed = std::random_device()(),
606  bool allow_zero = true,
607  size_t alphabet_size = kNumRandomWeights,
608  size_t max_set_length = kNumRandomWeights)
609  : allow_zero_(allow_zero),
610  alphabet_size_(alphabet_size),
611  max_set_length_(max_set_length) {}
612 
613  Weight operator()() const {
614  const int n = std::uniform_int_distribution<>(
615  0, max_set_length_ + allow_zero_ - 1)(rand_);
616  if (allow_zero_ && n == max_set_length_) return Weight::Zero();
617  std::vector<Label> labels;
618  labels.reserve(n);
619  for (int i = 0; i < n; ++i) {
620  labels.push_back(
621  std::uniform_int_distribution<>(0, alphabet_size_)(rand_));
622  }
623  std::sort(labels.begin(), labels.end());
624  const auto labels_end = std::unique(labels.begin(), labels.end());
625  labels.resize(labels_end - labels.begin());
626  return Weight(labels.begin(), labels.end());
627  }
628 
629  private:
630  mutable std::mt19937_64 rand_;
631  const bool allow_zero_;
632  const size_t alphabet_size_;
633  const size_t max_set_length_;
634 };
635 
636 } // namespace fst
637 
638 #endif // FST_SET_WEIGHT_H_
static const std::string & Type()
Definition: set-weight.h:131
std::ostream & Write(std::ostream &strm) const
Definition: set-weight.h:267
static const SetWeight & EmptySet()
Definition: set-weight.h:163
constexpr char kSetSeparator
Definition: set-weight.h:49
static constexpr uint64_t Properties()
Definition: set-weight.h:155
SetWeight & operator=(const SetWeight< Label, S2 > &w)
Definition: set-weight.h:104
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
ReverseWeight Reverse() const
Definition: set-weight.h:284
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t alphabet_size=kNumRandomWeights, size_t max_set_length=kNumRandomWeights)
Definition: set-weight.h:605
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
SetWeight & operator=(SetWeight< Label, S2 > &&w)
Definition: set-weight.h:111
SetType
Definition: set-weight.h:59
void Intersect(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const IntersectOptions &opts=IntersectOptions())
Definition: intersect.h:150
constexpr uint64_t kIdempotent
Definition: weight.h:147
constexpr int kSetUniv
Definition: set-weight.h:47
SetWeight()=default
static const SetWeight & NoWeight()
Definition: set-weight.h:126
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
Definition: compat.cc:77
SetWeight(Label label)
Definition: set-weight.h:91
std::istream & Read(std::istream &strm)
Definition: set-weight.h:254
constexpr uint64_t kRightSemiring
Definition: weight.h:139
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
Definition: difference.h:175
const Label & Value() const
Definition: set-weight.h:229
#define FSTERROR()
Definition: util.h:56
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
Definition: util.cc:46
typename Weight::Label Label
Definition: set-weight.h:216
Label Back()
Definition: set-weight.h:182
size_t Hash() const
Definition: set-weight.h:289
bool Member() const
Definition: set-weight.h:277
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:185
SetWeightIterator< SetWeight > Iterator
Definition: set-weight.h:75
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
Definition: union.h:117
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:54
SetWeightIterator(const Weight &w)
Definition: set-weight.h:218
SetWeight(const Iterator begin, const Iterator end)
Definition: set-weight.h:85
constexpr uint64_t kCommutative
Definition: weight.h:144
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:71
SetWeight(SetWeight< Label, S2 > &&w)
Definition: set-weight.h:98
static const SetWeight & UnivSet()
Definition: set-weight.h:169
SetWeight< Label, S2 > operator()(const SetWeight< Label, S1 > &w1) const
Definition: set-weight.h:588
static const SetWeight & Zero()
Definition: set-weight.h:118
void PushBack(Label label)
Definition: set-weight.h:193
constexpr int kSetBad
Definition: set-weight.h:48
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:51
constexpr size_t kNumRandomWeights
Definition: weight.h:154
static const SetWeight & One()
Definition: set-weight.h:122
SetWeight Quantize(float delta=kDelta) const
Definition: set-weight.h:151
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
size_t Size() const
Definition: set-weight.h:180
DivideType
Definition: weight.h:165
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
constexpr float kDelta
Definition: weight.h:133
bool Done() const
Definition: set-weight.h:221
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
constexpr int kSetEmpty
Definition: set-weight.h:46
SetWeight(const SetWeight< Label, S2 > &w)
Definition: set-weight.h:94