23 #include <type_traits> 37 template <
class StateId,
class Weight>
41 const std::vector<Weight> &fdistance)
42 : idistance_(idistance), fdistance_(fdistance) {}
44 bool operator()(
const StateId x,
const StateId y)
const {
45 const auto wx =
Times(IDistance(x), FDistance(x));
46 const auto wy =
Times(IDistance(y), FDistance(y));
51 Weight IDistance(
const StateId s)
const {
52 return s < idistance_.size() ? idistance_[s] : Weight::Zero();
55 Weight FDistance(
const StateId s)
const {
56 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
59 const std::vector<Weight> &idistance_;
60 const std::vector<Weight> &fdistance_;
66 template <
class Arc,
class ArcFilter>
73 ArcFilter filter = ArcFilter(),
74 std::vector<Weight> *distance =
nullptr,
75 float delta =
kDelta,
bool threshold_initial =
false)
76 : weight_threshold(std::move(weight_threshold)),
77 state_threshold(state_threshold),
78 filter(std::move(filter)),
81 threshold_initial(threshold_initial) {}
110 template <
class Arc,
class ArcFilter>
113 using StateId =
typename Arc::StateId;
114 using Weight =
typename Arc::Weight;
119 std::vector<Weight> idistance(ns, Weight::Zero());
120 std::vector<Weight> tmp;
121 if (!opts.distance) {
125 const auto *fdistance = opts.distance ? opts.distance : &tmp;
126 if ((opts.state_threshold == 0) || (fdistance->size() <= fst->
Start()) ||
127 ((*fdistance)[fst->
Start()] == Weight::Zero())) {
132 StateHeap heap(compare);
133 std::vector<bool> visited(ns,
false);
134 std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
135 std::vector<StateId> dead;
138 auto s = fst->
Start();
139 const auto limit = opts.threshold_initial
140 ?
Times(opts.weight_threshold, (*fdistance)[s])
141 :
Times((*fdistance)[s], opts.weight_threshold);
142 StateId num_visited = 0;
144 if (!less(limit, (*fdistance)[s])) {
145 idistance[s] = Weight::One();
146 enqueued[s] = heap.Insert(s);
149 while (!heap.Empty()) {
152 enqueued[s] = StateHeap::kNoKey;
154 if (less(limit,
Times(idistance[s], fst->
Final(s)))) {
159 auto arc = aiter.Value();
160 if (!opts.filter(arc))
continue;
163 arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
165 if (less(limit, weight)) {
166 arc.nextstate = dead[0];
170 if (less(
Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
171 idistance[arc.nextstate] =
Times(idistance[s], arc.weight);
173 if (visited[arc.nextstate])
continue;
175 (num_visited >= opts.state_threshold)) {
178 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
179 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
182 heap.Update(enqueued[arc.nextstate], arc.nextstate);
186 for (StateId i = 0; i < visited.size(); ++i) {
187 if (!visited[i]) dead.push_back(i);
204 typename Arc::StateId state_threshold =
kNoStateId,
222 template <
class Arc,
class ArcFilter>
226 using StateId =
typename Arc::StateId;
227 using Weight =
typename Arc::Weight;
235 if (less(opts.weight_threshold, Weight::One()) ||
236 (opts.state_threshold == 0)) {
239 std::vector<Weight> idistance;
240 std::vector<Weight> tmp;
242 const auto *fdistance = opts.distance ? opts.distance : &tmp;
243 if ((fdistance->size() <= ifst.
Start()) ||
244 ((*fdistance)[ifst.
Start()] == Weight::Zero())) {
248 StateHeap heap(compare);
249 std::vector<StateId> copy;
250 std::vector<size_t> enqueued;
251 std::vector<bool> visited;
252 auto s = ifst.
Start();
253 const auto limit = opts.threshold_initial
254 ?
Times(opts.weight_threshold, (*fdistance)[s])
255 :
Times((*fdistance)[s], opts.weight_threshold);
256 while (copy.size() <= s) copy.push_back(
kNoStateId);
259 while (idistance.size() <= s) idistance.push_back(Weight::Zero());
260 idistance[s] = Weight::One();
261 while (enqueued.size() <= s) {
262 enqueued.push_back(StateHeap::kNoKey);
263 visited.push_back(
false);
265 enqueued[s] = heap.Insert(s);
266 while (!heap.Empty()) {
269 enqueued[s] = StateHeap::kNoKey;
271 if (!less(limit,
Times(idistance[s], ifst.
Final(s)))) {
275 const auto &arc = aiter.Value();
276 if (!opts.filter(arc))
continue;
279 arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
281 if (less(limit, weight))
continue;
283 (ofst->
NumStates() >= opts.state_threshold)) {
286 while (idistance.size() <= arc.nextstate) {
287 idistance.push_back(Weight::Zero());
289 if (less(
Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
290 idistance[arc.nextstate] =
Times(idistance[s], arc.weight);
292 while (copy.size() <= arc.nextstate) copy.push_back(
kNoStateId);
294 copy[arc.nextstate] = ofst->
AddState();
296 ofst->
AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
297 copy[arc.nextstate]));
298 while (enqueued.size() <= arc.nextstate) {
299 enqueued.push_back(StateHeap::kNoKey);
300 visited.push_back(
false);
302 if (visited[arc.nextstate])
continue;
303 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
304 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
306 heap.Update(enqueued[arc.nextstate], arc.nextstate);
325 typename Arc::Weight weight_threshold,
326 typename Arc::StateId state_threshold =
kNoStateId,
330 Prune(ifst, ofst, opts);
335 #endif // FST_PRUNE_H_
bool operator()(const StateId x, const StateId y) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
typename Arc::StateId StateId
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
typename Arc::Weight Weight
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
virtual StateId Start() const =0
virtual void AddArc(StateId, const Arc &)=0
virtual const SymbolTable * InputSymbols() const =0
const std::vector< Weight > * distance
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
virtual StateId NumStates() const =0
PruneCompare(const std::vector< Weight > &idistance, const std::vector< Weight > &fdistance)
PruneOptions(const Weight &weight_threshold=Weight::Zero(), StateId state_threshold=kNoStateId, ArcFilter filter=ArcFilter(), std::vector< Weight > *distance=nullptr, float delta=kDelta, bool threshold_initial=false)
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
virtual const SymbolTable * OutputSymbols() const =0