20 #ifndef FST_RMEPSILON_H_ 21 #define FST_RMEPSILON_H_ 51 #include <unordered_map> 55 template <
class Arc,
class Queue>
67 Weight weight_threshold = Weight::Zero(),
72 weight_threshold(std::move(weight_threshold)),
73 state_threshold(state_threshold) {}
79 template <
class Arc,
class Queue>
82 using Label =
typename Arc::Label;
90 sd_state_(fst_, distance, opts, true),
95 std::vector<Arc> &
Arcs() {
return arcs_; }
99 bool Error()
const {
return sd_state_.Error(); }
110 : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
115 size_t operator()(
const Element &element)
const {
116 static constexpr
size_t prime0 = 7853;
117 static constexpr
size_t prime1 = 7867;
118 return static_cast<size_t>(element.nextstate) +
119 static_cast<size_t>(element.ilabel) * prime0 +
120 static_cast<size_t>(element.olabel) * prime1;
126 bool operator()(
const Element &e1,
const Element &e2)
const {
127 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
128 (e1.nextstate == e2.nextstate);
132 using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
133 ElementHash, ElementEqual>;
137 std::vector<Weight> *distance_;
143 ElementMap element_map_;
145 std::stack<StateId, std::vector<StateId>>
147 std::vector<bool> visited_;
148 std::vector<StateId> visited_states_;
149 std::vector<Arc> arcs_;
157 template <
class Arc,
class Queue>
159 final_weight_ = Weight::Zero();
161 sd_state_.ShortestDistance(source);
162 if (sd_state_.Error())
return;
163 eps_queue_.push(source);
164 while (!eps_queue_.empty()) {
165 const auto state = eps_queue_.top();
167 if (
static_cast<decltype(state)
>(visited_.size()) <= state) {
168 visited_.resize(state + 1,
false);
170 if (visited_[state])
continue;
171 visited_[state] =
true;
172 visited_states_.push_back(state);
175 auto arc = aiter.Value();
176 arc.weight =
Times((*distance_)[state], arc.weight);
177 if (eps_filter_(arc)) {
178 if (
static_cast<decltype(arc.nextstate)
>(visited_.size()) <=
180 visited_.resize(arc.nextstate + 1,
false);
182 if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
183 }
else if (
auto [insert_it, success] = element_map_.emplace(
184 Element(arc.ilabel, arc.olabel, arc.nextstate),
185 std::make_pair(expand_id_, arcs_.size()));
187 arcs_.push_back(std::move(arc));
188 }
else if (
auto &[xid, arc_idx] = insert_it->second; xid == expand_id_) {
189 auto &weight = arcs_[arc_idx].weight;
190 weight =
Plus(weight, arc.weight);
193 arc_idx = arcs_.size();
194 arcs_.push_back(std::move(arc));
198 Plus(final_weight_,
Times((*distance_)[state], fst_.Final(state)));
201 for (
const auto state_id : visited_states_) visited_[state_id] =
false;
202 visited_states_.clear();
216 template <
class Arc,
class Queue>
218 std::vector<typename Arc::Weight> *distance,
220 using StateId =
typename Arc::StateId;
221 using Weight =
typename Arc::Weight;
225 std::vector<bool> noneps_in(fst->
NumStates(),
false);
226 noneps_in[fst->
Start()] =
true;
227 for (
size_t i = 0; i < fst->
NumStates(); ++i) {
229 const auto &arc = aiter.Value();
230 if (arc.ilabel != 0 || arc.olabel != 0) {
231 noneps_in[arc.nextstate] =
true;
237 std::vector<StateId> states;
240 for (
size_t i = 0; i < fst->
NumStates(); i++) states.push_back(i);
242 std::vector<StateId> order;
248 FSTERROR() <<
"RmEpsilon: Inconsistent acyclic property bit";
252 states.resize(order.size());
253 for (
StateId i = 0; i < order.size(); i++) states[order[i]] = i;
256 std::vector<StateId> scc;
259 std::vector<StateId> first(scc.size(),
kNoStateId);
260 std::vector<StateId> next(scc.size(),
kNoStateId);
261 for (
StateId i = 0; i < scc.size(); i++) {
262 if (first[scc[i]] !=
kNoStateId) next[i] = first[scc[i]];
265 for (
StateId i = 0; i < first.size(); i++) {
266 for (
auto j = first[i]; j !=
kNoStateId; j = next[j]) {
272 while (!states.empty()) {
273 const auto state = states.back();
275 if (!noneps_in[state] &&
280 rmeps_state.
Expand(state);
283 auto &arcs = rmeps_state.
Arcs();
285 while (!arcs.empty()) {
286 fst->
AddArc(state, arcs.back());
292 for (
size_t s = 0; s < fst->
NumStates(); ++s) {
305 FSTERROR() <<
"RmEpsilon: Weight must have path property: " 343 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
344 typename Arc::StateId state_threshold =
kNoStateId,
346 using StateId =
typename Arc::StateId;
347 using Weight =
typename Arc::Weight;
348 std::vector<Weight> distance;
351 &state_queue,
delta, connect, weight_threshold, state_threshold);
408 fst_(impl.fst_->Copy(true)),
420 if (!HasStart()) SetStart(fst_->Start());
425 if (!HasFinal(s))
Expand(s);
430 if (!HasArcs(s))
Expand(s);
435 if (!HasArcs(s))
Expand(s);
440 if (!HasArcs(s))
Expand(s);
449 (fst_->Properties(kError,
false) || rmeps_state_.Error())) {
450 SetProperties(kError, kError);
456 if (!HasArcs(s))
Expand(s);
461 rmeps_state_.Expand(s);
462 SetFinal(s, rmeps_state_.Final());
463 auto &arcs = rmeps_state_.Arcs();
464 while (!arcs.empty()) {
465 PushArc(s, std::move(arcs.back()));
472 std::unique_ptr<const Fst<Arc>> fst_;
474 std::vector<Weight> distance_;
525 :
Base(std::make_shared<
Impl>(fst, opts)) {}
538 GetMutableImpl()->InitArcIterator(s, data);
543 using Base::GetMutableImpl;
573 data->
base = std::make_unique<StateIterator<RmEpsilonFst<Arc>>>(*this);
581 #endif // FST_RMEPSILON_H_ RmEpsilonState(const Fst< Arc > &fst, std::vector< Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
uint64_t Properties() const override
void RmEpsilon(MutableFst< Arc > *fst, std::vector< typename Arc::Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
typename Arc::Weight Weight
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
typename Arc::Label Label
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
size_t NumOutputEpsilons(StateId s)
uint64_t RmEpsilonProperties(uint64_t inprops, bool delayed=false)
RmEpsilonFst(const RmEpsilonFst &fst, bool safe=false)
const Weight & Final() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::Weight Weight
const SymbolTable * OutputSymbols() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
typename Arc::StateId StateId
constexpr uint64_t kError
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
typename Store::State State
typename RmEpsilonFst< Arc >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
constexpr uint64_t kTopSorted
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
virtual void ReserveArcs(StateId, size_t)
size_t NumInputEpsilons(StateId s)
RmEpsilonFst(const Fst< A > &fst, const RmEpsilonFstOptions &opts)
ArcIterator(const RmEpsilonFst< Arc > &fst, StateId s)
const Impl * GetImpl() const
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
RmEpsilonFstOptions(const CacheOptions &opts, float delta=kShortestDelta)
RmEpsilonFst * Copy(bool safe=false) const override
std::vector< Arc > & Arcs()
constexpr uint64_t kCopyProperties
constexpr uint64_t kAcyclic
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
virtual void DeleteArcs(StateId, size_t)=0
virtual StateId Start() const =0
size_t NumArcs(StateId s)
RmEpsilonOptions(Queue *queue, float delta=kShortestDelta, bool connect=true, Weight weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId)
std::unique_ptr< StateIteratorBase< Arc > > base
constexpr float kShortestDelta
void Expand(const FstClass &ifst, const std::vector< std::pair< int64_t, int64_t >> &parens, const std::vector< int64_t > &assignments, MutableFstClass *ofst, const MPdtExpandOptions &opts)
bool HasArcs(StateId s) const
typename Arc::StateId StateId
RmEpsilonFstImpl(const Fst< Arc > &fst, const RmEpsilonFstOptions &opts)
constexpr uint64_t kFstProperties
virtual void AddArc(StateId, const Arc &)=0
typename Arc::Weight Weight
RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
typename internal::RmEpsilonFstImpl< Arc >::Arc Arc
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
RmEpsilonFst(const Fst< Arc > &fst)
typename RmEpsilonFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
typename Arc::StateId StateId
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
Impl * GetMutableImpl() const
typename Arc::StateId StateId
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
uint64_t Properties(uint64_t mask) const override
StateIterator(const RmEpsilonFst< Arc > &fst)
typename Arc::StateId StateId
typename CacheState< Arc >::Arc Arc
RmEpsilonFstOptions(float delta=kShortestDelta)
virtual StateId NumStates() const =0
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
virtual const SymbolTable * OutputSymbols() const =0