22 #ifndef FST_SET_WEIGHT_H_ 23 #define FST_SET_WEIGHT_H_ 42 #include <string_view> 70 template <
typename L, SetType S = SET_INTERSECT_UNION>
78 template <
typename L2, SetType S2>
84 template <
typename Iterator>
86 for (
auto iter = begin; iter != end; ++iter)
PushBack(*iter);
95 : first_(w.first_), rest_(w.rest_) {}
99 : first_(w.first_), rest_(std::move(w.rest_)) {
103 template <SetType S2>
110 template <SetType S2>
113 rest_ = std::move(w.rest_);
127 static const auto *
const no_weight =
new SetWeight(
Label(kSetBad));
131 static const std::string &
Type() {
132 static const std::string *
const type =
134 ?
"union_intersect_set" 136 ?
"intersect_union_set" 138 ?
"restricted_set_intersect_union" 145 std::istream &
Read(std::istream &strm);
147 std::ostream &
Write(std::ostream &strm)
const;
180 size_t Size()
const {
return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
194 if (first_ == kSetEmpty) {
197 if (label <=
Back() || label <= 0) {
198 FSTERROR() <<
"SetWeight: labels must be positive, added" 199 <<
" in sort order and be unique.";
200 rest_.push_back(
Label(kSetBad));
202 rest_.push_back(label);
208 std::list<Label> rest_;
212 template <
class SetWeight_>
216 using Label =
typename Weight::Label;
219 : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
225 return iter_ == rest_.end();
229 const Label &
Value()
const {
return init_ ? first_ : *iter_; }
241 iter_ = rest_.begin();
246 const decltype(Weight::rest_) &rest_;
248 typename decltype(Weight::rest_)::const_iterator iter_;
253 template <
typename Label, SetType S>
258 for (int32_t i = 0; i < size; ++i) {
266 template <
typename Label, SetType S>
268 const int32_t size =
Size();
276 template <
typename Label, SetType S>
282 template <
typename Label, SetType S>
288 template <
typename Label, SetType S>
292 return *
this == Weight::Zero() ? 0 : 1;
296 h ^= h << 1 ^ iter.Value();
303 template <
typename Label, SetType S>
306 if (w1.
Size() != w2.
Size())
return false;
310 for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
311 if (iter1.Value() != iter2.Value())
return false;
317 template <
typename Label>
325 Label label1 = iter1.Done() ? kSetEmpty : iter1.Value();
326 Label label2 = iter2.Done() ? kSetEmpty : iter2.Value();
327 if (label1 == kSetUniv)
return label2 ==
kSetUniv;
328 if (label2 == kSetUniv)
return label1 ==
kSetUniv;
332 template <
typename Label, SetType S>
338 template <
typename Label, SetType S>
344 template <
typename Label, SetType S>
349 return strm <<
"EmptySet";
351 return strm <<
"UnivSet";
353 return strm <<
"BadSet";
355 for (
size_t i = 0; !iter.
Done(); ++i, iter.
Next()) {
357 strm << iter.
Value();
363 template <
typename Label, SetType S>
369 if (str ==
"EmptySet") {
370 weight = Weight(
Label(kSetEmpty));
371 }
else if (str ==
"UnivSet") {
372 weight = Weight(
Label(kSetUniv));
375 for (std::string_view sv :
StrSplit(str, kSetSeparator)) {
377 if (!maybe_label.has_value()) {
378 strm.clear(std::ios::badbit);
387 template <
typename Label, SetType S>
392 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
393 if (w1 == Weight::EmptySet())
return w2;
394 if (w2 == Weight::EmptySet())
return w1;
395 if (w1 == Weight::UnivSet())
return w1;
396 if (w2 == Weight::UnivSet())
return w2;
400 while (!it1.Done() && !it2.Done()) {
401 const auto v1 = it1.Value();
402 const auto v2 = it2.Value();
406 }
else if (v1 > v2) {
415 for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
416 for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value());
420 template <
typename Label, SetType S>
425 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
426 if (w1 == Weight::EmptySet())
return w1;
427 if (w2 == Weight::EmptySet())
return w2;
428 if (w1 == Weight::UnivSet())
return w2;
429 if (w2 == Weight::UnivSet())
return w1;
433 while (!it1.Done() && !it2.Done()) {
434 const auto v1 = it1.Value();
435 const auto v2 = it2.Value();
438 }
else if (v1 > v2) {
449 template <
typename Label, SetType S>
454 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
455 if (w1 == Weight::EmptySet())
return w1;
456 if (w2 == Weight::EmptySet())
return w1;
457 if (w2 == Weight::UnivSet())
return Weight::EmptySet();
461 while (!it1.Done() && !it2.Done()) {
462 const auto v1 = it1.Value();
463 const auto v2 = it2.Value();
467 }
else if (v1 > v2) {
474 for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
479 template <
typename Label, SetType S>
486 template <
typename Label>
490 return Union(w1, w2);
496 template <
typename Label>
501 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
502 if (w1 == Weight::Zero())
return w2;
503 if (w2 == Weight::Zero())
return w1;
505 FSTERROR() <<
"SetWeight::Plus: Unequal arguments " 506 <<
"(non-unique labelled path weights?)" 507 <<
" w1 = " << w1 <<
" w2 = " << w2;
508 return Weight::NoWeight();
514 template <
typename Label>
519 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
520 if (w1 == Weight::One())
return w1;
521 if (w2 == Weight::One())
return w2;
522 return Weight::Zero();
526 template <
typename Label, SetType S>
529 return Union(w1, w2);
533 template <
typename Label>
541 template <
typename Label>
546 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
547 if (w1 == Weight::One())
return w2;
552 template <
typename Label, SetType S>
561 template <
typename Label>
567 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
568 if (w1 == w2)
return Weight::UnivSet();
573 template <
typename Label>
579 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
580 if (w1 == Weight::One())
return w1;
581 if (w2 == Weight::Zero())
return Weight::One();
582 return Weight::Zero();
586 template <
typename Label, SetType S1, SetType S2>
591 for (
Iterator iter(w1); !iter.Done(); iter.Next())
600 template <
class Label, SetType S>
606 bool allow_zero =
true,
609 : allow_zero_(allow_zero),
610 alphabet_size_(alphabet_size),
611 max_set_length_(max_set_length) {}
614 const int n = std::uniform_int_distribution<>(
615 0, max_set_length_ + allow_zero_ - 1)(rand_);
616 if (allow_zero_ && n == max_set_length_)
return Weight::Zero();
617 std::vector<Label> labels;
619 for (
int i = 0; i < n; ++i) {
621 std::uniform_int_distribution<>(0, alphabet_size_)(rand_));
623 std::sort(labels.begin(), labels.end());
624 const auto labels_end = std::unique(labels.begin(), labels.end());
625 labels.resize(labels_end - labels.begin());
626 return Weight(labels.begin(), labels.end());
630 mutable std::mt19937_64 rand_;
631 const bool allow_zero_;
632 const size_t alphabet_size_;
633 const size_t max_set_length_;
638 #endif // FST_SET_WEIGHT_H_ static const std::string & Type()
std::ostream & Write(std::ostream &strm) const
static const SetWeight & EmptySet()
constexpr char kSetSeparator
static constexpr uint64_t Properties()
SetWeight & operator=(const SetWeight< Label, S2 > &w)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
ReverseWeight Reverse() const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t alphabet_size=kNumRandomWeights, size_t max_set_length=kNumRandomWeights)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
SetWeight & operator=(SetWeight< Label, S2 > &&w)
void Intersect(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const IntersectOptions &opts=IntersectOptions())
constexpr uint64_t kIdempotent
static const SetWeight & NoWeight()
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
std::istream & Read(std::istream &strm)
constexpr uint64_t kRightSemiring
std::ostream & WriteType(std::ostream &strm, const T t)
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
const Label & Value() const
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
typename Weight::Label Label
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
SetWeightIterator< SetWeight > Iterator
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
bool operator!=(const ErrorWeight &, const ErrorWeight &)
SetWeightIterator(const Weight &w)
SetWeight(const Iterator begin, const Iterator end)
constexpr uint64_t kCommutative
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Weight operator()() const
SetWeight(SetWeight< Label, S2 > &&w)
static const SetWeight & UnivSet()
SetWeight< Label, S2 > operator()(const SetWeight< Label, S1 > &w1) const
static const SetWeight & Zero()
void PushBack(Label label)
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
bool operator==(const ErrorWeight &, const ErrorWeight &)
constexpr size_t kNumRandomWeights
static const SetWeight & One()
SetWeight Quantize(float delta=kDelta) const
std::istream & ReadType(std::istream &strm, T *t)
constexpr uint64_t kLeftSemiring
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
SetWeight(const SetWeight< Label, S2 > &w)