20 #ifndef FST_RMEPSILON_H_ 21 #define FST_RMEPSILON_H_ 41 #include <unordered_map> 45 template <
class Arc,
class Queue>
57 Weight weight_threshold = Weight::Zero(),
62 weight_threshold(std::move(weight_threshold)),
63 state_threshold(state_threshold) {}
69 template <
class Arc,
class Queue>
72 using Label =
typename Arc::Label;
80 sd_state_(fst_, distance, opts, true),
85 std::vector<Arc> &
Arcs() {
return arcs_; }
89 bool Error()
const {
return sd_state_.Error(); }
100 : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
105 size_t operator()(
const Element &element)
const {
106 static constexpr
size_t prime0 = 7853;
107 static constexpr
size_t prime1 = 7867;
108 return static_cast<size_t>(element.nextstate) +
109 static_cast<size_t>(element.ilabel) * prime0 +
110 static_cast<size_t>(element.olabel) * prime1;
116 bool operator()(
const Element &e1,
const Element &e2)
const {
117 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
118 (e1.nextstate == e2.nextstate);
122 using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
123 ElementHash, ElementEqual>;
127 std::vector<Weight> *distance_;
133 ElementMap element_map_;
135 std::stack<StateId, std::vector<StateId>>
137 std::vector<bool> visited_;
138 std::vector<StateId> visited_states_;
139 std::vector<Arc> arcs_;
147 template <
class Arc,
class Queue>
149 final_weight_ = Weight::Zero();
151 sd_state_.ShortestDistance(source);
152 if (sd_state_.Error())
return;
153 eps_queue_.push(source);
154 while (!eps_queue_.empty()) {
155 const auto state = eps_queue_.top();
157 if (
static_cast<decltype(state)
>(visited_.size()) <= state) {
158 visited_.resize(state + 1,
false);
160 if (visited_[state])
continue;
161 visited_[state] =
true;
162 visited_states_.push_back(state);
165 auto arc = aiter.Value();
166 arc.weight =
Times((*distance_)[state], arc.weight);
167 if (eps_filter_(arc)) {
168 if (
static_cast<decltype(arc.nextstate)
>(visited_.size()) <=
170 visited_.resize(arc.nextstate + 1,
false);
172 if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
173 }
else if (
auto [insert_it, success] = element_map_.emplace(
174 Element(arc.ilabel, arc.olabel, arc.nextstate),
175 std::make_pair(expand_id_, arcs_.size()));
177 arcs_.push_back(std::move(arc));
178 }
else if (
auto &[xid, arc_idx] = insert_it->second; xid == expand_id_) {
179 auto &weight = arcs_[arc_idx].weight;
180 weight =
Plus(weight, arc.weight);
183 arc_idx = arcs_.size();
184 arcs_.push_back(std::move(arc));
188 Plus(final_weight_,
Times((*distance_)[state], fst_.Final(state)));
191 for (
const auto state_id : visited_states_) visited_[state_id] =
false;
192 visited_states_.clear();
206 template <
class Arc,
class Queue>
208 std::vector<typename Arc::Weight> *distance,
210 using StateId =
typename Arc::StateId;
211 using Weight =
typename Arc::Weight;
215 std::vector<bool> noneps_in(fst->
NumStates(),
false);
216 noneps_in[fst->
Start()] =
true;
217 for (
size_t i = 0; i < fst->
NumStates(); ++i) {
219 const auto &arc = aiter.Value();
220 if (arc.ilabel != 0 || arc.olabel != 0) {
221 noneps_in[arc.nextstate] =
true;
227 std::vector<StateId> states;
230 for (
size_t i = 0; i < fst->
NumStates(); i++) states.push_back(i);
232 std::vector<StateId> order;
238 FSTERROR() <<
"RmEpsilon: Inconsistent acyclic property bit";
242 states.resize(order.size());
243 for (
StateId i = 0; i < order.size(); i++) states[order[i]] = i;
246 std::vector<StateId> scc;
249 std::vector<StateId> first(scc.size(),
kNoStateId);
250 std::vector<StateId> next(scc.size(),
kNoStateId);
251 for (
StateId i = 0; i < scc.size(); i++) {
252 if (first[scc[i]] !=
kNoStateId) next[i] = first[scc[i]];
255 for (
StateId i = 0; i < first.size(); i++) {
256 for (
auto j = first[i]; j !=
kNoStateId; j = next[j]) {
262 while (!states.empty()) {
263 const auto state = states.back();
265 if (!noneps_in[state] &&
270 rmeps_state.
Expand(state);
273 auto &arcs = rmeps_state.
Arcs();
275 while (!arcs.empty()) {
276 fst->
AddArc(state, arcs.back());
282 for (
size_t s = 0; s < fst->
NumStates(); ++s) {
295 FSTERROR() <<
"RmEpsilon: Weight must have path property: " 333 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
334 typename Arc::StateId state_threshold =
kNoStateId,
336 using StateId =
typename Arc::StateId;
337 using Weight =
typename Arc::Weight;
338 std::vector<Weight> distance;
341 &state_queue,
delta, connect, weight_threshold, state_threshold);
398 fst_(impl.fst_->Copy(true)),
410 if (!HasStart()) SetStart(fst_->Start());
415 if (!HasFinal(s))
Expand(s);
420 if (!HasArcs(s))
Expand(s);
425 if (!HasArcs(s))
Expand(s);
430 if (!HasArcs(s))
Expand(s);
439 (fst_->Properties(kError,
false) || rmeps_state_.Error())) {
440 SetProperties(kError, kError);
446 if (!HasArcs(s))
Expand(s);
451 rmeps_state_.Expand(s);
452 SetFinal(s, rmeps_state_.Final());
453 auto &arcs = rmeps_state_.Arcs();
454 while (!arcs.empty()) {
455 PushArc(s, std::move(arcs.back()));
462 std::unique_ptr<const Fst<Arc>> fst_;
464 std::vector<Weight> distance_;
527 GetMutableImpl()->InitArcIterator(s, data);
562 data->
base = std::make_unique<StateIterator<RmEpsilonFst<Arc>>>(*this);
570 #endif // FST_RMEPSILON_H_ RmEpsilonState(const Fst< Arc > &fst, std::vector< Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
uint64_t Properties() const override
void RmEpsilon(MutableFst< Arc > *fst, std::vector< typename Arc::Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
typename Arc::Weight Weight
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
typename Arc::Label Label
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
size_t NumOutputEpsilons(StateId s)
uint64_t RmEpsilonProperties(uint64_t inprops, bool delayed=false)
RmEpsilonFst(const RmEpsilonFst &fst, bool safe=false)
const Weight & Final() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::Weight Weight
const SymbolTable * OutputSymbols() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
typename Arc::StateId StateId
constexpr uint64_t kError
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
typename Store::State State
typename RmEpsilonFst< Arc >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
constexpr uint64_t kTopSorted
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
virtual void ReserveArcs(StateId, size_t)
size_t NumInputEpsilons(StateId s)
RmEpsilonFst(const Fst< A > &fst, const RmEpsilonFstOptions &opts)
ArcIterator(const RmEpsilonFst< Arc > &fst, StateId s)
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
RmEpsilonFstOptions(const CacheOptions &opts, float delta=kShortestDelta)
RmEpsilonFst * Copy(bool safe=false) const override
std::vector< Arc > & Arcs()
constexpr uint64_t kCopyProperties
constexpr uint64_t kAcyclic
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
virtual void DeleteArcs(StateId, size_t)=0
virtual StateId Start() const =0
size_t NumArcs(StateId s)
RmEpsilonOptions(Queue *queue, float delta=kShortestDelta, bool connect=true, Weight weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId)
std::unique_ptr< StateIteratorBase< Arc > > base
constexpr float kShortestDelta
void Expand(const FstClass &ifst, const std::vector< std::pair< int64_t, int64_t >> &parens, const std::vector< int64_t > &assignments, MutableFstClass *ofst, const MPdtExpandOptions &opts)
typename Arc::StateId StateId
RmEpsilonFstImpl(const Fst< Arc > &fst, const RmEpsilonFstOptions &opts)
constexpr uint64_t kFstProperties
virtual void AddArc(StateId, const Arc &)=0
typename Arc::Weight Weight
RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
typename internal::RmEpsilonFstImpl< Arc >::Arc Arc
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
RmEpsilonFst(const Fst< Arc > &fst)
typename RmEpsilonFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
typename Arc::StateId StateId
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
typename Arc::StateId StateId
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
uint64_t Properties(uint64_t mask) const override
StateIterator(const RmEpsilonFst< Arc > &fst)
typename Arc::StateId StateId
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
RmEpsilonFstOptions(float delta=kShortestDelta)
virtual StateId NumStates() const =0
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
const Impl * GetImpl() const
virtual const SymbolTable * OutputSymbols() const =0