FST  openfst-1.8.3
OpenFst Library
shortest-path.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions to find shortest paths in a PDT.
19 
20 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
21 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
22 
23 #include <sys/types.h>
24 
25 #include <cstddef>
26 #include <cstdint>
27 #include <memory>
28 #include <stack>
29 #include <unordered_map>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/log.h>
35 #include <fst/extensions/pdt/pdt.h>
36 #include <fst/arc.h>
37 #include <fst/expanded-fst.h>
38 #include <fst/float-weight.h>
39 #include <fst/fst.h>
40 #include <fst/mutable-fst.h>
41 #include <fst/properties.h>
42 #include <fst/queue.h>
43 #include <fst/shortest-path.h>
44 #include <fst/util.h>
45 #include <fst/weight.h>
46 #include <unordered_map>
47 
48 namespace fst {
49 
50 template <class Arc, class Queue>
53  bool path_gc;
54 
55  explicit PdtShortestPathOptions(bool keep_parentheses = false,
56  bool path_gc = true)
57  : keep_parentheses(keep_parentheses), path_gc(path_gc) {}
58 };
59 
60 namespace internal {
61 
62 // Flags for shortest path data.
63 
64 inline constexpr uint8_t kPdtInited = 0x01;
65 inline constexpr uint8_t kPdtFinal = 0x02;
66 inline constexpr uint8_t kPdtMarked = 0x04;
67 
68 // Stores shortest path tree info Distance(), Parent(), and ArcParent()
69 // information keyed on two types:
70 //
71 // 1. SearchState: This is a usual node in a shortest path tree but:
72 // a. is w.r.t a PDT search state (a pair of a PDT state and a "start" state,
73 // either the PDT start state or the destination state of an open
74 // parenthesis).
75 // b. the Distance() is from this "start" state to the search state.
76 // c. Parent().state is kNoLabel for the "start" state.
77 //
78 // 2. ParenSpec: This connects shortest path trees depending on the the
79 // parenthesis taken. Given the parenthesis spec:
80 // a. the Distance() is from the Parent() "start" state to the parenthesis
81 // destination state.
82 // b. The ArcParent() is the parenthesis arc.
83 template <class Arc>
85  public:
86  using Label = typename Arc::Label;
87  using StateId = typename Arc::StateId;
88  using Weight = typename Arc::Weight;
89 
90  struct SearchState {
91  StateId state; // PDT state.
92  StateId start; // PDT paren "start" state.
93 
95  : state(s), start(t) {}
96 
97  bool operator==(const SearchState &other) const {
98  if (&other == this) return true;
99  return other.state == state && other.start == start;
100  }
101  };
102 
103  // Specifies paren ID, source and dest "start" states of a paren. These are
104  // the "start" states of the respective sub-graphs.
105  struct ParenSpec {
106  explicit ParenSpec(Label paren_id = kNoLabel,
107  StateId src_start = kNoStateId,
108  StateId dest_start = kNoStateId)
109  : paren_id(paren_id), src_start(src_start), dest_start(dest_start) {}
110 
112  StateId src_start; // Sub-graph "start" state for paren source.
113  StateId dest_start; // Sub-graph "start" state for paren dest.
114 
115  bool operator==(const ParenSpec &other) const {
116  if (&other == this) return true;
117  return (other.paren_id == paren_id &&
118  other.src_start == other.src_start &&
119  other.dest_start == dest_start);
120  }
121  };
122 
123  struct SearchData {
125  : distance(Weight::Zero()),
126  parent(kNoStateId, kNoStateId),
127  paren_id(kNoLabel),
128  flags(0) {}
129 
130  Weight distance; // Distance to this state from PDT "start" state.
131  SearchState parent; // Parent state in shortest path tree.
132  int16_t paren_id; // If parent arc has paren, paren ID (or kNoLabel).
133  uint8_t flags; // First byte reserved for PdtShortestPathData use.
134  };
135 
136  explicit PdtShortestPathData(bool gc)
137  : gc_(gc), nstates_(0), ngc_(0), finished_(false) {}
138 
140  VLOG(1) << "opm size: " << paren_map_.size();
141  VLOG(1) << "# of search states: " << nstates_;
142  if (gc_) VLOG(1) << "# of GC'd search states: " << ngc_;
143  }
144 
145  void Clear() {
146  search_map_.clear();
147  search_multimap_.clear();
148  paren_map_.clear();
149  state_ = SearchState(kNoStateId, kNoStateId);
150  nstates_ = 0;
151  ngc_ = 0;
152  }
153 
154  // TODO(kbg): Currently copying SearchState and passing a const reference to
155  // ParenSpec. Benchmark to confirm this is the right thing to do.
156 
157  Weight Distance(SearchState s) const { return GetSearchData(s)->distance; }
158 
159  Weight Distance(const ParenSpec &paren) const {
160  return GetSearchData(paren)->distance;
161  }
162 
163  SearchState Parent(SearchState s) const { return GetSearchData(s)->parent; }
164 
165  SearchState Parent(const ParenSpec &paren) const {
166  return GetSearchData(paren)->parent;
167  }
168 
169  Label ParenId(SearchState s) const { return GetSearchData(s)->paren_id; }
170 
171  uint8_t Flags(SearchState s) const { return GetSearchData(s)->flags; }
172 
173  void SetDistance(SearchState s, Weight weight) {
174  GetSearchData(s)->distance = std::move(weight);
175  }
176 
177  void SetDistance(const ParenSpec &paren, Weight weight) {
178  GetSearchData(paren)->distance = std::move(weight);
179  }
180 
181  void SetParent(SearchState s, SearchState p) { GetSearchData(s)->parent = p; }
182 
183  void SetParent(const ParenSpec &paren, SearchState p) {
184  GetSearchData(paren)->parent = p;
185  }
186 
188  if (p >= 32768) {
189  FSTERROR() << "PdtShortestPathData: Paren ID does not fit in an int16_t";
190  }
191  GetSearchData(s)->paren_id = p;
192  }
193 
194  void SetFlags(SearchState s, uint8_t f, uint8_t mask) {
195  auto *data = GetSearchData(s);
196  data->flags &= ~mask;
197  data->flags |= f & mask;
198  }
199 
200  void GC(StateId s);
201 
202  void Finish() { finished_ = true; }
203 
204  private:
205  // Hash for search state.
206  struct SearchStateHash {
207  size_t operator()(const SearchState &s) const {
208  static constexpr auto prime = 7853;
209  return s.state + s.start * prime;
210  }
211  };
212 
213  // Hash for paren map.
214  struct ParenHash {
215  size_t operator()(const ParenSpec &paren) const {
216  static constexpr auto prime0 = 7853;
217  static constexpr auto prime1 = 7867;
218  return paren.paren_id + paren.src_start * prime0 +
219  paren.dest_start * prime1;
220  }
221  };
222 
223  using SearchMap =
224  std::unordered_map<SearchState, SearchData, SearchStateHash>;
225 
226  using SearchMultimap = std::unordered_multimap<StateId, StateId>;
227 
228  // Hash map from paren spec to open paren data.
229  using ParenMap = std::unordered_map<ParenSpec, SearchData, ParenHash>;
230 
231  SearchData *GetSearchData(SearchState s) const {
232  if (s == state_) return state_data_;
233  if (finished_) {
234  auto it = search_map_.find(s);
235  if (it == search_map_.end()) return &null_search_data_;
236  state_ = s;
237  return state_data_ = &(it->second);
238  } else {
239  state_ = s;
240  state_data_ = &search_map_[s];
241  if (!(state_data_->flags & kPdtInited)) {
242  ++nstates_;
243  if (gc_) search_multimap_.insert(std::make_pair(s.start, s.state));
244  state_data_->flags = kPdtInited;
245  }
246  return state_data_;
247  }
248  }
249 
250  SearchData *GetSearchData(ParenSpec paren) const {
251  if (paren == paren_) return paren_data_;
252  if (finished_) {
253  auto it = paren_map_.find(paren);
254  if (it == paren_map_.end()) return &null_search_data_;
255  paren_ = paren;
256  return state_data_ = &(it->second);
257  } else {
258  paren_ = paren;
259  return paren_data_ = &paren_map_[paren];
260  }
261  }
262 
263  mutable SearchMap search_map_; // Maps from search state to data.
264  mutable SearchMultimap search_multimap_; // Maps from "start" to subgraph.
265  mutable ParenMap paren_map_; // Maps paren spec to search data.
266  mutable SearchState state_; // Last state accessed.
267  mutable SearchData *state_data_; // Last state data accessed.
268  mutable ParenSpec paren_; // Last paren spec accessed.
269  mutable SearchData *paren_data_; // Last paren data accessed.
270  bool gc_; // Allow GC?
271  mutable size_t nstates_; // Total number of search states.
272  size_t ngc_; // Number of GC'd search states.
273  mutable SearchData null_search_data_; // Null search data.
274  bool finished_; // Read-only access when true.
275 
276  PdtShortestPathData(const PdtShortestPathData &) = delete;
277  PdtShortestPathData &operator=(const PdtShortestPathData &) = delete;
278 };
279 
280 // Deletes inaccessible search data from a given "start" (open paren dest)
281 // state. Assumes "final" (close paren source or PDT final) states have
282 // been flagged kPdtFinal.
283 template <class Arc>
285  if (!gc_) return;
286  std::vector<StateId> finals;
287  for (auto it = search_multimap_.find(start);
288  it != search_multimap_.end() && it->first == start; ++it) {
289  const SearchState s(it->second, start);
290  if (search_map_[s].flags & kPdtFinal) finals.push_back(s.state);
291  }
292  // Mark phase.
293  for (const auto state : finals) {
294  SearchState ss(state, start);
295  while (ss.state != kNoLabel) {
296  auto &sdata = search_map_[ss];
297  if (sdata.flags & kPdtMarked) break;
298  sdata.flags |= kPdtMarked;
299  const auto p = sdata.parent;
300  if (p.start != start && p.start != kNoLabel) { // Entering sub-subgraph.
301  const ParenSpec paren(sdata.paren_id, ss.start, p.start);
302  ss = paren_map_[paren].parent;
303  } else {
304  ss = p;
305  }
306  }
307  }
308  // Sweep phase.
309  auto it = search_multimap_.find(start);
310  while (it != search_multimap_.end() && it->first == start) {
311  const SearchState s(it->second, start);
312  auto mit = search_map_.find(s);
313  const SearchData &data = mit->second;
314  if (!(data.flags & kPdtMarked)) {
315  search_map_.erase(mit);
316  ++ngc_;
317  }
318  search_multimap_.erase(it++);
319  }
320 }
321 
322 } // namespace internal
323 
324 // This computes the single source shortest (balanced) path (SSSP) through a
325 // weighted PDT that has a bounded stack (i.e., is expandable as an FST). It is
326 // a generalization of the classic SSSP graph algorithm that removes a state s
327 // from a queue (defined by a user-provided queue type) and relaxes the
328 // destination states of transitions leaving s. In this PDT version, states that
329 // have entering open parentheses are treated as source states for a sub-graph
330 // SSSP problem with the shortest path up to the open parenthesis being first
331 // saved. When a close parenthesis is then encountered any balancing open
332 // parenthesis is examined for this saved information and multiplied back. In
333 // this way, each sub-graph is entered only once rather than repeatedly. If
334 // every state in the input PDT has the property that there is a unique "start"
335 // state for it with entering open parentheses, then this algorithm is quite
336 // straightforward. In general, this will not be the case, so the algorithm
337 // (implicitly) creates a new graph where each state is a pair of an original
338 // state and a possible parenthesis "start" state for that state.
339 template <class Arc, class Queue>
341  public:
342  using Label = typename Arc::Label;
343  using StateId = typename Arc::StateId;
344  using Weight = typename Arc::Weight;
345 
348  using ParenSpec = typename SpData::ParenSpec;
349  using CloseSourceIterator =
351 
353  const std::vector<std::pair<Label, Label>> &parens,
355  : ifst_(ifst.Copy()),
356  parens_(parens),
357  keep_parens_(opts.keep_parentheses),
358  start_(ifst.Start()),
359  sp_data_(opts.path_gc),
360  error_(false) {
361  // TODO(kbg): Make this a compile-time static_assert once:
362  // 1) All weight properties are made constexpr for all weight types.
363  // 2) We have a pleasant way to "deregister" this oepration for non-path
364  // semirings so an informative error message is produced. The best
365  // solution will probably involve some kind of SFINAE magic.
366  if ((Weight::Properties() & (kPath | kRightSemiring)) !=
367  (kPath | kRightSemiring)) {
368  FSTERROR() << "PdtShortestPath: Weight needs to have the path"
369  << " property and be right distributive: " << Weight::Type();
370  error_ = true;
371  }
372  for (Label i = 0; i < parens.size(); ++i) {
373  const auto &pair = parens[i];
374  paren_map_[pair.first] = i;
375  paren_map_[pair.second] = i;
376  }
377  }
378 
380  VLOG(1) << "# of input states: " << CountStates(*ifst_);
381  VLOG(1) << "# of enqueued: " << nenqueued_;
382  VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
383  }
384 
386  Init(ofst);
387  GetDistance(start_);
388  GetPath();
389  sp_data_.Finish();
390  if (error_) ofst->SetProperties(kError, kError);
391  }
392 
394  return sp_data_;
395  }
396 
397  internal::PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
398 
399  public:
400  // Hash multimap from close paren label to an paren arc.
401  using CloseParenMultimap =
402  std::unordered_multimap<internal::ParenState<Arc>, Arc,
404 
406  return close_paren_multimap_;
407  }
408 
409  private:
410  void Init(MutableFst<Arc> *ofst);
411 
412  void GetDistance(StateId start);
413 
414  void ProcFinal(SearchState s);
415 
416  void ProcArcs(SearchState s);
417 
418  void ProcOpenParen(Label paren_id, SearchState s, StateId nexstate,
419  const Weight &weight);
420 
421  void ProcCloseParen(Label paren_id, SearchState s, const Weight &weight);
422 
423  void ProcNonParen(SearchState s, StateId nextstate, const Weight &weight);
424 
425  void Relax(SearchState s, SearchState t, StateId nextstate,
426  const Weight &weight, Label paren_id);
427 
428  void Enqueue(SearchState d);
429 
430  void GetPath();
431 
432  Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
433 
434  std::unique_ptr<Fst<Arc>> ifst_;
435  MutableFst<Arc> *ofst_;
436  const std::vector<std::pair<Label, Label>> &parens_;
437  bool keep_parens_;
438  Queue *state_queue_;
439  StateId start_;
440  Weight fdistance_;
441  SearchState f_parent_;
442  SpData sp_data_;
443  std::unordered_map<Label, Label> paren_map_;
444  CloseParenMultimap close_paren_multimap_;
445  internal::PdtBalanceData<Arc> balance_data_;
446  ssize_t nenqueued_;
447  bool error_;
448 
449  static constexpr uint8_t kEnqueued = 0x10;
450  static constexpr uint8_t kExpanded = 0x20;
451  static constexpr uint8_t kFinished = 0x40;
452 
453  static const Arc kNoArc;
454 };
455 
456 template <class Arc, class Queue>
458  ofst_ = ofst;
459  ofst->DeleteStates();
460  ofst->SetInputSymbols(ifst_->InputSymbols());
461  ofst->SetOutputSymbols(ifst_->OutputSymbols());
462  if (ifst_->Start() == kNoStateId) return;
463  fdistance_ = Weight::Zero();
464  f_parent_ = SearchState(kNoStateId, kNoStateId);
465  sp_data_.Clear();
466  close_paren_multimap_.clear();
467  balance_data_.Clear();
468  nenqueued_ = 0;
469  // Finds open parens per destination state and close parens per source state.
470  for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
471  const auto s = siter.Value();
472  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
473  const auto &arc = aiter.Value();
474  const auto it = paren_map_.find(arc.ilabel);
475  if (it != paren_map_.end()) { // Is a paren?
476  const auto paren_id = it->second;
477  if (arc.ilabel == parens_[paren_id].first) { // Open paren.
478  balance_data_.OpenInsert(paren_id, arc.nextstate);
479  } else { // Close paren.
480  const internal::ParenState<Arc> paren_state(paren_id, s);
481  close_paren_multimap_.emplace(paren_state, arc);
482  }
483  }
484  }
485  }
486 }
487 
488 // Computes the shortest distance stored in a recursive way. Each sub-graph
489 // (i.e., different paren "start" state) begins with weight One().
490 template <class Arc, class Queue>
492  if (start == kNoStateId) return;
493  Queue state_queue;
494  state_queue_ = &state_queue;
495  const SearchState q(start, start);
496  Enqueue(q);
497  sp_data_.SetDistance(q, Weight::One());
498  while (!state_queue_->Empty()) {
499  const auto state = state_queue_->Head();
500  state_queue_->Dequeue();
501  const SearchState s(state, start);
502  sp_data_.SetFlags(s, 0, kEnqueued);
503  ProcFinal(s);
504  ProcArcs(s);
505  sp_data_.SetFlags(s, kExpanded, kExpanded);
506  }
507  sp_data_.SetFlags(q, kFinished, kFinished);
508  balance_data_.FinishInsert(start);
509  sp_data_.GC(start);
510 }
511 
512 // Updates best complete path.
513 template <class Arc, class Queue>
515  if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
516  const auto weight = Times(sp_data_.Distance(s), ifst_->Final(s.state));
517  if (fdistance_ != Plus(fdistance_, weight)) {
518  if (f_parent_.state != kNoStateId) {
519  sp_data_.SetFlags(f_parent_, 0, internal::kPdtFinal);
520  }
521  sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
522  fdistance_ = Plus(fdistance_, weight);
523  f_parent_ = s;
524  }
525  }
526 }
527 
528 // Processes all arcs leaving the state s.
529 template <class Arc, class Queue>
531  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
532  aiter.Next()) {
533  const auto &arc = aiter.Value();
534  const auto weight = Times(sp_data_.Distance(s), arc.weight);
535  const auto it = paren_map_.find(arc.ilabel);
536  if (it != paren_map_.end()) { // Is a paren?
537  const auto paren_id = it->second;
538  if (arc.ilabel == parens_[paren_id].first) {
539  ProcOpenParen(paren_id, s, arc.nextstate, weight);
540  } else {
541  ProcCloseParen(paren_id, s, weight);
542  }
543  } else {
544  ProcNonParen(s, arc.nextstate, weight);
545  }
546  }
547 }
548 
549 // Saves the shortest path info for reaching this parenthesis and starts a new
550 // SSSP in the sub-graph pointed to by the parenthesis if previously unvisited.
551 // Otherwise it finds any previously encountered closing parentheses and relaxes
552 // them using the recursively stored shortest distance to them.
553 template <class Arc, class Queue>
555  SearchState s,
556  StateId nextstate,
557  const Weight &weight) {
558  const SearchState d(nextstate, nextstate);
559  const ParenSpec paren(paren_id, s.start, d.start);
560  const auto pdist = sp_data_.Distance(paren);
561  if (pdist != Plus(pdist, weight)) {
562  sp_data_.SetDistance(paren, weight);
563  sp_data_.SetParent(paren, s);
564  const auto dist = sp_data_.Distance(d);
565  if (dist == Weight::Zero()) {
566  auto *state_queue = state_queue_;
567  GetDistance(d.start);
568  state_queue_ = state_queue;
569  } else if (!(sp_data_.Flags(d) & kFinished)) {
570  FSTERROR()
571  << "PdtShortestPath: open parenthesis recursion: not bounded stack";
572  error_ = true;
573  }
574  for (auto set_iter = balance_data_.Find(paren_id, nextstate);
575  !set_iter.Done(); set_iter.Next()) {
576  const SearchState cpstate(set_iter.Element(), d.start);
577  const internal::ParenState<Arc> paren_state(paren_id, cpstate.state);
578  for (auto cpit = close_paren_multimap_.find(paren_state);
579  cpit != close_paren_multimap_.end() && paren_state == cpit->first;
580  ++cpit) {
581  const auto &cparc = cpit->second;
582  const auto cpw =
583  Times(weight, Times(sp_data_.Distance(cpstate), cparc.weight));
584  Relax(cpstate, s, cparc.nextstate, cpw, paren_id);
585  }
586  }
587  }
588 }
589 
590 // Saves the correspondence between each closing parenthesis and its balancing
591 // open parenthesis info. Relaxes any close parenthesis destination state that
592 // has a balancing previously encountered open parenthesis.
593 template <class Arc, class Queue>
595  SearchState s,
596  const Weight &weight) {
597  const internal::ParenState<Arc> paren_state(paren_id, s.start);
598  if (!(sp_data_.Flags(s) & kExpanded)) {
599  balance_data_.CloseInsert(paren_id, s.start, s.state);
600  sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
601  }
602 }
603 
604 // Classical relaxation for non-parentheses.
605 template <class Arc, class Queue>
607  StateId nextstate,
608  const Weight &weight) {
609  Relax(s, s, nextstate, weight, kNoLabel);
610 }
611 
612 // Classical relaxation on the search graph for an arc with destination state
613 // nexstate from state s. State t is in the same sub-graph as nextstate (i.e.,
614 // has the same paren "start").
615 template <class Arc, class Queue>
617  StateId nextstate,
618  const Weight &weight,
619  Label paren_id) {
620  const SearchState d(nextstate, t.start);
621  Weight dist = sp_data_.Distance(d);
622  if (dist != Plus(dist, weight)) {
623  sp_data_.SetParent(d, s);
624  sp_data_.SetParenId(d, paren_id);
625  sp_data_.SetDistance(d, Plus(dist, weight));
626  Enqueue(d);
627  }
628 }
629 
630 template <class Arc, class Queue>
632  if (!(sp_data_.Flags(s) & kEnqueued)) {
633  state_queue_->Enqueue(s.state);
634  sp_data_.SetFlags(s, kEnqueued, kEnqueued);
635  ++nenqueued_;
636  } else {
637  state_queue_->Update(s.state);
638  }
639 }
640 
641 // Follows parent pointers to find the shortest path. A stack is used since the
642 // shortest distance is stored recursively.
643 template <class Arc, class Queue>
645  SearchState s = f_parent_;
647  StateId s_p = kNoStateId;
648  StateId d_p = kNoStateId;
649  auto arc = kNoArc;
650  Label paren_id = kNoLabel;
651  std::stack<ParenSpec> paren_stack;
652  while (s.state != kNoStateId) {
653  d_p = s_p;
654  s_p = ofst_->AddState();
655  if (d.state == kNoStateId) {
656  ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
657  } else {
658  if (paren_id != kNoLabel) { // Paren?
659  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
660  paren_stack.pop();
661  } else { // Close paren?
662  const ParenSpec paren(paren_id, d.start, s.start);
663  paren_stack.push(paren);
664  }
665  if (!keep_parens_) arc.ilabel = arc.olabel = 0;
666  }
667  arc.nextstate = d_p;
668  ofst_->AddArc(s_p, arc);
669  }
670  d = s;
671  s = sp_data_.Parent(d);
672  paren_id = sp_data_.ParenId(d);
673  if (s.state != kNoStateId) {
674  arc = GetPathArc(s, d, paren_id, false);
675  } else if (!paren_stack.empty()) {
676  const ParenSpec paren = paren_stack.top();
677  s = sp_data_.Parent(paren);
678  paren_id = paren.paren_id;
679  arc = GetPathArc(s, d, paren_id, true);
680  }
681  }
682  ofst_->SetStart(s_p);
683  ofst_->SetProperties(
684  ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
686 }
687 
688 // Finds transition with least weight between two states with label matching
689 // paren_id and open/close paren type or a non-paren if kNoLabel.
690 template <class Arc, class Queue>
692  Label paren_id, bool open_paren) {
693  auto path_arc = kNoArc;
694  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
695  aiter.Next()) {
696  const auto &arc = aiter.Value();
697  if (arc.nextstate != d.state) continue;
698  Label arc_paren_id = kNoLabel;
699  const auto it = paren_map_.find(arc.ilabel);
700  if (it != paren_map_.end()) {
701  arc_paren_id = it->second;
702  bool arc_open_paren = (arc.ilabel == parens_[arc_paren_id].first);
703  if (arc_open_paren != open_paren) continue;
704  }
705  if (arc_paren_id != paren_id) continue;
706  if (arc.weight == Plus(arc.weight, path_arc.weight)) path_arc = arc;
707  }
708  if (path_arc.nextstate == kNoStateId) {
709  FSTERROR() << "PdtShortestPath::GetPathArc: Failed to find arc";
710  error_ = true;
711  }
712  return path_arc;
713 }
714 
715 template <class Arc, class Queue>
717  Weight::Zero(), kNoStateId);
718 
719 // Functional variants.
720 
721 template <class Arc, class Queue>
723  const Fst<Arc> &ifst,
724  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
725  &parens,
727  PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
728  psp.ShortestPath(ofst);
729 }
730 
731 template <class Arc>
733  const Fst<Arc> &ifst,
734  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
735  &parens,
736  MutableFst<Arc> *ofst) {
739  PdtShortestPath<Arc, Q> psp(ifst, parens, opts);
740  psp.ShortestPath(ofst);
741 }
742 
743 } // namespace fst
744 
745 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
bool operator==(const ParenSpec &other) const
internal::PdtBalanceData< Arc > * GetBalanceData()
constexpr int kNoLabel
Definition: fst.h:195
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
void SetFlags(SearchState s, uint8_t f, uint8_t mask)
uint8_t Flags(SearchState s) const
SearchState Parent(SearchState s) const
uint64_t ShortestPathProperties(uint64_t props, bool tree=false)
Definition: properties.cc:373
typename Arc::StateId StateId
Definition: shortest-path.h:87
void SetDistance(SearchState s, Weight weight)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
SearchState Parent(const ParenSpec &paren) const
constexpr uint64_t kError
Definition: properties.h:52
constexpr uint8_t kPdtFinal
Definition: shortest-path.h:65
virtual void SetInputSymbols(const SymbolTable *isyms)=0
ParenSpec(Label paren_id=kNoLabel, StateId src_start=kNoStateId, StateId dest_start=kNoStateId)
std::unordered_multimap< internal::ParenState< Arc >, Arc, typename internal::ParenState< Arc >::Hash > CloseParenMultimap
constexpr int kNoStateId
Definition: fst.h:196
void SetParent(SearchState s, SearchState p)
constexpr uint64_t kRightSemiring
Definition: weight.h:139
#define FSTERROR()
Definition: util.h:56
SearchState(StateId s=kNoStateId, StateId t=kNoStateId)
Definition: shortest-path.h:94
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:359
virtual void SetProperties(uint64_t props, uint64_t mask)=0
PdtShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, const PdtShortestPathOptions< Arc, Queue > &opts)
#define VLOG(level)
Definition: log.h:54
const CloseParenMultimap & GetCloseParenMultimap() const
Weight Distance(const ParenSpec &paren) const
void SetParenId(SearchState s, Label p)
PdtShortestPathOptions(bool keep_parentheses=false, bool path_gc=true)
Definition: shortest-path.h:55
Label ParenId(SearchState s) const
constexpr uint8_t kPdtInited
Definition: shortest-path.h:64
void SetDistance(const ParenSpec &paren, Weight weight)
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
constexpr size_t kNoArc
Definition: shortest-path.h:86
constexpr uint64_t kPath
Definition: weight.h:150
constexpr uint64_t kFstProperties
Definition: properties.h:326
void ShortestPath(MutableFst< Arc > *ofst)
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:179
bool operator==(const SearchState &other) const
Definition: shortest-path.h:97
virtual void DeleteStates(const std::vector< StateId > &)=0
typename internal::PdtBalanceData< Arc >::SetIterator CloseSourceIterator
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
Weight Distance(SearchState s) const
const internal::PdtShortestPathData< Arc > & GetShortestPathData() const
constexpr uint64_t kExpanded
Definition: properties.h:46
void SetParent(const ParenSpec &paren, SearchState p)
constexpr uint8_t kPdtMarked
Definition: shortest-path.h:66