20 #ifndef FST_FACTOR_WEIGHT_H_ 21 #define FST_FACTOR_WEIGHT_H_ 34 #include <unordered_map> 43 using Label =
typename Arc::Label;
53 uint8_t mode = kFactorArcWeights |
55 Label final_ilabel = 0,
Label final_olabel = 0,
56 bool increment_final_ilabel =
false,
57 bool increment_final_olabel =
false)
61 final_ilabel(final_ilabel),
62 final_olabel(final_olabel),
63 increment_final_ilabel(increment_final_ilabel),
64 increment_final_olabel(increment_final_olabel) {}
67 uint8_t mode = kFactorArcWeights |
69 Label final_ilabel = 0,
Label final_olabel = 0,
70 bool increment_final_ilabel =
false,
71 bool increment_final_olabel =
false)
74 final_ilabel(final_ilabel),
75 final_olabel(final_olabel),
76 increment_final_ilabel(increment_final_ilabel),
77 increment_final_olabel(increment_final_olabel) {}
104 bool Done()
const {
return true; }
108 std::pair<W, W>
Value()
const {
return std::make_pair(W::One(), W::One()); }
119 explicit OneFactor(
const W &w) : weight_(w), done_(w == W::One()) {}
121 bool Done()
const {
return done_; }
125 std::pair<W, W>
Value()
const {
return std::make_pair(W::One(), weight_); }
127 void Reset() { done_ = weight_ == W::One(); }
135 template <
typename Label, StringType S = STRING_LEFT>
139 : weight_(weight), done_(weight.Size() <= 1) {}
141 bool Done()
const {
return done_; }
147 typename Weight::Iterator siter(weight_);
148 Weight w1(siter.Value());
150 for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
151 return std::make_pair(w1, w2);
154 void Reset() { done_ = weight_.Size() <= 1; }
162 template <
class Label,
class W, GallicType G = GALLIC_LEFT>
168 : weight_(weight), done_(weight.Value1().Size() <= 1) {}
170 bool Done()
const {
return done_; }
176 GW w1(siter.Value().first, weight_.Value2());
177 GW w2(siter.Value().second, W::One());
178 return std::make_pair(w1, w2);
181 void Reset() { done_ = weight_.Value1().Size() <= 1; }
189 template <
class Label,
class W>
197 done_(weight.Size() == 0 ||
198 (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
200 bool Done()
const {
return done_ || iter_.Done(); }
207 const auto weight = iter_.Value();
210 GRW w1(siter.Value().first, weight.Value2());
211 GRW w2(siter.Value().second, W::One());
212 return std::make_pair(
GW(w1),
GW(w2));
223 template <
class Arc,
class FactorIterator>
267 LOG(WARNING) <<
"FactorWeightFst: Factor mode is set to 0; " 268 <<
"factoring neither arc weights nor final weights";
274 fst_(impl.fst_->Copy(true)),
277 final_ilabel_(impl.final_ilabel_),
278 final_olabel_(impl.final_olabel_),
279 increment_final_ilabel_(impl.increment_final_ilabel_),
280 increment_final_olabel_(impl.increment_final_olabel_) {
289 const auto s = fst_->Start();
291 SetStart(FindState(
Element(fst_->Start(), Weight::One())));
298 const auto &element = elements_[s];
302 :
Times(element.weight, fst_->Final(element.state));
303 FactorIterator siter(weight);
304 if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
307 SetFinal(s, Weight::Zero());
314 if (!HasArcs(s))
Expand(s);
319 if (!HasArcs(s))
Expand(s);
324 if (!HasArcs(s))
Expand(s);
332 if ((mask &
kError) && fst_->Properties(kError,
false)) {
333 SetProperties(kError, kError);
339 if (!HasArcs(s))
Expand(s);
346 if (!(mode_ & kFactorArcWeights) && element.
weight == Weight::One() &&
348 while (unfactored_.size() <= element.
state)
351 unfactored_[element.
state] = elements_.size();
352 elements_.push_back(element);
354 return unfactored_[element.
state];
356 const auto insert_result =
357 element_map_.emplace(element, elements_.size());
358 if (insert_result.second) {
359 elements_.push_back(element);
361 return insert_result.first->second;
368 const auto element = elements_[s];
372 const auto &arc = ait.Value();
373 auto weight =
Times(element.weight, arc.weight);
374 FactorIterator fiter(weight);
375 if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
376 const auto dest = FindState(
Element(arc.nextstate, Weight::One()));
377 EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
379 for (; !fiter.Done(); fiter.Next()) {
380 auto pair = fiter.Value();
382 FindState(
Element(arc.nextstate, pair.second.Quantize(delta_)));
383 EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
388 if ((mode_ & kFactorFinalWeights) &&
390 (fst_->Final(element.state) != Weight::Zero()))) {
394 :
Times(element.weight, fst_->Final(element.state));
395 auto ilabel = final_ilabel_;
396 auto olabel = final_olabel_;
397 for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
398 auto pair = fiter.Value();
401 EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
402 if (increment_final_ilabel_) ++ilabel;
403 if (increment_final_olabel_) ++olabel;
421 size_t operator()(
const Element &x)
const {
422 static constexpr
auto prime = 7853;
423 return static_cast<size_t>(x.
state * prime + x.
weight.Hash());
428 std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
430 std::unique_ptr<const Fst<Arc>> fst_;
435 bool increment_final_ilabel_;
436 bool increment_final_olabel_;
437 std::vector<Element> elements_;
438 ElementMap element_map_;
441 std::vector<StateId> unfactored_;
455 template <
class A,
class FactorIterator>
457 :
public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
489 GetMutableImpl()->InitArcIterator(s, data);
500 template <
class Arc,
class FactorIterator>
506 fst, fst.GetMutableImpl()) {}
510 template <
class Arc,
class FactorIterator>
518 fst.GetMutableImpl(), s) {
523 template <
class Arc,
class FactorIterator>
527 std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
533 #endif // FST_FACTOR_WEIGHT_H_ ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
FactorWeightFst(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
bool increment_final_olabel
GallicFactor(const GW &weight)
FactorWeightFst(const FactorWeightFst &fst, bool copy)
std::pair< W, W > Value() const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
uint64_t Properties(uint64_t mask) const override
uint64_t Properties() const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
size_t NumInputEpsilons(StateId s)
const SymbolTable * OutputSymbols() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
StateId FindState(const Element &element)
size_t NumOutputEpsilons(StateId s)
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Element(StateId s, Weight weight_)
std::pair< W, W > Value() const
bool increment_final_ilabel
StringFactor(const StringWeight< Label, S > &weight)
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
FactorWeightFstImpl(const FactorWeightFstImpl< Arc, FactorIterator > &impl)
constexpr uint64_t kCopyProperties
uint64_t FactorWeightProperties(uint64_t inprops)
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
typename Arc::Weight Weight
std::pair< GW, GW > Value() const
typename Arc::Weight Weight
std::unique_ptr< StateIteratorBase< Arc > > base
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
FactorWeightOptions(float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
typename Store::State State
FactorWeightOptions(const CacheOptions &opts, float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
FactorWeightFst(const Fst< Arc > &fst)
size_t NumArcs(StateId s)
typename Arc::StateId StateId
constexpr uint64_t kFstProperties
typename Arc::Label Label
constexpr uint8_t kFactorFinalWeights
typename Arc::StateId StateId
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
std::pair< StringWeight< Label, S >, StringWeight< Label, S > > Value() const
FactorWeightFstImpl(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
FactorWeightFst * Copy(bool copy=false) const override
typename Arc::Label Label
void InitStateIterator(StateIteratorData< Arc > *data) const override
constexpr uint8_t kFactorArcWeights
GallicFactor(const GW &weight)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
typename Arc::StateId StateId
ArcIterator(const FactorWeightFst< Arc, FactorIterator > &fst, StateId s)
IdentityFactor(const W &weight)
std::pair< GW, GW > Value() const
StateIterator(const FactorWeightFst< Arc, FactorIterator > &fst)
const Impl * GetImpl() const
virtual const SymbolTable * OutputSymbols() const =0