22 #ifndef FST_SET_WEIGHT_H_ 23 #define FST_SET_WEIGHT_H_ 34 #include <string_view> 63 template <
typename L, SetType S = SET_INTERSECT_UNION>
71 template <
typename L2, SetType S2>
77 template <
typename Iterator>
79 for (
auto iter = begin; iter != end; ++iter)
PushBack(*iter);
88 : first_(w.first_), rest_(w.rest_) {}
92 : first_(w.first_), rest_(std::move(w.rest_)) {
103 template <SetType S2>
106 rest_ = std::move(w.rest_);
120 static const auto *
const no_weight =
new SetWeight(
Label(kSetBad));
124 static const std::string &
Type() {
125 static const std::string *
const type =
127 ?
"union_intersect_set" 129 ?
"intersect_union_set" 131 ?
"restricted_set_intersect_union" 138 std::istream &
Read(std::istream &strm);
140 std::ostream &
Write(std::ostream &strm)
const;
173 size_t Size()
const {
return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
187 if (first_ == kSetEmpty) {
190 if (label <=
Back() || label <= 0) {
191 FSTERROR() <<
"SetWeight: labels must be positive, added" 192 <<
" in sort order and be unique.";
193 rest_.push_back(
Label(kSetBad));
195 rest_.push_back(label);
201 std::list<Label> rest_;
205 template <
class SetWeight_>
209 using Label =
typename Weight::Label;
212 : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
218 return iter_ == rest_.end();
222 const Label &
Value()
const {
return init_ ? first_ : *iter_; }
234 iter_ = rest_.begin();
239 const decltype(Weight::rest_) &rest_;
241 typename decltype(Weight::rest_)::const_iterator iter_;
246 template <
typename Label, SetType S>
251 for (int32_t i = 0; i < size; ++i) {
259 template <
typename Label, SetType S>
261 const int32_t size =
Size();
269 template <
typename Label, SetType S>
275 template <
typename Label, SetType S>
281 template <
typename Label, SetType S>
285 return *
this == Weight::Zero() ? 0 : 1;
289 h ^= h << 1 ^ iter.Value();
296 template <
typename Label, SetType S>
299 if (w1.
Size() != w2.
Size())
return false;
303 for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
304 if (iter1.Value() != iter2.Value())
return false;
310 template <
typename Label>
318 Label label1 = iter1.Done() ? kSetEmpty : iter1.Value();
319 Label label2 = iter2.Done() ? kSetEmpty : iter2.Value();
320 if (label1 == kSetUniv)
return label2 ==
kSetUniv;
321 if (label2 == kSetUniv)
return label1 ==
kSetUniv;
325 template <
typename Label, SetType S>
331 template <
typename Label, SetType S>
337 template <
typename Label, SetType S>
342 return strm <<
"EmptySet";
344 return strm <<
"UnivSet";
346 return strm <<
"BadSet";
348 for (
size_t i = 0; !iter.
Done(); ++i, iter.
Next()) {
350 strm << iter.
Value();
356 template <
typename Label, SetType S>
362 if (str ==
"EmptySet") {
363 weight = Weight(
Label(kSetEmpty));
364 }
else if (str ==
"UnivSet") {
365 weight = Weight(
Label(kSetUniv));
368 for (std::string_view sv :
StrSplit(str, kSetSeparator)) {
370 if (!maybe_label.has_value()) {
371 strm.clear(std::ios::badbit);
380 template <
typename Label, SetType S>
385 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
386 if (w1 == Weight::EmptySet())
return w2;
387 if (w2 == Weight::EmptySet())
return w1;
388 if (w1 == Weight::UnivSet())
return w1;
389 if (w2 == Weight::UnivSet())
return w2;
393 while (!it1.Done() && !it2.Done()) {
394 const auto v1 = it1.Value();
395 const auto v2 = it2.Value();
399 }
else if (v1 > v2) {
408 for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
409 for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value());
413 template <
typename Label, SetType S>
418 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
419 if (w1 == Weight::EmptySet())
return w1;
420 if (w2 == Weight::EmptySet())
return w2;
421 if (w1 == Weight::UnivSet())
return w2;
422 if (w2 == Weight::UnivSet())
return w1;
426 while (!it1.Done() && !it2.Done()) {
427 const auto v1 = it1.Value();
428 const auto v2 = it2.Value();
431 }
else if (v1 > v2) {
442 template <
typename Label, SetType S>
447 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
448 if (w1 == Weight::EmptySet())
return w1;
449 if (w2 == Weight::EmptySet())
return w1;
450 if (w2 == Weight::UnivSet())
return Weight::EmptySet();
454 while (!it1.Done() && !it2.Done()) {
455 const auto v1 = it1.Value();
456 const auto v2 = it2.Value();
460 }
else if (v1 > v2) {
467 for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
472 template <
typename Label, SetType S>
479 template <
typename Label>
483 return Union(w1, w2);
489 template <
typename Label>
494 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
495 if (w1 == Weight::Zero())
return w2;
496 if (w2 == Weight::Zero())
return w1;
498 FSTERROR() <<
"SetWeight::Plus: Unequal arguments " 499 <<
"(non-unique labelled path weights?)" 500 <<
" w1 = " << w1 <<
" w2 = " << w2;
501 return Weight::NoWeight();
507 template <
typename Label>
512 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
513 if (w1 == Weight::One())
return w1;
514 if (w2 == Weight::One())
return w2;
515 return Weight::Zero();
519 template <
typename Label, SetType S>
522 return Union(w1, w2);
526 template <
typename Label>
534 template <
typename Label>
539 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
540 if (w1 == Weight::One())
return w2;
545 template <
typename Label, SetType S>
554 template <
typename Label>
560 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
561 if (w1 == w2)
return Weight::UnivSet();
566 template <
typename Label>
572 if (!w1.
Member() || !w2.
Member())
return Weight::NoWeight();
573 if (w1 == Weight::One())
return w1;
574 if (w2 == Weight::Zero())
return Weight::One();
575 return Weight::Zero();
579 template <
typename Label, SetType S1, SetType S2>
584 for (
Iterator iter(w1); !iter.Done(); iter.Next())
593 template <
class Label, SetType S>
599 bool allow_zero =
true,
602 : allow_zero_(allow_zero),
603 alphabet_size_(alphabet_size),
604 max_set_length_(max_set_length) {}
607 const int n = std::uniform_int_distribution<>(
608 0, max_set_length_ + allow_zero_ - 1)(rand_);
609 if (allow_zero_ && n == max_set_length_)
return Weight::Zero();
610 std::vector<Label> labels;
612 for (
int i = 0; i < n; ++i) {
614 std::uniform_int_distribution<>(0, alphabet_size_)(rand_));
616 std::sort(labels.begin(), labels.end());
617 const auto labels_end = std::unique(labels.begin(), labels.end());
618 labels.resize(labels_end - labels.begin());
619 return Weight(labels.begin(), labels.end());
623 mutable std::mt19937_64 rand_;
624 const bool allow_zero_;
625 const size_t alphabet_size_;
626 const size_t max_set_length_;
631 #endif // FST_SET_WEIGHT_H_ static const std::string & Type()
std::ostream & Write(std::ostream &strm) const
static const SetWeight & EmptySet()
constexpr char kSetSeparator
static constexpr uint64_t Properties()
SetWeight & operator=(const SetWeight< Label, S2 > &w)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
ReverseWeight Reverse() const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t alphabet_size=kNumRandomWeights, size_t max_set_length=kNumRandomWeights)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
SetWeight & operator=(SetWeight< Label, S2 > &&w)
void Intersect(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const IntersectOptions &opts=IntersectOptions())
constexpr uint64_t kIdempotent
static const SetWeight & NoWeight()
internal::StringSplitter StrSplit(std::string_view full, ByAnyChar delim)
std::istream & Read(std::istream &strm)
constexpr uint64_t kRightSemiring
std::ostream & WriteType(std::ostream &strm, const T t)
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
const Label & Value() const
std::optional< int64_t > ParseInt64(std::string_view s, int base=10)
typename Weight::Label Label
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
SetWeightIterator< SetWeight > Iterator
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
bool operator!=(const ErrorWeight &, const ErrorWeight &)
SetWeightIterator(const Weight &w)
SetWeight(const Iterator begin, const Iterator end)
constexpr uint64_t kCommutative
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Weight operator()() const
SetWeight(SetWeight< Label, S2 > &&w)
static const SetWeight & UnivSet()
SetWeight< Label, S2 > operator()(const SetWeight< Label, S1 > &w1) const
static const SetWeight & Zero()
void PushBack(Label label)
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
bool operator==(const ErrorWeight &, const ErrorWeight &)
constexpr size_t kNumRandomWeights
static const SetWeight & One()
SetWeight Quantize(float delta=kDelta) const
std::istream & ReadType(std::istream &strm, T *t)
constexpr uint64_t kLeftSemiring
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
SetWeight(const SetWeight< Label, S2 > &w)