FST  openfst-1.8.2.post1
OpenFst Library
shortest-path.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions to find shortest paths in a PDT.
19 
20 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
21 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
22 
23 #include <cstdint>
24 #include <stack>
25 #include <unordered_map>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/log.h>
31 #include <fst/extensions/pdt/pdt.h>
32 #include <fst/shortest-path.h>
33 #include <unordered_map>
34 
35 namespace fst {
36 
37 template <class Arc, class Queue>
40  bool path_gc;
41 
42  explicit PdtShortestPathOptions(bool keep_parentheses = false,
43  bool path_gc = true)
44  : keep_parentheses(keep_parentheses), path_gc(path_gc) {}
45 };
46 
47 namespace internal {
48 
49 // Flags for shortest path data.
50 
51 inline constexpr uint8_t kPdtInited = 0x01;
52 inline constexpr uint8_t kPdtFinal = 0x02;
53 inline constexpr uint8_t kPdtMarked = 0x04;
54 
55 // Stores shortest path tree info Distance(), Parent(), and ArcParent()
56 // information keyed on two types:
57 //
58 // 1. SearchState: This is a usual node in a shortest path tree but:
59 // a. is w.r.t a PDT search state (a pair of a PDT state and a "start" state,
60 // either the PDT start state or the destination state of an open
61 // parenthesis).
62 // b. the Distance() is from this "start" state to the search state.
63 // c. Parent().state is kNoLabel for the "start" state.
64 //
65 // 2. ParenSpec: This connects shortest path trees depending on the the
66 // parenthesis taken. Given the parenthesis spec:
67 // a. the Distance() is from the Parent() "start" state to the parenthesis
68 // destination state.
69 // b. The ArcParent() is the parenthesis arc.
70 template <class Arc>
72  public:
73  using Label = typename Arc::Label;
74  using StateId = typename Arc::StateId;
75  using Weight = typename Arc::Weight;
76 
77  struct SearchState {
78  StateId state; // PDT state.
79  StateId start; // PDT paren "start" state.
80 
82  : state(s), start(t) {}
83 
84  bool operator==(const SearchState &other) const {
85  if (&other == this) return true;
86  return other.state == state && other.start == start;
87  }
88  };
89 
90  // Specifies paren ID, source and dest "start" states of a paren. These are
91  // the "start" states of the respective sub-graphs.
92  struct ParenSpec {
93  explicit ParenSpec(Label paren_id = kNoLabel,
94  StateId src_start = kNoStateId,
95  StateId dest_start = kNoStateId)
96  : paren_id(paren_id), src_start(src_start), dest_start(dest_start) {}
97 
99  StateId src_start; // Sub-graph "start" state for paren source.
100  StateId dest_start; // Sub-graph "start" state for paren dest.
101 
102  bool operator==(const ParenSpec &other) const {
103  if (&other == this) return true;
104  return (other.paren_id == paren_id &&
105  other.src_start == other.src_start &&
106  other.dest_start == dest_start);
107  }
108  };
109 
110  struct SearchData {
112  : distance(Weight::Zero()),
113  parent(kNoStateId, kNoStateId),
114  paren_id(kNoLabel),
115  flags(0) {}
116 
117  Weight distance; // Distance to this state from PDT "start" state.
118  SearchState parent; // Parent state in shortest path tree.
119  int16_t paren_id; // If parent arc has paren, paren ID (or kNoLabel).
120  uint8_t flags; // First byte reserved for PdtShortestPathData use.
121  };
122 
123  explicit PdtShortestPathData(bool gc)
124  : gc_(gc), nstates_(0), ngc_(0), finished_(false) {}
125 
127  VLOG(1) << "opm size: " << paren_map_.size();
128  VLOG(1) << "# of search states: " << nstates_;
129  if (gc_) VLOG(1) << "# of GC'd search states: " << ngc_;
130  }
131 
132  void Clear() {
133  search_map_.clear();
134  search_multimap_.clear();
135  paren_map_.clear();
136  state_ = SearchState(kNoStateId, kNoStateId);
137  nstates_ = 0;
138  ngc_ = 0;
139  }
140 
141  // TODO(kbg): Currently copying SearchState and passing a const reference to
142  // ParenSpec. Benchmark to confirm this is the right thing to do.
143 
144  Weight Distance(SearchState s) const { return GetSearchData(s)->distance; }
145 
146  Weight Distance(const ParenSpec &paren) const {
147  return GetSearchData(paren)->distance;
148  }
149 
150  SearchState Parent(SearchState s) const { return GetSearchData(s)->parent; }
151 
152  SearchState Parent(const ParenSpec &paren) const {
153  return GetSearchData(paren)->parent;
154  }
155 
156  Label ParenId(SearchState s) const { return GetSearchData(s)->paren_id; }
157 
158  uint8_t Flags(SearchState s) const { return GetSearchData(s)->flags; }
159 
160  void SetDistance(SearchState s, Weight weight) {
161  GetSearchData(s)->distance = std::move(weight);
162  }
163 
164  void SetDistance(const ParenSpec &paren, Weight weight) {
165  GetSearchData(paren)->distance = std::move(weight);
166  }
167 
168  void SetParent(SearchState s, SearchState p) { GetSearchData(s)->parent = p; }
169 
170  void SetParent(const ParenSpec &paren, SearchState p) {
171  GetSearchData(paren)->parent = p;
172  }
173 
175  if (p >= 32768) {
176  FSTERROR() << "PdtShortestPathData: Paren ID does not fit in an int16_t";
177  }
178  GetSearchData(s)->paren_id = p;
179  }
180 
181  void SetFlags(SearchState s, uint8_t f, uint8_t mask) {
182  auto *data = GetSearchData(s);
183  data->flags &= ~mask;
184  data->flags |= f & mask;
185  }
186 
187  void GC(StateId s);
188 
189  void Finish() { finished_ = true; }
190 
191  private:
192  // Hash for search state.
193  struct SearchStateHash {
194  size_t operator()(const SearchState &s) const {
195  static constexpr auto prime = 7853;
196  return s.state + s.start * prime;
197  }
198  };
199 
200  // Hash for paren map.
201  struct ParenHash {
202  size_t operator()(const ParenSpec &paren) const {
203  static constexpr auto prime0 = 7853;
204  static constexpr auto prime1 = 7867;
205  return paren.paren_id + paren.src_start * prime0 +
206  paren.dest_start * prime1;
207  }
208  };
209 
210  using SearchMap =
211  std::unordered_map<SearchState, SearchData, SearchStateHash>;
212 
213  using SearchMultimap = std::unordered_multimap<StateId, StateId>;
214 
215  // Hash map from paren spec to open paren data.
216  using ParenMap = std::unordered_map<ParenSpec, SearchData, ParenHash>;
217 
218  SearchData *GetSearchData(SearchState s) const {
219  if (s == state_) return state_data_;
220  if (finished_) {
221  auto it = search_map_.find(s);
222  if (it == search_map_.end()) return &null_search_data_;
223  state_ = s;
224  return state_data_ = &(it->second);
225  } else {
226  state_ = s;
227  state_data_ = &search_map_[s];
228  if (!(state_data_->flags & kPdtInited)) {
229  ++nstates_;
230  if (gc_) search_multimap_.insert(std::make_pair(s.start, s.state));
231  state_data_->flags = kPdtInited;
232  }
233  return state_data_;
234  }
235  }
236 
237  SearchData *GetSearchData(ParenSpec paren) const {
238  if (paren == paren_) return paren_data_;
239  if (finished_) {
240  auto it = paren_map_.find(paren);
241  if (it == paren_map_.end()) return &null_search_data_;
242  paren_ = paren;
243  return state_data_ = &(it->second);
244  } else {
245  paren_ = paren;
246  return paren_data_ = &paren_map_[paren];
247  }
248  }
249 
250  mutable SearchMap search_map_; // Maps from search state to data.
251  mutable SearchMultimap search_multimap_; // Maps from "start" to subgraph.
252  mutable ParenMap paren_map_; // Maps paren spec to search data.
253  mutable SearchState state_; // Last state accessed.
254  mutable SearchData *state_data_; // Last state data accessed.
255  mutable ParenSpec paren_; // Last paren spec accessed.
256  mutable SearchData *paren_data_; // Last paren data accessed.
257  bool gc_; // Allow GC?
258  mutable size_t nstates_; // Total number of search states.
259  size_t ngc_; // Number of GC'd search states.
260  mutable SearchData null_search_data_; // Null search data.
261  bool finished_; // Read-only access when true.
262 
263  PdtShortestPathData(const PdtShortestPathData &) = delete;
264  PdtShortestPathData &operator=(const PdtShortestPathData &) = delete;
265 };
266 
267 // Deletes inaccessible search data from a given "start" (open paren dest)
268 // state. Assumes "final" (close paren source or PDT final) states have
269 // been flagged kPdtFinal.
270 template <class Arc>
272  if (!gc_) return;
273  std::vector<StateId> finals;
274  for (auto it = search_multimap_.find(start);
275  it != search_multimap_.end() && it->first == start; ++it) {
276  const SearchState s(it->second, start);
277  if (search_map_[s].flags & kPdtFinal) finals.push_back(s.state);
278  }
279  // Mark phase.
280  for (const auto state : finals) {
281  SearchState ss(state, start);
282  while (ss.state != kNoLabel) {
283  auto &sdata = search_map_[ss];
284  if (sdata.flags & kPdtMarked) break;
285  sdata.flags |= kPdtMarked;
286  const auto p = sdata.parent;
287  if (p.start != start && p.start != kNoLabel) { // Entering sub-subgraph.
288  const ParenSpec paren(sdata.paren_id, ss.start, p.start);
289  ss = paren_map_[paren].parent;
290  } else {
291  ss = p;
292  }
293  }
294  }
295  // Sweep phase.
296  auto it = search_multimap_.find(start);
297  while (it != search_multimap_.end() && it->first == start) {
298  const SearchState s(it->second, start);
299  auto mit = search_map_.find(s);
300  const SearchData &data = mit->second;
301  if (!(data.flags & kPdtMarked)) {
302  search_map_.erase(mit);
303  ++ngc_;
304  }
305  search_multimap_.erase(it++);
306  }
307 }
308 
309 } // namespace internal
310 
311 // This computes the single source shortest (balanced) path (SSSP) through a
312 // weighted PDT that has a bounded stack (i.e., is expandable as an FST). It is
313 // a generalization of the classic SSSP graph algorithm that removes a state s
314 // from a queue (defined by a user-provided queue type) and relaxes the
315 // destination states of transitions leaving s. In this PDT version, states that
316 // have entering open parentheses are treated as source states for a sub-graph
317 // SSSP problem with the shortest path up to the open parenthesis being first
318 // saved. When a close parenthesis is then encountered any balancing open
319 // parenthesis is examined for this saved information and multiplied back. In
320 // this way, each sub-graph is entered only once rather than repeatedly. If
321 // every state in the input PDT has the property that there is a unique "start"
322 // state for it with entering open parentheses, then this algorithm is quite
323 // straightforward. In general, this will not be the case, so the algorithm
324 // (implicitly) creates a new graph where each state is a pair of an original
325 // state and a possible parenthesis "start" state for that state.
326 template <class Arc, class Queue>
328  public:
329  using Label = typename Arc::Label;
330  using StateId = typename Arc::StateId;
331  using Weight = typename Arc::Weight;
332 
335  using ParenSpec = typename SpData::ParenSpec;
336  using CloseSourceIterator =
338 
340  const std::vector<std::pair<Label, Label>> &parens,
342  : ifst_(ifst.Copy()),
343  parens_(parens),
344  keep_parens_(opts.keep_parentheses),
345  start_(ifst.Start()),
346  sp_data_(opts.path_gc),
347  error_(false) {
348  // TODO(kbg): Make this a compile-time static_assert once:
349  // 1) All weight properties are made constexpr for all weight types.
350  // 2) We have a pleasant way to "deregister" this oepration for non-path
351  // semirings so an informative error message is produced. The best
352  // solution will probably involve some kind of SFINAE magic.
353  if ((Weight::Properties() & (kPath | kRightSemiring)) !=
354  (kPath | kRightSemiring)) {
355  FSTERROR() << "PdtShortestPath: Weight needs to have the path"
356  << " property and be right distributive: " << Weight::Type();
357  error_ = true;
358  }
359  for (Label i = 0; i < parens.size(); ++i) {
360  const auto &pair = parens[i];
361  paren_map_[pair.first] = i;
362  paren_map_[pair.second] = i;
363  }
364  }
365 
367  VLOG(1) << "# of input states: " << CountStates(*ifst_);
368  VLOG(1) << "# of enqueued: " << nenqueued_;
369  VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
370  }
371 
373  Init(ofst);
374  GetDistance(start_);
375  GetPath();
376  sp_data_.Finish();
377  if (error_) ofst->SetProperties(kError, kError);
378  }
379 
381  return sp_data_;
382  }
383 
384  internal::PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
385 
386  public:
387  // Hash multimap from close paren label to an paren arc.
388  using CloseParenMultimap =
389  std::unordered_multimap<internal::ParenState<Arc>, Arc,
391 
393  return close_paren_multimap_;
394  }
395 
396  private:
397  void Init(MutableFst<Arc> *ofst);
398 
399  void GetDistance(StateId start);
400 
401  void ProcFinal(SearchState s);
402 
403  void ProcArcs(SearchState s);
404 
405  void ProcOpenParen(Label paren_id, SearchState s, StateId nexstate,
406  const Weight &weight);
407 
408  void ProcCloseParen(Label paren_id, SearchState s, const Weight &weight);
409 
410  void ProcNonParen(SearchState s, StateId nextstate, const Weight &weight);
411 
412  void Relax(SearchState s, SearchState t, StateId nextstate,
413  const Weight &weight, Label paren_id);
414 
415  void Enqueue(SearchState d);
416 
417  void GetPath();
418 
419  Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
420 
421  std::unique_ptr<Fst<Arc>> ifst_;
422  MutableFst<Arc> *ofst_;
423  const std::vector<std::pair<Label, Label>> &parens_;
424  bool keep_parens_;
425  Queue *state_queue_;
426  StateId start_;
427  Weight fdistance_;
428  SearchState f_parent_;
429  SpData sp_data_;
430  std::unordered_map<Label, Label> paren_map_;
431  CloseParenMultimap close_paren_multimap_;
432  internal::PdtBalanceData<Arc> balance_data_;
433  ssize_t nenqueued_;
434  bool error_;
435 
436  static constexpr uint8_t kEnqueued = 0x10;
437  static constexpr uint8_t kExpanded = 0x20;
438  static constexpr uint8_t kFinished = 0x40;
439 
440  static const Arc kNoArc;
441 };
442 
443 template <class Arc, class Queue>
445  ofst_ = ofst;
446  ofst->DeleteStates();
447  ofst->SetInputSymbols(ifst_->InputSymbols());
448  ofst->SetOutputSymbols(ifst_->OutputSymbols());
449  if (ifst_->Start() == kNoStateId) return;
450  fdistance_ = Weight::Zero();
451  f_parent_ = SearchState(kNoStateId, kNoStateId);
452  sp_data_.Clear();
453  close_paren_multimap_.clear();
454  balance_data_.Clear();
455  nenqueued_ = 0;
456  // Finds open parens per destination state and close parens per source state.
457  for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
458  const auto s = siter.Value();
459  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
460  const auto &arc = aiter.Value();
461  const auto it = paren_map_.find(arc.ilabel);
462  if (it != paren_map_.end()) { // Is a paren?
463  const auto paren_id = it->second;
464  if (arc.ilabel == parens_[paren_id].first) { // Open paren.
465  balance_data_.OpenInsert(paren_id, arc.nextstate);
466  } else { // Close paren.
467  const internal::ParenState<Arc> paren_state(paren_id, s);
468  close_paren_multimap_.emplace(paren_state, arc);
469  }
470  }
471  }
472  }
473 }
474 
475 // Computes the shortest distance stored in a recursive way. Each sub-graph
476 // (i.e., different paren "start" state) begins with weight One().
477 template <class Arc, class Queue>
479  if (start == kNoStateId) return;
480  Queue state_queue;
481  state_queue_ = &state_queue;
482  const SearchState q(start, start);
483  Enqueue(q);
484  sp_data_.SetDistance(q, Weight::One());
485  while (!state_queue_->Empty()) {
486  const auto state = state_queue_->Head();
487  state_queue_->Dequeue();
488  const SearchState s(state, start);
489  sp_data_.SetFlags(s, 0, kEnqueued);
490  ProcFinal(s);
491  ProcArcs(s);
492  sp_data_.SetFlags(s, kExpanded, kExpanded);
493  }
494  sp_data_.SetFlags(q, kFinished, kFinished);
495  balance_data_.FinishInsert(start);
496  sp_data_.GC(start);
497 }
498 
499 // Updates best complete path.
500 template <class Arc, class Queue>
502  if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
503  const auto weight = Times(sp_data_.Distance(s), ifst_->Final(s.state));
504  if (fdistance_ != Plus(fdistance_, weight)) {
505  if (f_parent_.state != kNoStateId) {
506  sp_data_.SetFlags(f_parent_, 0, internal::kPdtFinal);
507  }
508  sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
509  fdistance_ = Plus(fdistance_, weight);
510  f_parent_ = s;
511  }
512  }
513 }
514 
515 // Processes all arcs leaving the state s.
516 template <class Arc, class Queue>
518  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
519  aiter.Next()) {
520  const auto &arc = aiter.Value();
521  const auto weight = Times(sp_data_.Distance(s), arc.weight);
522  const auto it = paren_map_.find(arc.ilabel);
523  if (it != paren_map_.end()) { // Is a paren?
524  const auto paren_id = it->second;
525  if (arc.ilabel == parens_[paren_id].first) {
526  ProcOpenParen(paren_id, s, arc.nextstate, weight);
527  } else {
528  ProcCloseParen(paren_id, s, weight);
529  }
530  } else {
531  ProcNonParen(s, arc.nextstate, weight);
532  }
533  }
534 }
535 
536 // Saves the shortest path info for reaching this parenthesis and starts a new
537 // SSSP in the sub-graph pointed to by the parenthesis if previously unvisited.
538 // Otherwise it finds any previously encountered closing parentheses and relaxes
539 // them using the recursively stored shortest distance to them.
540 template <class Arc, class Queue>
542  SearchState s,
543  StateId nextstate,
544  const Weight &weight) {
545  const SearchState d(nextstate, nextstate);
546  const ParenSpec paren(paren_id, s.start, d.start);
547  const auto pdist = sp_data_.Distance(paren);
548  if (pdist != Plus(pdist, weight)) {
549  sp_data_.SetDistance(paren, weight);
550  sp_data_.SetParent(paren, s);
551  const auto dist = sp_data_.Distance(d);
552  if (dist == Weight::Zero()) {
553  auto *state_queue = state_queue_;
554  GetDistance(d.start);
555  state_queue_ = state_queue;
556  } else if (!(sp_data_.Flags(d) & kFinished)) {
557  FSTERROR()
558  << "PdtShortestPath: open parenthesis recursion: not bounded stack";
559  error_ = true;
560  }
561  for (auto set_iter = balance_data_.Find(paren_id, nextstate);
562  !set_iter.Done(); set_iter.Next()) {
563  const SearchState cpstate(set_iter.Element(), d.start);
564  const internal::ParenState<Arc> paren_state(paren_id, cpstate.state);
565  for (auto cpit = close_paren_multimap_.find(paren_state);
566  cpit != close_paren_multimap_.end() && paren_state == cpit->first;
567  ++cpit) {
568  const auto &cparc = cpit->second;
569  const auto cpw =
570  Times(weight, Times(sp_data_.Distance(cpstate), cparc.weight));
571  Relax(cpstate, s, cparc.nextstate, cpw, paren_id);
572  }
573  }
574  }
575 }
576 
577 // Saves the correspondence between each closing parenthesis and its balancing
578 // open parenthesis info. Relaxes any close parenthesis destination state that
579 // has a balancing previously encountered open parenthesis.
580 template <class Arc, class Queue>
582  SearchState s,
583  const Weight &weight) {
584  const internal::ParenState<Arc> paren_state(paren_id, s.start);
585  if (!(sp_data_.Flags(s) & kExpanded)) {
586  balance_data_.CloseInsert(paren_id, s.start, s.state);
587  sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
588  }
589 }
590 
591 // Classical relaxation for non-parentheses.
592 template <class Arc, class Queue>
594  StateId nextstate,
595  const Weight &weight) {
596  Relax(s, s, nextstate, weight, kNoLabel);
597 }
598 
599 // Classical relaxation on the search graph for an arc with destination state
600 // nexstate from state s. State t is in the same sub-graph as nextstate (i.e.,
601 // has the same paren "start").
602 template <class Arc, class Queue>
604  StateId nextstate,
605  const Weight &weight,
606  Label paren_id) {
607  const SearchState d(nextstate, t.start);
608  Weight dist = sp_data_.Distance(d);
609  if (dist != Plus(dist, weight)) {
610  sp_data_.SetParent(d, s);
611  sp_data_.SetParenId(d, paren_id);
612  sp_data_.SetDistance(d, Plus(dist, weight));
613  Enqueue(d);
614  }
615 }
616 
617 template <class Arc, class Queue>
619  if (!(sp_data_.Flags(s) & kEnqueued)) {
620  state_queue_->Enqueue(s.state);
621  sp_data_.SetFlags(s, kEnqueued, kEnqueued);
622  ++nenqueued_;
623  } else {
624  state_queue_->Update(s.state);
625  }
626 }
627 
628 // Follows parent pointers to find the shortest path. A stack is used since the
629 // shortest distance is stored recursively.
630 template <class Arc, class Queue>
632  SearchState s = f_parent_;
634  StateId s_p = kNoStateId;
635  StateId d_p = kNoStateId;
636  auto arc = kNoArc;
637  Label paren_id = kNoLabel;
638  std::stack<ParenSpec> paren_stack;
639  while (s.state != kNoStateId) {
640  d_p = s_p;
641  s_p = ofst_->AddState();
642  if (d.state == kNoStateId) {
643  ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
644  } else {
645  if (paren_id != kNoLabel) { // Paren?
646  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
647  paren_stack.pop();
648  } else { // Close paren?
649  const ParenSpec paren(paren_id, d.start, s.start);
650  paren_stack.push(paren);
651  }
652  if (!keep_parens_) arc.ilabel = arc.olabel = 0;
653  }
654  arc.nextstate = d_p;
655  ofst_->AddArc(s_p, arc);
656  }
657  d = s;
658  s = sp_data_.Parent(d);
659  paren_id = sp_data_.ParenId(d);
660  if (s.state != kNoStateId) {
661  arc = GetPathArc(s, d, paren_id, false);
662  } else if (!paren_stack.empty()) {
663  const ParenSpec paren = paren_stack.top();
664  s = sp_data_.Parent(paren);
665  paren_id = paren.paren_id;
666  arc = GetPathArc(s, d, paren_id, true);
667  }
668  }
669  ofst_->SetStart(s_p);
670  ofst_->SetProperties(
671  ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
673 }
674 
675 // Finds transition with least weight between two states with label matching
676 // paren_id and open/close paren type or a non-paren if kNoLabel.
677 template <class Arc, class Queue>
679  Label paren_id, bool open_paren) {
680  auto path_arc = kNoArc;
681  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
682  aiter.Next()) {
683  const auto &arc = aiter.Value();
684  if (arc.nextstate != d.state) continue;
685  Label arc_paren_id = kNoLabel;
686  const auto it = paren_map_.find(arc.ilabel);
687  if (it != paren_map_.end()) {
688  arc_paren_id = it->second;
689  bool arc_open_paren = (arc.ilabel == parens_[arc_paren_id].first);
690  if (arc_open_paren != open_paren) continue;
691  }
692  if (arc_paren_id != paren_id) continue;
693  if (arc.weight == Plus(arc.weight, path_arc.weight)) path_arc = arc;
694  }
695  if (path_arc.nextstate == kNoStateId) {
696  FSTERROR() << "PdtShortestPath::GetPathArc: Failed to find arc";
697  error_ = true;
698  }
699  return path_arc;
700 }
701 
702 template <class Arc, class Queue>
704  Weight::Zero(), kNoStateId);
705 
706 // Functional variants.
707 
708 template <class Arc, class Queue>
710  const Fst<Arc> &ifst,
711  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
712  &parens,
714  PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
715  psp.ShortestPath(ofst);
716 }
717 
718 template <class Arc>
720  const Fst<Arc> &ifst,
721  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
722  &parens,
723  MutableFst<Arc> *ofst) {
726  PdtShortestPath<Arc, Q> psp(ifst, parens, opts);
727  psp.ShortestPath(ofst);
728 }
729 
730 } // namespace fst
731 
732 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
bool operator==(const ParenSpec &other) const
internal::PdtBalanceData< Arc > * GetBalanceData()
constexpr int kNoLabel
Definition: fst.h:201
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
void SetFlags(SearchState s, uint8_t f, uint8_t mask)
uint8_t Flags(SearchState s) const
SearchState Parent(SearchState s) const
uint64_t ShortestPathProperties(uint64_t props, bool tree=false)
Definition: properties.cc:374
typename Arc::StateId StateId
Definition: shortest-path.h:74
void SetDistance(SearchState s, Weight weight)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
SearchState Parent(const ParenSpec &paren) const
constexpr uint64_t kError
Definition: properties.h:51
constexpr uint8_t kPdtFinal
Definition: shortest-path.h:52
virtual void SetInputSymbols(const SymbolTable *isyms)=0
ParenSpec(Label paren_id=kNoLabel, StateId src_start=kNoStateId, StateId dest_start=kNoStateId)
Definition: shortest-path.h:93
std::unordered_multimap< internal::ParenState< Arc >, Arc, typename internal::ParenState< Arc >::Hash > CloseParenMultimap
constexpr int kNoStateId
Definition: fst.h:202
void SetParent(SearchState s, SearchState p)
constexpr uint64_t kRightSemiring
Definition: weight.h:136
#define FSTERROR()
Definition: util.h:53
SearchState(StateId s=kNoStateId, StateId t=kNoStateId)
Definition: shortest-path.h:81
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:354
virtual void SetProperties(uint64_t props, uint64_t mask)=0
PdtShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, const PdtShortestPathOptions< Arc, Queue > &opts)
#define VLOG(level)
Definition: log.h:50
const CloseParenMultimap & GetCloseParenMultimap() const
Weight Distance(const ParenSpec &paren) const
void SetParenId(SearchState s, Label p)
PdtShortestPathOptions(bool keep_parentheses=false, bool path_gc=true)
Definition: shortest-path.h:42
Label ParenId(SearchState s) const
constexpr uint8_t kPdtInited
Definition: shortest-path.h:51
void SetDistance(const ParenSpec &paren, Weight weight)
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
constexpr size_t kNoArc
Definition: shortest-path.h:78
constexpr uint64_t kPath
Definition: weight.h:147
constexpr uint64_t kFstProperties
Definition: properties.h:325
void ShortestPath(MutableFst< Arc > *ofst)
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:169
bool operator==(const SearchState &other) const
Definition: shortest-path.h:84
virtual void DeleteStates(const std::vector< StateId > &)=0
typename internal::PdtBalanceData< Arc >::SetIterator CloseSourceIterator
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
Weight Distance(SearchState s) const
const internal::PdtShortestPathData< Arc > & GetShortestPathData() const
constexpr uint64_t kExpanded
Definition: properties.h:45
void SetParent(const ParenSpec &paren, SearchState p)
constexpr uint8_t kPdtMarked
Definition: shortest-path.h:53