20 #ifndef FST_FACTOR_WEIGHT_H_ 21 #define FST_FACTOR_WEIGHT_H_ 39 #include <unordered_map> 48 using Label =
typename Arc::Label;
58 uint8_t mode = kFactorArcWeights |
60 Label final_ilabel = 0,
Label final_olabel = 0,
61 bool increment_final_ilabel =
false,
62 bool increment_final_olabel =
false)
66 final_ilabel(final_ilabel),
67 final_olabel(final_olabel),
68 increment_final_ilabel(increment_final_ilabel),
69 increment_final_olabel(increment_final_olabel) {}
72 uint8_t mode = kFactorArcWeights |
74 Label final_ilabel = 0,
Label final_olabel = 0,
75 bool increment_final_ilabel =
false,
76 bool increment_final_olabel =
false)
79 final_ilabel(final_ilabel),
80 final_olabel(final_olabel),
81 increment_final_ilabel(increment_final_ilabel),
82 increment_final_olabel(increment_final_olabel) {}
109 bool Done()
const {
return true; }
113 std::pair<W, W>
Value()
const {
return std::make_pair(W::One(), W::One()); }
124 explicit OneFactor(
const W &w) : weight_(w), done_(w == W::One()) {}
126 bool Done()
const {
return done_; }
130 std::pair<W, W>
Value()
const {
return std::make_pair(W::One(), weight_); }
132 void Reset() { done_ = weight_ == W::One(); }
140 template <
typename Label, StringType S = STRING_LEFT>
144 : weight_(weight), done_(weight.Size() <= 1) {}
146 bool Done()
const {
return done_; }
152 typename Weight::Iterator siter(weight_);
153 Weight w1(siter.Value());
155 for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
156 return std::make_pair(w1, w2);
159 void Reset() { done_ = weight_.Size() <= 1; }
167 template <
class Label,
class W, GallicType G = GALLIC_LEFT>
173 : weight_(weight), done_(weight.Value1().Size() <= 1) {}
175 bool Done()
const {
return done_; }
181 GW w1(siter.Value().first, weight_.Value2());
182 GW w2(siter.Value().second, W::One());
183 return std::make_pair(w1, w2);
186 void Reset() { done_ = weight_.Value1().Size() <= 1; }
194 template <
class Label,
class W>
202 done_(weight.Size() == 0 ||
203 (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
205 bool Done()
const {
return done_ || iter_.Done(); }
212 const auto weight = iter_.Value();
215 GRW w1(siter.Value().first, weight.Value2());
216 GRW w2(siter.Value().second, W::One());
217 return std::make_pair(
GW(w1),
GW(w2));
228 template <
class Arc,
class FactorIterator>
272 LOG(WARNING) <<
"FactorWeightFst: Factor mode is set to 0; " 273 <<
"factoring neither arc weights nor final weights";
279 fst_(impl.fst_->Copy(true)),
282 final_ilabel_(impl.final_ilabel_),
283 final_olabel_(impl.final_olabel_),
284 increment_final_ilabel_(impl.increment_final_ilabel_),
285 increment_final_olabel_(impl.increment_final_olabel_) {
294 const auto s = fst_->Start();
296 SetStart(FindState(
Element(fst_->Start(), Weight::One())));
303 const auto &element = elements_[s];
307 :
Times(element.weight, fst_->Final(element.state));
308 FactorIterator siter(weight);
309 if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
312 SetFinal(s, Weight::Zero());
319 if (!HasArcs(s))
Expand(s);
324 if (!HasArcs(s))
Expand(s);
329 if (!HasArcs(s))
Expand(s);
337 if ((mask &
kError) && fst_->Properties(kError,
false)) {
338 SetProperties(kError, kError);
344 if (!HasArcs(s))
Expand(s);
351 if (!(mode_ & kFactorArcWeights) && element.
weight == Weight::One() &&
353 while (unfactored_.size() <= element.
state)
356 unfactored_[element.
state] = elements_.size();
357 elements_.push_back(element);
359 return unfactored_[element.
state];
361 const auto insert_result =
362 element_map_.emplace(element, elements_.size());
363 if (insert_result.second) {
364 elements_.push_back(element);
366 return insert_result.first->second;
373 const auto element = elements_[s];
377 const auto &arc = ait.Value();
378 auto weight =
Times(element.weight, arc.weight);
379 FactorIterator fiter(weight);
380 if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
381 const auto dest = FindState(
Element(arc.nextstate, Weight::One()));
382 EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
384 for (; !fiter.Done(); fiter.Next()) {
385 auto pair = fiter.Value();
387 FindState(
Element(arc.nextstate, pair.second.Quantize(delta_)));
388 EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
393 if ((mode_ & kFactorFinalWeights) &&
395 (fst_->Final(element.state) != Weight::Zero()))) {
399 :
Times(element.weight, fst_->Final(element.state));
400 auto ilabel = final_ilabel_;
401 auto olabel = final_olabel_;
402 for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
403 auto pair = fiter.Value();
406 EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
407 if (increment_final_ilabel_) ++ilabel;
408 if (increment_final_olabel_) ++olabel;
426 size_t operator()(
const Element &x)
const {
427 static constexpr
auto prime = 7853;
428 return static_cast<size_t>(x.
state * prime + x.
weight.Hash());
433 std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
435 std::unique_ptr<const Fst<Arc>> fst_;
440 bool increment_final_ilabel_;
441 bool increment_final_olabel_;
442 std::vector<Element> elements_;
443 ElementMap element_map_;
446 std::vector<StateId> unfactored_;
460 template <
class A,
class FactorIterator>
462 :
public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
494 GetMutableImpl()->InitArcIterator(s, data);
505 template <
class Arc,
class FactorIterator>
511 fst, fst.GetMutableImpl()) {}
515 template <
class Arc,
class FactorIterator>
523 fst.GetMutableImpl(), s) {
528 template <
class Arc,
class FactorIterator>
532 std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
538 #endif // FST_FACTOR_WEIGHT_H_ ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
FactorWeightFst(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
bool increment_final_olabel
GallicFactor(const GW &weight)
FactorWeightFst(const FactorWeightFst &fst, bool copy)
std::pair< W, W > Value() const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
uint64_t Properties(uint64_t mask) const override
uint64_t Properties() const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
size_t NumInputEpsilons(StateId s)
const SymbolTable * OutputSymbols() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
StateId FindState(const Element &element)
size_t NumOutputEpsilons(StateId s)
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Element(StateId s, Weight weight_)
std::pair< W, W > Value() const
bool increment_final_ilabel
StringFactor(const StringWeight< Label, S > &weight)
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
FactorWeightFstImpl(const FactorWeightFstImpl< Arc, FactorIterator > &impl)
constexpr uint64_t kCopyProperties
uint64_t FactorWeightProperties(uint64_t inprops)
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
typename Arc::Weight Weight
std::pair< GW, GW > Value() const
typename Arc::Weight Weight
std::unique_ptr< StateIteratorBase< Arc > > base
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
FactorWeightOptions(float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
typename Store::State State
FactorWeightOptions(const CacheOptions &opts, float delta=kDelta, uint8_t mode=kFactorArcWeights|kFactorFinalWeights, Label final_ilabel=0, Label final_olabel=0, bool increment_final_ilabel=false, bool increment_final_olabel=false)
FactorWeightFst(const Fst< Arc > &fst)
size_t NumArcs(StateId s)
typename Arc::StateId StateId
constexpr uint64_t kFstProperties
typename Arc::Label Label
constexpr uint8_t kFactorFinalWeights
typename Arc::StateId StateId
typename FactorWeightFst< Arc, FactorIterator >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
std::pair< StringWeight< Label, S >, StringWeight< Label, S > > Value() const
FactorWeightFstImpl(const Fst< Arc > &fst, const FactorWeightOptions< Arc > &opts)
FactorWeightFst * Copy(bool copy=false) const override
typename Arc::Label Label
void InitStateIterator(StateIteratorData< Arc > *data) const override
constexpr uint8_t kFactorArcWeights
GallicFactor(const GW &weight)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
typename Arc::StateId StateId
ArcIterator(const FactorWeightFst< Arc, FactorIterator > &fst, StateId s)
IdentityFactor(const W &weight)
std::pair< GW, GW > Value() const
StateIterator(const FactorWeightFst< Arc, FactorIterator > &fst)
const Impl * GetImpl() const
virtual const SymbolTable * OutputSymbols() const =0