20 #ifndef FST_SHORTEST_DISTANCE_H_ 21 #define FST_SHORTEST_DISTANCE_H_ 45 template <
class Arc,
class Queue,
class ArcFilter>
65 float delta = kShortestDelta,
bool first_path =
false)
66 : state_queue(state_queue),
67 arc_filter(arc_filter),
70 first_path(first_path) {}
82 template <
class Arc,
class Queue,
class ArcFilter,
96 weight_equal_(opts.
delta),
103 distance_->reserve(*num_states);
104 adder_.reserve(*num_states);
105 radder_.reserve(*num_states);
106 enqueued_.reserve(*num_states);
112 bool Error()
const {
return error_; }
115 void EnsureDistanceIndexIsValid(std::size_t index) {
116 while (distance_->size() <= index) {
117 distance_->push_back(Weight::Zero());
120 enqueued_.push_back(
false);
125 void EnsureSourcesIndexIsValid(std::size_t index) {
126 while (sources_.size() <= index) {
133 std::vector<Weight> *distance_;
135 ArcFilter arc_filter_;
136 WeightEqual weight_equal_;
137 const bool first_path_;
140 std::vector<Adder<Weight>> adder_;
141 std::vector<Adder<Weight>> radder_;
142 std::vector<bool> enqueued_;
143 std::vector<StateId> sources_;
151 template <
class Arc,
class Queue,
class ArcFilter,
class WeightEqual>
155 if (fst_.Properties(
kError,
false)) error_ =
true;
159 FSTERROR() <<
"ShortestDistance: Weight needs to be right distributive: " 164 if (first_path_ && !(Weight::Properties() &
kPath)) {
165 FSTERROR() <<
"ShortestDistance: The first_path option is disallowed when " 166 <<
"Weight does not have the path property: " << Weight::Type();
170 state_queue_->Clear();
177 if (source ==
kNoStateId) source = fst_.Start();
178 EnsureDistanceIndexIsValid(source);
180 EnsureSourcesIndexIsValid(source);
181 sources_[
source] = source_id_;
183 (*distance_)[
source] = Weight::One();
184 adder_[
source].Reset(Weight::One());
185 radder_[
source].Reset(Weight::One());
187 state_queue_->Enqueue(source);
188 while (!state_queue_->Empty()) {
189 const auto state = state_queue_->Head();
190 state_queue_->Dequeue();
191 EnsureDistanceIndexIsValid(state);
192 if (first_path_ && (fst_.Final(state) != Weight::Zero()))
break;
193 enqueued_[state] =
false;
194 const auto r = radder_[state].Sum();
195 radder_[state].Reset();
198 const auto &arc = aiter.Value();
199 const auto nextstate = arc.nextstate;
200 if (!arc_filter_(arc))
continue;
201 EnsureDistanceIndexIsValid(nextstate);
203 EnsureSourcesIndexIsValid(nextstate);
204 if (sources_[nextstate] != source_id_) {
205 (*distance_)[nextstate] = Weight::Zero();
206 adder_[nextstate].Reset();
207 radder_[nextstate].Reset();
208 enqueued_[nextstate] =
false;
209 sources_[nextstate] = source_id_;
212 auto &nd = (*distance_)[nextstate];
213 auto &na = adder_[nextstate];
214 auto &nr = radder_[nextstate];
215 auto weight =
Times(r, arc.weight);
216 if (!weight_equal_(nd,
Plus(nd, weight))) {
219 if (!nd.Member() || !nr.Sum().Member()) {
223 if (!enqueued_[nextstate]) {
224 state_queue_->Enqueue(nextstate);
225 enqueued_[nextstate] =
true;
227 state_queue_->Update(nextstate);
233 if (fst_.Properties(
kError,
false)) error_ =
true;
261 template <
class Arc,
class Queue,
class ArcFilter>
263 const Fst<Arc> &
fst, std::vector<typename Arc::Weight> *distance,
268 if (sd_state.
Error()) {
269 distance->assign(1, Arc::Weight::NoWeight());
303 std::vector<typename Arc::Weight> *distance,
304 bool reverse =
false,
float delta = kShortestDelta) {
305 using StateId =
typename Arc::StateId;
318 std::vector<ReverseWeight> rdistance;
325 if (rdistance.size() == 1 && !rdistance[0].Member()) {
326 distance->assign(1, Arc::Weight::NoWeight());
330 distance->reserve(rdistance.size() - 1);
331 while (distance->size() < rdistance.size() - 1) {
332 distance->push_back(rdistance[distance->size() + 1].Reverse());
342 float delta = kShortestDelta) {
343 using StateId =
typename Arc::StateId;
344 using Weight =
typename Arc::Weight;
345 std::vector<Weight> distance;
348 if (distance.size() == 1 && !distance[0].Member()) {
349 return Arc::Weight::NoWeight();
352 for (
StateId state = 0; state < distance.size(); ++state) {
358 const auto state = fst.
Start();
359 if (distance.size() == 1 && !distance[0].Member()) {
360 return Arc::Weight::NoWeight();
362 return state !=
kNoStateId && state < distance.size() ? distance[state]
369 #endif // FST_SHORTEST_DISTANCE_H_
void ShortestDistance(StateId source)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
virtual Weight Final(StateId) const =0
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
constexpr uint64_t kRightSemiring
Weight Add(const Weight &w)
typename Arc::StateId StateId
virtual StateId Start() const =0
constexpr float kShortestDelta
virtual std::optional< StateId > NumStatesIfKnown() const
typename Arc::StateId StateId
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
ShortestDistanceState(const Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter, StateId source=kNoStateId, float delta=kShortestDelta, bool first_path=false)
A::Weight::ReverseWeight Weight
typename Arc::Weight Weight