20 #ifndef FST_DETERMINIZE_H_ 21 #define FST_DETERMINIZE_H_ 26 #include <forward_list> 65 template <
typename Label, StringType S>
74 FSTERROR() <<
"LabelCommonDivisor: Weight needs to be left semiring";
75 return Weight::NoWeight();
76 }
else if (w1.
Size() == 0 || w2.
Size() == 0) {
78 }
else if (w1 == Weight::Zero()) {
80 }
else if (w2 == Weight::Zero()) {
106 CommonDivisor weight_common_divisor_;
110 template <
class Label,
class W,
class CommonDivisor>
119 auto weight = GRWeight::Zero();
121 weight = common_divisor_(weight, iter.Value());
124 weight = common_divisor_(weight, iter.Value());
126 return weight == GRWeight::Zero() ? Weight::Zero() :
Weight(weight);
142 : state_id(s), weight(std::move(weight)) {}
149 return !(*
this == element);
152 inline bool operator<(const DeterminizeElement<Arc> &element)
const {
153 return state_id < element.state_id;
161 template <
typename A,
typename FilterState>
165 using Subset = std::forward_list<Element>;
184 template <
class StateTuple>
186 using Arc =
typename StateTuple::Arc;
200 std::unique_ptr<StateTuple>
222 using LabelMap = std::map<Label, internal::DeterminizeArc<StateTuple>>;
233 template <
class Filter>
235 : fst_(fst.Copy()) {}
240 : fst_(
fst ?
fst->Copy() : filter.fst_->Copy()) {}
252 auto &det_arc = (*label_map)[arc.ilabel];
257 det_arc.dest_tuple->subset.push_front(std::move(dest_element));
267 std::unique_ptr<Fst<Arc>> fst_;
301 template <
class Arc,
class FilterState>
312 template <
class B,
class G>
318 : table_size_(table_size), tuples_(table_size_) {}
322 : table_size_(table.table_size_), tuples_(table_size_) {}
325 for (
StateId s = 0; s < tuples_.Size(); ++s)
delete tuples_.FindEntry(s);
333 const StateId ns = tuples_.Size();
336 const auto s = tuples_.FindId(raw_tuple);
337 if (s != ns)
delete raw_tuple;
345 class StateTupleEqual {
348 return *tuple1 == *tuple2;
353 class StateTupleKey {
357 for (
auto &element : tuple->
subset) {
358 const size_t h1 = element.state_id;
359 static constexpr
auto lshift = 5;
360 static constexpr
auto rshift = CHAR_BIT *
sizeof(size_t) - 5;
361 h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ element.weight.Hash();
412 Label subsequential_label = 0,
414 bool increment_subsequential_label =
false,
415 Filter *filter =
nullptr,
416 StateTable *state_table =
nullptr)
419 subsequential_label(subsequential_label),
421 increment_subsequential_label(increment_subsequential_label),
423 state_table(state_table) {}
426 Label subsequential_label = 0,
428 bool increment_subsequential_label =
false,
429 Filter *filter =
nullptr,
430 StateTable *state_table =
nullptr)
432 subsequential_label(subsequential_label),
434 increment_subsequential_label(increment_subsequential_label),
436 state_table(state_table) {}
466 template <
class CommonDivisor,
class Filter,
class StateTable>
495 const auto start = ComputeStart();
502 if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
509 if (!HasArcs(s))
Expand(s);
514 if (!HasArcs(s))
Expand(s);
519 if (!HasArcs(s))
Expand(s);
524 if (!HasArcs(s))
Expand(s);
528 virtual StateId ComputeStart() = 0;
535 std::unique_ptr<const Fst<Arc>> fst_;
539 template <
class Arc,
class CommonDivisor,
class Filter,
class StateTable>
557 const Fst<Arc> &
fst,
const std::vector<Weight> *in_dist,
558 std::vector<Weight> *out_dist,
564 filter_(opts.filter ? opts.filter : new Filter(fst)),
565 state_table_(opts.state_table ? opts.state_table : new StateTable()) {
567 FSTERROR() <<
"DeterminizeFst: Argument not an acceptor";
571 FSTERROR() <<
"DeterminizeFst: Weight must be left distributive: " 575 if (out_dist_) out_dist_->clear();
583 filter_(new Filter(*impl.filter_, &GetFst())),
584 state_table_(new StateTable(*impl.state_table_)) {
585 if (impl.out_dist_) {
586 FSTERROR() <<
"DeterminizeFsaImpl: Cannot copy with out_dist vector";
599 if ((mask &
kError) && (GetFst().Properties(kError,
false))) {
600 SetProperties(kError, kError);
606 const auto s = GetFst().Start();
608 auto tuple = fst::make_unique_for_overwrite<StateTuple>();
609 tuple->subset.emplace_front(s, Weight::One());
610 tuple->filter_state = filter_->Start();
611 return FindState(std::move(tuple));
615 const auto *tuple = state_table_->Tuple(s);
616 filter_->SetState(s, *tuple);
617 auto final_weight = Weight::Zero();
618 for (
const auto &element : tuple->subset) {
621 Times(element.weight, GetFst().
Final(element.state_id)));
622 final_weight = filter_->FilterFinal(final_weight, element);
623 if (!final_weight.Member()) SetProperties(
kError,
kError);
629 const auto &subset = tuple->subset;
630 const auto s = state_table_->FindState(std::move(tuple));
631 if (in_dist_ && out_dist_->size() <= s) {
632 out_dist_->push_back(ComputeDistance(subset));
640 auto outd = Weight::Zero();
641 for (
const auto &element : subset) {
643 (element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id]
645 outd =
Plus(outd,
Times(element.weight, ind));
654 GetLabelMap(s, &label_map);
655 for (
auto &[unused_label, arc] : label_map) {
656 AddArc(s, std::move(arc));
667 const auto *src_tuple = state_table_->Tuple(s);
668 filter_->SetState(s, *src_tuple);
669 for (
const auto &src_element : src_tuple->subset) {
671 !aiter.Done(); aiter.Next()) {
672 const auto &arc = aiter.Value();
673 Element dest_element(arc.nextstate,
674 Times(src_element.weight, arc.weight));
675 filter_->FilterArc(arc, src_element, std::move(dest_element),
679 for (
auto &[unused_label, arc] : *label_map) {
686 void NormArc(
DetArc *det_arc) {
687 std::unique_ptr<StateTuple> &dest_tuple = det_arc->
dest_tuple;
688 dest_tuple->subset.sort();
689 auto piter = dest_tuple->subset.begin();
690 for (
auto diter = dest_tuple->subset.begin();
691 diter != dest_tuple->subset.end();) {
692 auto &dest_element = *diter;
693 auto &prev_element = *piter;
695 det_arc->
weight = common_divisor_(det_arc->
weight, dest_element.weight);
696 if (piter != diter && dest_element.state_id == prev_element.state_id) {
698 prev_element.weight =
Plus(prev_element.weight, dest_element.weight);
699 if (!prev_element.weight.Member()) SetProperties(
kError,
kError);
701 dest_tuple->subset.erase_after(piter);
709 for (
auto &dest_element : dest_tuple->subset) {
710 dest_element.weight =
712 dest_element.weight = dest_element.weight.Quantize(delta_);
720 std::move(det_arc.
weight),
725 const std::vector<Weight> *in_dist_;
726 std::vector<Weight> *out_dist_;
728 static const CommonDivisor common_divisor_;
729 std::unique_ptr<Filter> filter_;
730 std::unique_ptr<StateTable> state_table_;
733 template <
class Arc,
class CommonDivisor,
class Filter,
class StateTable>
735 StateTable>::common_divisor_{};
741 template <
class Arc,
GallicType G,
class CommonDivisor,
class Filter,
756 using ToFilter =
typename Filter::template rebind<ToArc>::Other;
759 typename StateTable::template rebind<ToArc, ToFilterState>::Other;
772 subsequential_label_(opts.subsequential_label),
773 increment_subsequential_label_(opts.increment_subsequential_label) {
776 <<
"A state table can not be passed with transducer input";
787 subsequential_label_(impl.subsequential_label_),
788 increment_subsequential_label_(impl.increment_subsequential_label_) {
789 Init(GetFst(),
nullptr);
800 if ((mask &
kError) && (GetFst().Properties(kError,
false) ||
801 from_fst_->Properties(kError,
false))) {
802 SetProperties(kError, kError);
822 void Init(
const Fst<Arc> &
fst, std::unique_ptr<Filter> filter);
825 Label subsequential_label_;
826 bool increment_subsequential_label_;
827 std::unique_ptr<FromFst> from_fst_;
871 template <
class B,
GallicType G,
class CommonDivisor,
class Filter,
873 friend class DeterminizeFstImpl;
878 template <
class CommonDivisor,
class Filter,
class StateTable>
889 template <
class CommonDivisor,
class Filter,
class StateTable>
891 const Fst<Arc> &fst,
const std::vector<Weight> *in_dist,
892 std::vector<Weight> *out_dist,
897 std::make_shared<internal::DeterminizeFsaImpl<
Arc, CommonDivisor,
898 Filter, StateTable>>(
899 fst, in_dist, out_dist, opts)) {
902 <<
"Distance to final states computed for acceptors only";
910 : fst.GetSharedImpl()) {}
920 GetMutableImpl()->InitArcIterator(s, data);
927 static std::shared_ptr<Impl> CreateImpl(
const Fst<Arc> &fst) {
932 return CreateImpl(fst, opts);
935 template <
class CommonDivisor,
class Filter,
class StateTable>
936 static std::shared_ptr<Impl> CreateImpl(
942 return std::make_shared<
944 fst,
nullptr,
nullptr, opts);
949 Arc,
GALLIC_MIN, CommonDivisor, Filter, StateTable>>(fst, opts);
951 FSTERROR() <<
"DeterminizeFst: Weight needs to have the path " 952 <<
"property to disambiguate output: " << Weight::Type();
956 Arc,
GALLIC, CommonDivisor, Filter, StateTable>>(empty_fst, opts);
967 Arc,
GALLIC, CommonDivisor, Filter, StateTable>>(fst, opts);
978 template <
class A, GallicType G,
class D,
class F,
class T>
979 void DeterminizeFstImpl<A, G, D, F, T>::Init(
const Fst<A> &fst,
980 std::unique_ptr<F> filter) {
982 const ToFst to_fst(fst);
983 auto *to_filter = filter ?
new ToFilter(to_fst, std::move(filter)) :
nullptr;
986 const CacheOptions copts(GetCacheGc(), GetCacheLimit());
994 subsequential_label_, increment_subsequential_label_,
995 increment_subsequential_label_);
997 from_fst_ = std::make_unique<FromFst>(factored_fst,
998 FromMapper(subsequential_label_));
1004 template <
class Arc>
1013 template <
class Arc>
1025 template <
class Arc>
1028 data->
base = std::make_unique<StateIterator<DeterminizeFst<Arc>>>(*this);
1034 template <
class Arc>
1050 Weight weight_threshold = Weight::Zero(),
1052 Label subsequential_label = 0,
1054 bool increment_subsequential_label =
false)
1056 weight_threshold(std::move(weight_threshold)),
1057 state_threshold(state_threshold),
1058 subsequential_label(subsequential_label),
1060 increment_subsequential_label(increment_subsequential_label) {}
1078 template <
class Arc>
1082 using Weight =
typename Arc::Weight;
1084 nopts.
delta = opts.delta;
1086 nopts.
type = opts.type;
1089 if (opts.weight_threshold != Weight::Zero() ||
1093 std::vector<Weight> idistance;
1094 std::vector<Weight> odistance;
1100 Prune(dfst, ofst, popts);
1103 Prune(ofst, opts.weight_threshold, opts.state_threshold);
1106 FSTERROR() <<
"Determinize: Weight needs to have the path " 1107 <<
"property to use pruning options: " << Weight::Type();
1117 #endif // FST_DETERMINIZE_H_
DefaultDeterminizeStateTable(const DefaultDeterminizeStateTable< Arc, FilterState > &table)
typename Arc::Label Label
DefaultDeterminizeFilter(const Fst< Arc > &fst)
uint64_t Properties(uint64_t mask) const override
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
uint64_t Properties(uint64_t mask) const override
size_t NumOutputEpsilons(StateId s)
FilterState Start() const
typename StateTuple::Arc Arc
size_t NumArcs(StateId s)
typename StateTuple::Element Element
DefaultDeterminizeStateTable(size_t table_size=0)
DeterminizeFsaImpl(const Fst< Arc > &fst, const std::vector< Weight > *in_dist, std::vector< Weight > *out_dist, const DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable > &opts)
uint64_t Properties() const override
bool FilterArc(const Arc &arc, const Element &src_element, Element &&dest_element, LabelMap *label_map) const
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename StateTuple::Subset Subset
DeterminizeFst(const Fst< Arc > &fst, const std::vector< Weight > *in_dist, std::vector< Weight > *out_dist, const DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable > &opts=DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable >())
DeterminizeArc(const Arc &arc)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
const Label & Value() const
bool increment_subsequential_label
bool operator==(const DeterminizeElement< Arc > &element) const
bool operator!=(const DeterminizeStateTuple< Arc, FilterState > &tuple) const
DefaultDeterminizeFilter(const DefaultDeterminizeFilter< Arc > &filter, const Fst< Arc > *fst=nullptr)
bool operator!=(const DeterminizeElement< Arc > &element) const
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
void Determinize(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DeterminizeOptions< Arc > &opts=DeterminizeOptions< Arc >())
std::map< Label, internal::DeterminizeArc< StateTuple >> LabelMap
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
DeterminizeFstOptions(const CacheOptions &opts, float delta=kDelta, Label subsequential_label=0, DeterminizeType type=DETERMINIZE_FUNCTIONAL, bool increment_subsequential_label=false, Filter *filter=nullptr, StateTable *state_table=nullptr)
Weight operator()(const Weight &w1, const Weight &w2) const
DeterminizeElement(StateId s, Weight weight)
typename Arc::Label Label
const W2 & Value2() const
typename DeterminizeFst< Arc >::Arc Arc
void InitStateIterator(StateIteratorData< Arc > *data) const override
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::Label Label
typename StateTuple::Subset Subset
std::unique_ptr< T > WrapUnique(T *ptr)
DeterminizeFstImplBase(const DeterminizeFstImplBase &impl)
static uint64_t Properties(uint64_t props)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
typename ToMapper::ToArc ToArc
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
~DefaultDeterminizeStateTable()
Weight FilterFinal(Weight weight, const Element &element)
Weight ComputeDistance(const Subset &subset)
typename Store::State State
typename Arc::Weight Weight
ArcIterator(const DeterminizeFst< Arc > &fst, StateId s)
DeterminizeFstOptions(float delta=kDelta, Label subsequential_label=0, DeterminizeType type=DETERMINIZE_FUNCTIONAL, bool increment_subsequential_label=false, Filter *filter=nullptr, StateTable *state_table=nullptr)
size_t NumInputEpsilons(StateId s)
IntegerFilterState< signed char > CharFilterState
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Filter::FilterState FilterState
virtual uint64_t Properties() const
DefaultDeterminizeFilter(const Fst< Arc > &fst, std::unique_ptr< Filter > filter)
Weight operator()(const Weight &w1, const Weight &w2) const
DeterminizeFst * Copy(bool safe=false) const override
constexpr uint64_t kCopyProperties
void SetState(StateId s, const StateTuple &tuple)
virtual void SetProperties(uint64_t props, uint64_t mask)=0
DeterminizeFstImpl(const DeterminizeFstImpl &impl)
typename Arc::Weight Weight
Label subsequential_label
const StateTuple * Tuple(StateId s)
std::forward_list< Element > Subset
typename ToFilter::FilterState ToFilterState
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
typename Arc::Weight Weight
std::unique_ptr< T > make_unique_for_overwrite()
typename Filter::template rebind< ToArc >::Other ToFilter
StateId FindState(std::unique_ptr< StateTuple > tuple)
Label subsequential_label
typename Arc::StateId StateId
bool operator==(const DeterminizeStateTuple< Arc, FilterState > &tuple) const
DeterminizeFst(const DeterminizeFst &fst, bool safe=false)
std::unique_ptr< StateIteratorBase< Arc > > base
const Fst< Arc > & GetFst() const
DeterminizeFsaImpl(const DeterminizeFsaImpl &impl)
typename Arc::Label Label
typename Arc::StateId StateId
typename Arc::Weight Weight
DeterminizeFst(const Fst< A > &fst)
void Expand(StateId s) override
typename Arc::Label Label
typename Arc::StateId StateId
typename Arc::StateId StateId
constexpr uint64_t kFstProperties
typename Arc::StateId StateId
constexpr uint8_t kFactorFinalWeights
StateIterator(const DeterminizeFst< Arc > &fst)
DeterminizeFstImpl * Copy() const override
Weight ComputeFinal(StateId s) override
Weight operator()(const Weight &w1, const Weight &w2) const
typename DeterminizeFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
uint64_t DeterminizeProperties(uint64_t inprops, bool has_subsequential_label, bool distinct_psubsequential_labels)
typename Arc::Weight Weight
void Expand(StateId s) override
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
typename StateTuple::Element Element
DeterminizeOptions(float delta=kDelta, Weight weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId, Label subsequential_label=0, DeterminizeType type=DETERMINIZE_FUNCTIONAL, bool increment_subsequential_label=false)
typename Arc::Label Label
DeterminizeFst(const Fst< Arc > &fst, const DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable > &opts=DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable >())
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
typename Arc::Label Label
bool increment_subsequential_label
typename Arc::Weight Weight
DeterminizeFsaImpl * Copy() const override
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
std::unique_ptr< StateTuple > dest_tuple
typename Filter::LabelMap LabelMap
StateId FindState(std::unique_ptr< StateTuple > tuple)
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
typename Arc::StateId StateId
typename Arc::Weight Weight
DeterminizeFstImplBase(const Fst< Arc > &fst, const DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable > &opts)
uint64_t Properties() const override
StateId ComputeStart() override
DeterminizeFstImpl(const Fst< Arc > &fst, const DeterminizeFstOptions< Arc, CommonDivisor, Filter, StateTable > &opts)
constexpr uint64_t kLeftSemiring
StateId ComputeStart() override
typename Arc::StateId StateId
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
Weight ComputeFinal(StateId s) override
typename Arc::Weight Weight
Weight operator()(const Weight &w1, const Weight &w2) const
const Impl * GetImpl() const
const W1 & Value1() const
constexpr uint64_t kAcceptor
typename StateTable::template rebind< ToArc, ToFilterState >::Other ToStateTable
virtual const SymbolTable * OutputSymbols() const =0