FST  openfst-1.7.1
OpenFst Library
expand.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Expands a PDT to an FST.
5 
6 #ifndef FST_EXTENSIONS_PDT_EXPAND_H_
7 #define FST_EXTENSIONS_PDT_EXPAND_H_
8 
9 #include <forward_list>
10 #include <vector>
11 
12 #include <fst/log.h>
13 
15 #include <fst/extensions/pdt/pdt.h>
18 #include <fst/cache.h>
19 #include <fst/mutable-fst.h>
20 #include <fst/queue.h>
21 #include <fst/state-table.h>
22 #include <fst/test-properties.h>
23 
24 namespace fst {
25 
26 template <class Arc>
31 
33  const CacheOptions &opts = CacheOptions(), bool keep_parentheses = false,
36  nullptr)
37  : CacheOptions(opts),
38  keep_parentheses(keep_parentheses),
39  stack(stack),
40  state_table(state_table) {}
41 };
42 
43 namespace internal {
44 
45 // Implementation class for PdtExpandFst.
46 template <class Arc>
47 class PdtExpandFstImpl : public CacheImpl<Arc> {
48  public:
49  using Label = typename Arc::Label;
50  using StateId = typename Arc::StateId;
51  using Weight = typename Arc::Weight;
52 
53  using StackId = StateId;
55 
61 
62  using CacheBaseImpl<CacheState<Arc>>::PushArc;
63  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
64  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
65  using CacheBaseImpl<CacheState<Arc>>::HasStart;
66  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
67  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
68  using CacheBaseImpl<CacheState<Arc>>::SetStart;
69 
71  const std::vector<std::pair<Label, Label>> &parens,
72  const PdtExpandFstOptions<Arc> &opts)
73  : CacheImpl<Arc>(opts),
74  fst_(fst.Copy()),
75  stack_(opts.stack ? opts.stack : new PdtStack<StateId, Label>(parens)),
76  state_table_(opts.state_table ? opts.state_table
77  : new PdtStateTable<StateId, StackId>()),
78  own_stack_(opts.stack == 0),
79  own_state_table_(opts.state_table == 0),
80  keep_parentheses_(opts.keep_parentheses) {
81  SetType("expand");
82  const auto props = fst.Properties(kFstProperties, false);
83  SetProperties(PdtExpandProperties(props), kCopyProperties);
84  SetInputSymbols(fst.InputSymbols());
85  SetOutputSymbols(fst.OutputSymbols());
86  }
87 
89  : CacheImpl<Arc>(impl),
90  fst_(impl.fst_->Copy(true)),
91  stack_(new PdtStack<StateId, Label>(*impl.stack_)),
92  state_table_(new PdtStateTable<StateId, StackId>()),
93  own_stack_(true),
94  own_state_table_(true),
95  keep_parentheses_(impl.keep_parentheses_) {
96  SetType("expand");
97  SetProperties(impl.Properties(), kCopyProperties);
98  SetInputSymbols(impl.InputSymbols());
99  SetOutputSymbols(impl.OutputSymbols());
100  }
101 
102  ~PdtExpandFstImpl() override {
103  if (own_stack_) delete stack_;
104  if (own_state_table_) delete state_table_;
105  }
106 
108  if (!HasStart()) {
109  const auto s = fst_->Start();
110  if (s == kNoStateId) return kNoStateId;
111  StateTuple tuple(s, 0);
112  const auto start = state_table_->FindState(tuple);
113  SetStart(start);
114  }
115  return CacheImpl<Arc>::Start();
116  }
117 
119  if (!HasFinal(s)) {
120  const auto &tuple = state_table_->Tuple(s);
121  const auto weight = fst_->Final(tuple.state_id);
122  if (weight != Weight::Zero() && tuple.stack_id == 0)
123  SetFinal(s, weight);
124  else
125  SetFinal(s, Weight::Zero());
126  }
127  return CacheImpl<Arc>::Final(s);
128  }
129 
130  size_t NumArcs(StateId s) {
131  if (!HasArcs(s)) ExpandState(s);
132  return CacheImpl<Arc>::NumArcs(s);
133  }
134 
136  if (!HasArcs(s)) ExpandState(s);
138  }
139 
141  if (!HasArcs(s)) ExpandState(s);
143  }
144 
146  if (!HasArcs(s)) ExpandState(s);
148  }
149 
150  // Computes the outgoing transitions from a state, creating new destination
151  // states as needed.
153  StateTuple tuple = state_table_->Tuple(s);
154  for (ArcIterator<Fst<Arc>> aiter(*fst_, tuple.state_id); !aiter.Done();
155  aiter.Next()) {
156  auto arc = aiter.Value();
157  const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
158  if (stack_id == -1) { // Non-matching close parenthesis.
159  continue;
160  } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
161  // Stack push/pop.
162  arc.ilabel = 0;
163  arc.olabel = 0;
164  }
165  StateTuple ntuple(arc.nextstate, stack_id);
166  arc.nextstate = state_table_->FindState(ntuple);
167  PushArc(s, arc);
168  }
169  SetArcs(s);
170  }
171 
172  const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
173 
175  return *state_table_;
176  }
177 
178  private:
179  // Properties for an expanded PDT.
180  inline uint64 PdtExpandProperties(uint64 inprops) {
181  return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
182  }
183 
184  std::unique_ptr<const Fst<Arc>> fst_;
185  PdtStack<StackId, Label> *stack_;
186  PdtStateTable<StateId, StackId> *state_table_;
187  bool own_stack_;
188  bool own_state_table_;
189  bool keep_parentheses_;
190 };
191 
192 } // namespace internal
193 
194 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
195 // version is a delayed FST. In the PDT, some transitions are labeled with open
196 // or close parentheses. To be interpreted as a PDT, the parens must balance on
197 // a path. The open-close parenthesis label pairs are passed using the parens
198 // argument. The expansion enforces the parenthesis constraints. The PDT must be
199 // expandable as an FST.
200 //
201 // This class attaches interface to implementation and handles reference
202 // counting, delegating most methods to ImplToFst.
203 template <class A>
204 class PdtExpandFst : public ImplToFst<internal::PdtExpandFstImpl<A>> {
205  public:
206  using Arc = A;
207 
208  using Label = typename Arc::Label;
209  using StateId = typename Arc::StateId;
210  using Weight = typename Arc::Weight;
211 
212  using StackId = StateId;
214  using State = typename Store::State;
216 
217  friend class ArcIterator<PdtExpandFst<Arc>>;
219 
221  const std::vector<std::pair<Label, Label>> &parens)
222  : ImplToFst<Impl>(
223  std::make_shared<Impl>(fst, parens, PdtExpandFstOptions<A>())) {}
224 
226  const std::vector<std::pair<Label, Label>> &parens,
227  const PdtExpandFstOptions<Arc> &opts)
228  : ImplToFst<Impl>(std::make_shared<Impl>(fst, parens, opts)) {}
229 
230  // See Fst<>::Copy() for doc.
231  PdtExpandFst(const PdtExpandFst<Arc> &fst, bool safe = false)
232  : ImplToFst<Impl>(fst, safe) {}
233 
234  // Gets a copy of this ExpandFst. See Fst<>::Copy() for further doc.
235  PdtExpandFst<Arc> *Copy(bool safe = false) const override {
236  return new PdtExpandFst<Arc>(*this, safe);
237  }
238 
239  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
240 
241  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
242  GetMutableImpl()->InitArcIterator(s, data);
243  }
244 
246  return GetImpl()->GetStack();
247  }
248 
250  return GetImpl()->GetStateTable();
251  }
252 
253  private:
256 
257  void operator=(const PdtExpandFst &) = delete;
258 };
259 
260 // Specialization for PdtExpandFst.
261 template <class Arc>
263  : public CacheStateIterator<PdtExpandFst<Arc>> {
264  public:
266  : CacheStateIterator<PdtExpandFst<Arc>>(fst, fst.GetMutableImpl()) {}
267 };
268 
269 // Specialization for PdtExpandFst.
270 template <class Arc>
272  : public CacheArcIterator<PdtExpandFst<Arc>> {
273  public:
274  using StateId = typename Arc::StateId;
275 
277  : CacheArcIterator<PdtExpandFst<Arc>>(fst.GetMutableImpl(), s) {
278  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s);
279  }
280 };
281 
282 template <class Arc>
284  StateIteratorData<Arc> *data) const {
285  data->base = new StateIterator<PdtExpandFst<Arc>>(*this);
286 }
287 
288 // PrunedExpand prunes the delayed expansion of a pushdown transducer (PDT)
289 // encoded as an FST into an FST. In the PDT, some transitions are labeled with
290 // open or close parentheses. To be interpreted as a PDT, the parens must
291 // balance on a path. The open-close parenthesis label pairs are passed
292 // using the parens argument. The expansion enforces the parenthesis
293 // constraints.
294 //
295 // The algorithm works by visiting the delayed ExpandFst using a shortest-stack
296 // first queue discipline and relies on the shortest-distance information
297 // computed using a reverse shortest-path call to perform the pruning.
298 //
299 // The algorithm maintains the same state ordering between the ExpandFst being
300 // visited (efst_) and the result of pruning written into the MutableFst (ofst_)
301 // to improve readability.
302 template <class Arc>
304  public:
305  using Label = typename Arc::Label;
306  using StateId = typename Arc::StateId;
307  using Weight = typename Arc::Weight;
308 
309  using StackId = StateId;
313 
314  // Constructor taking as input a PDT specified by by an input FST and a vector
315  // of parentheses. The keep_parentheses argument specifies whether parentheses
316  // are replaced by epsilons or not during the expansion. The cache options are
317  // passed to the underlying ExpandFst.
319  const std::vector<std::pair<Label, Label>> &parens,
320  bool keep_parentheses = false,
321  const CacheOptions &opts = CacheOptions())
322  : ifst_(ifst.Copy()),
323  keep_parentheses_(keep_parentheses),
324  stack_(parens),
325  efst_(ifst, parens,
326  PdtExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
327  queue_(state_table_, stack_, stack_length_, distance_, fdistance_),
328  error_(false) {
329  Reverse(*ifst_, parens, &rfst_);
330  VectorFst<Arc> path;
331  reverse_shortest_path_.reset(new PdtShortestPath<Arc, FifoQueue<StateId>>(
332  rfst_, parens,
333  PdtShortestPathOptions<Arc, FifoQueue<StateId>>(true, false)));
334  reverse_shortest_path_->ShortestPath(&path);
335  error_ = (path.Properties(kError, true) == kError);
336  balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse(
337  rfst_.NumStates(), 10, -1));
338  InitCloseParenMultimap(parens);
339  }
340 
341  bool Error() const { return error_; }
342 
343  // Expands and prunes the input PDT according to the provided weight
344  // threshold, wirting the result into an output mutable FST.
345  void Expand(MutableFst<Arc> *ofst, const Weight &threshold);
346 
347  private:
348  static constexpr uint8 kEnqueued = 0x01;
349  static constexpr uint8 kExpanded = 0x02;
350  static constexpr uint8 kSourceState = 0x04;
351 
352  // Comparison functor used by the queue:
353  //
354  // 1. States corresponding to shortest stack first, and
355  // 2. for stacks of matching length, reverse lexicographic order is used, and
356  // 3. for states with the same stack, shortest-first order is used.
357  class StackCompare {
358  public:
359  StackCompare(const StateTable &state_table, const Stack &stack,
360  const std::vector<StackId> &stack_length,
361  const std::vector<Weight> &distance,
362  const std::vector<Weight> &fdistance)
363  : state_table_(state_table),
364  stack_(stack),
365  stack_length_(stack_length),
366  distance_(distance),
367  fdistance_(fdistance) {}
368 
369  bool operator()(StateId s1, StateId s2) const {
370  auto si1 = state_table_.Tuple(s1).stack_id;
371  auto si2 = state_table_.Tuple(s2).stack_id;
372  if (stack_length_[si1] < stack_length_[si2]) return true;
373  if (stack_length_[si1] > stack_length_[si2]) return false;
374  // If stack IDs are equal, use A*.
375  if (si1 == si2) {
376  return less_(Distance(s1), Distance(s2));
377  }
378  // If lengths are equal, uses reverse lexicographic order.
379  for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
380  if (stack_.Top(si1) < stack_.Top(si2)) return true;
381  if (stack_.Top(si1) > stack_.Top(si2)) return false;
382  }
383  return false;
384  }
385 
386  private:
387  Weight Distance(StateId s) const {
388  return (s < distance_.size()) && (s < fdistance_.size())
389  ? Times(distance_[s], fdistance_[s])
390  : Weight::Zero();
391  }
392 
393  const StateTable &state_table_;
394  const Stack &stack_;
395  const std::vector<StackId> &stack_length_;
396  const std::vector<Weight> &distance_;
397  const std::vector<Weight> &fdistance_;
398  const NaturalLess<Weight> less_;
399  };
400 
401  class ShortestStackFirstQueue
402  : public ShortestFirstQueue<StateId, StackCompare> {
403  public:
404  ShortestStackFirstQueue(const PdtStateTable<StateId, StackId> &state_table,
405  const Stack &stack,
406  const std::vector<StackId> &stack_length,
407  const std::vector<Weight> &distance,
408  const std::vector<Weight> &fdistance)
410  state_table, stack, stack_length, distance, fdistance)) {}
411  };
412 
413  void InitCloseParenMultimap(
414  const std::vector<std::pair<Label, Label>> &parens);
415 
416  Weight DistanceToDest(StateId source, StateId dest) const;
417 
418  uint8 Flags(StateId s) const;
419 
420  void SetFlags(StateId s, uint8 flags, uint8 mask);
421 
422  Weight Distance(StateId s) const;
423 
424  void SetDistance(StateId s, Weight weight);
425 
426  Weight FinalDistance(StateId s) const;
427 
428  void SetFinalDistance(StateId s, Weight weight);
429 
430  StateId SourceState(StateId s) const;
431 
432  void SetSourceState(StateId s, StateId p);
433 
434  void AddStateAndEnqueue(StateId s);
435 
436  void Relax(StateId s, const Arc &arc, Weight weight);
437 
438  bool PruneArc(StateId s, const Arc &arc);
439 
440  void ProcStart();
441 
442  void ProcFinal(StateId s);
443 
444  bool ProcNonParen(StateId s, const Arc &arc, bool add_arc);
445 
446  bool ProcOpenParen(StateId s, const Arc &arc, StackId si, StackId nsi);
447 
448  bool ProcCloseParen(StateId s, const Arc &arc);
449 
450  void ProcDestStates(StateId s, StackId si);
451 
452  // Input PDT.
453  std::unique_ptr<Fst<Arc>> ifst_;
454  // Reversed PDT.
455  VectorFst<Arc> rfst_;
456  // Keep parentheses in ofst?
457  const bool keep_parentheses_;
458  // State table for efst_.
459  StateTable state_table_;
460  // Stack trie.
461  Stack stack_;
462  // Expanded PDT.
463  PdtExpandFst<Arc> efst_;
464  // Length of stack for given stack ID.
465  std::vector<StackId> stack_length_;
466  // Distance from initial state in efst_/ofst.
467  std::vector<Weight> distance_;
468  // Distance to final states in efst_/ofst.
469  std::vector<Weight> fdistance_;
470  // Queue used to visit efst_.
471  ShortestStackFirstQueue queue_;
472  // Construction time failure?
473  bool error_;
474  // Status flags for states in efst_/ofst.
475  std::vector<uint8> flags_;
476  // PDT source state for each expanded state.
477  std::vector<StateId> sources_;
478  // Shortest path for rfst_.
479  std::unique_ptr<PdtShortestPath<Arc, FifoQueue<StateId>>>
480  reverse_shortest_path_;
481  std::unique_ptr<internal::PdtBalanceData<Arc>> balance_data_;
482  // Maps open paren arcs to balancing close paren arcs.
483  typename PdtShortestPath<Arc, FifoQueue<StateId>>::CloseParenMultimap
484  close_paren_multimap_;
485  MutableFst<Arc> *ofst_; // Output FST.
486  Weight limit_; // Weight limit.
487 
488  // Maps a state s in ifst (i.e., the source of a closed paranthesis matching
489  // the top of current_stack_id_ to final states in efst_.
490  std::unordered_map<StateId, Weight> dest_map_;
491  // Stack ID of the states currently at the top of the queue, i.e., the states
492  // currently being popped and processed.
493  StackId current_stack_id_;
494  ssize_t current_paren_id_; // Paren ID at top of current stack.
495  ssize_t cached_stack_id_;
496  StateId cached_source_;
497  // The set of pairs of destination states and weights to final states for the
498  // source state cached_source_ and the paren ID cached_paren_id_; i.e., the
499  // set of source states of a closed parenthesis with paren ID cached_paren_id
500  // balancing an incoming open parenthesis with paren ID cached_paren_id_ in
501  // state cached_source_.
502  std::forward_list<std::pair<StateId, Weight>> cached_dest_list_;
503  NaturalLess<Weight> less_;
504 };
505 
506 // Initializes close paren multimap, mapping pairs (s, paren_id) to all the arcs
507 // out of s labeled with close parenthese for paren_id.
508 template <class Arc>
510  const std::vector<std::pair<Label, Label>> &parens) {
511  std::unordered_map<Label, Label> paren_map;
512  for (size_t i = 0; i < parens.size(); ++i) {
513  const auto &pair = parens[i];
514  paren_map[pair.first] = i;
515  paren_map[pair.second] = i;
516  }
517  for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
518  const auto s = siter.Value();
519  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
520  const auto &arc = aiter.Value();
521  const auto it = paren_map.find(arc.ilabel);
522  if (it == paren_map.end()) continue;
523  if (arc.ilabel == parens[it->second].second) { // Close paren.
524  const internal::ParenState<Arc> key(it->second, s);
525  close_paren_multimap_.emplace(key, arc);
526  }
527  }
528  }
529 }
530 
531 // Returns the weight of the shortest balanced path from source to dest
532 // in ifst_; dest must be the source state of a close paren arc.
533 template <class Arc>
534 typename Arc::Weight PdtPrunedExpand<Arc>::DistanceToDest(StateId source,
535  StateId dest) const {
536  using SearchState =
537  typename PdtShortestPath<Arc, FifoQueue<StateId>>::SearchState;
538  const SearchState ss(source + 1, dest + 1);
539  const auto distance =
540  reverse_shortest_path_->GetShortestPathData().Distance(ss);
541  VLOG(2) << "D(" << source << ", " << dest << ") =" << distance;
542  return distance;
543 }
544 
545 // Returns the flags for state s in ofst_.
546 template <class Arc>
548  return s < flags_.size() ? flags_[s] : 0;
549 }
550 
551 // Modifies the flags for state s in ofst_.
552 template <class Arc>
553 void PdtPrunedExpand<Arc>::SetFlags(StateId s, uint8 flags, uint8 mask) {
554  while (flags_.size() <= s) flags_.push_back(0);
555  flags_[s] &= ~mask;
556  flags_[s] |= flags & mask;
557 }
558 
559 // Returns the shortest distance from the initial state to s in ofst_.
560 template <class Arc>
561 typename Arc::Weight PdtPrunedExpand<Arc>::Distance(StateId s) const {
562  return s < distance_.size() ? distance_[s] : Weight::Zero();
563 }
564 
565 // Sets the shortest distance from the initial state to s in ofst_.
566 template <class Arc>
568  while (distance_.size() <= s) distance_.push_back(Weight::Zero());
569  distance_[s] = std::move(weight);
570 }
571 
572 // Returns the shortest distance from s to the final states in ofst_.
573 template <class Arc>
574 typename Arc::Weight PdtPrunedExpand<Arc>::FinalDistance(StateId s) const {
575  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
576 }
577 
578 // Sets the shortest distance from s to the final states in ofst_.
579 template <class Arc>
581  while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
582  fdistance_[s] = std::move(weight);
583 }
584 
585 // Returns the PDT source state of state s in ofst_.
586 template <class Arc>
587 typename Arc::StateId PdtPrunedExpand<Arc>::SourceState(StateId s) const {
588  return s < sources_.size() ? sources_[s] : kNoStateId;
589 }
590 
591 // Sets the PDT source state of state s in ofst_ to state p'in ifst_.
592 template <class Arc>
594  while (sources_.size() <= s) sources_.push_back(kNoStateId);
595  sources_[s] = p;
596 }
597 
598 // Adds state s of efst_ to ofst_ and inserts it in the queue, modifying the
599 // flags for s accordingly.
600 template <class Arc>
602  if (!(Flags(s) & (kEnqueued | kExpanded))) {
603  while (ofst_->NumStates() <= s) ofst_->AddState();
604  queue_.Enqueue(s);
605  SetFlags(s, kEnqueued, kEnqueued);
606  } else if (Flags(s) & kEnqueued) {
607  queue_.Update(s);
608  }
609  // TODO(allauzen): Check everything is fine when kExpanded?
610 }
611 
612 // Relaxes arc out of state s in ofst_ as follows:
613 //
614 // 1. If the distance to s times the weight of arc is smaller than
615 // the currently stored distance for arc.nextstate, updates
616 // Distance(arc.nextstate) with a new estimate
617 // 2. If fd is less than the currently stored distance from arc.nextstate to the
618 // final state, updates with new estimate.
619 template <class Arc>
620 void PdtPrunedExpand<Arc>::Relax(StateId s, const Arc &arc, Weight fd) {
621  const auto nd = Times(Distance(s), arc.weight);
622  if (less_(nd, Distance(arc.nextstate))) {
623  SetDistance(arc.nextstate, nd);
624  SetSourceState(arc.nextstate, SourceState(s));
625  }
626  if (less_(fd, FinalDistance(arc.nextstate))) {
627  SetFinalDistance(arc.nextstate, fd);
628  }
629  VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
630  << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
631  << ", nd = " << nd;
632 }
633 
634 // Returns whether the arc out of state s in efst needs pruned.
635 template <class Arc>
636 bool PdtPrunedExpand<Arc>::PruneArc(StateId s, const Arc &arc) {
637  VLOG(2) << "Prune ?";
638  auto fd = Weight::Zero();
639  if ((cached_source_ != SourceState(s)) ||
640  (cached_stack_id_ != current_stack_id_)) {
641  cached_source_ = SourceState(s);
642  cached_stack_id_ = current_stack_id_;
643  cached_dest_list_.clear();
644  if (cached_source_ != ifst_->Start()) {
645  for (auto set_iter =
646  balance_data_->Find(current_paren_id_, cached_source_);
647  !set_iter.Done(); set_iter.Next()) {
648  auto dest = set_iter.Element();
649  const auto it = dest_map_.find(dest);
650  cached_dest_list_.push_front(*it);
651  }
652  } else {
653  // TODO(allauzen): queue discipline should prevent this from ever
654  // happening.
655  // Replace by a check.
656  cached_dest_list_.push_front(
657  std::make_pair(rfst_.Start() - 1, Weight::One()));
658  }
659  }
660  for (auto it = cached_dest_list_.begin(); it != cached_dest_list_.end();
661  ++it) {
662  const auto d =
663  DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first);
664  fd = Plus(fd, Times(d, it->second));
665  }
666  Relax(s, arc, fd);
667  return less_(limit_, Times(Distance(s), Times(arc.weight, fd)));
668 }
669 
670 // Adds start state of efst_ to ofst_, enqueues it, and initializes the distance
671 // data structures.
672 template <class Arc>
674  const auto s = efst_.Start();
675  AddStateAndEnqueue(s);
676  ofst_->SetStart(s);
677  SetSourceState(s, ifst_->Start());
678  current_stack_id_ = 0;
679  current_paren_id_ = -1;
680  stack_length_.push_back(0);
681  const auto r = rfst_.Start() - 1;
682  cached_source_ = ifst_->Start();
683  cached_stack_id_ = 0;
684  cached_dest_list_.push_front(std::make_pair(r, Weight::One()));
685  const PdtStateTuple<StateId, StackId> tuple(r, 0);
686  SetFinalDistance(state_table_.FindState(tuple), Weight::One());
687  SetDistance(s, Weight::One());
688  const auto d = DistanceToDest(ifst_->Start(), r);
689  SetFinalDistance(s, d);
690  VLOG(2) << d;
691 }
692 
693 // Makes s final in ofst_ if shortest accepting path ending in s is below
694 // threshold.
695 template <class Arc>
697  const auto weight = efst_.Final(s);
698  if (weight == Weight::Zero()) return;
699  if (less_(limit_, Times(Distance(s), weight))) return;
700  ofst_->SetFinal(s, weight);
701 }
702 
703 // Returns true when an arc (or meta-arc) leaving state s in efst_ is below the
704 // threshold. When add_arc is true, arc is added to ofst_.
705 template <class Arc>
706 bool PdtPrunedExpand<Arc>::ProcNonParen(StateId s, const Arc &arc,
707  bool add_arc) {
708  VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate << ", "
709  << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
710  << ", add_arc = " << (add_arc ? "true" : "false");
711  if (PruneArc(s, arc)) return false;
712  if (add_arc) ofst_->AddArc(s, arc);
713  AddStateAndEnqueue(arc.nextstate);
714  return true;
715 }
716 
717 // Processes an open paren arc leaving state s in ofst_. When the arc is labeled
718 // with an open paren,
719 //
720 // 1. Considers each (shortest) balanced path starting in s by taking the arc
721 // and ending by a close paren balancing the open paren of as a meta-arc,
722 // processing and pruning each meta-arc as a non-paren arc, inserting its
723 // destination to the queue;
724 // 2. if at least one of these meta-arcs has not been pruned, adds the
725 // destination of arc to ofst_ as a new source state for the stack ID nsi, and
726 // inserts it in the queue.
727 template <class Arc>
728 bool PdtPrunedExpand<Arc>::ProcOpenParen(StateId s, const Arc &arc, StackId si,
729  StackId nsi) {
730  // Updates the stack length when needed.
731  while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
732  if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1;
733  const auto ns = arc.nextstate;
734  VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
735  << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
736  bool proc_arc = false;
737  auto fd = Weight::Zero();
738  const auto paren_id = stack_.ParenId(arc.ilabel);
739  std::forward_list<StateId> sources;
740  for (auto set_iter =
741  balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
742  !set_iter.Done(); set_iter.Next()) {
743  sources.push_front(set_iter.Element());
744  }
745  for (const auto source : sources) {
746  VLOG(2) << "Close paren source: " << source;
747  const internal::ParenState<Arc> paren_state(paren_id, source);
748  for (auto it = close_paren_multimap_.find(paren_state);
749  it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
750  auto meta_arc = it->second;
751  const PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
752  meta_arc.nextstate = state_table_.FindState(tuple);
753  const auto state_id = state_table_.Tuple(ns).state_id;
754  const auto d = DistanceToDest(state_id, source);
755  VLOG(2) << state_id << ", " << source;
756  VLOG(2) << "Meta arc weight = " << arc.weight << " Times " << d
757  << " Times " << meta_arc.weight;
758  meta_arc.weight = Times(arc.weight, Times(d, meta_arc.weight));
759  proc_arc |= ProcNonParen(s, meta_arc, false);
760  fd = Plus(
761  fd,
762  Times(Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
763  it->second.weight),
764  FinalDistance(meta_arc.nextstate)));
765  }
766  }
767  if (proc_arc) {
768  VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
769  ofst_->AddArc(
770  s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
771  AddStateAndEnqueue(arc.nextstate);
772  const auto nd = Times(Distance(s), arc.weight);
773  if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd);
774  // FinalDistance not necessary for source state since pruning decided using
775  // meta-arcs above. But this is a problem with A*, hence the following.
776  if (less_(fd, FinalDistance(arc.nextstate)))
777  SetFinalDistance(arc.nextstate, fd);
778  SetFlags(arc.nextstate, kSourceState, kSourceState);
779  }
780  return proc_arc;
781 }
782 
783 // Checks that shortest path through close paren arc in efst_ is below
784 // threshold, and if so, adds it to ofst_.
785 template <class Arc>
786 bool PdtPrunedExpand<Arc>::ProcCloseParen(StateId s, const Arc &arc) {
787  const auto weight =
788  Times(Distance(s), Times(arc.weight, FinalDistance(arc.nextstate)));
789  if (less_(limit_, weight)) return false;
790  ofst_->AddArc(s,
791  keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
792  return true;
793 }
794 
795 // When state s in ofst_ is a source state for stack ID si, identifies all the
796 // corresponding possible destination states, that is, all the states in ifst_
797 // that have an outgoing close paren arc balancing the incoming open paren taken
798 // to get to s. For each such state t, computes the shortest distance from (t,
799 // si) to the final states in ofst_. Stores this information in dest_map_.
800 template <class Arc>
802  if (!(Flags(s) & kSourceState)) return;
803  if (si != current_stack_id_) {
804  dest_map_.clear();
805  current_stack_id_ = si;
806  current_paren_id_ = stack_.Top(current_stack_id_);
807  VLOG(2) << "StackID " << si << " dequeued for first time";
808  }
809  // TODO(allauzen): clean up source state business; rename current function to
810  // ProcSourceState.
811  SetSourceState(s, state_table_.Tuple(s).state_id);
812  const auto paren_id = stack_.Top(si);
813  for (auto set_iter =
814  balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
815  !set_iter.Done(); set_iter.Next()) {
816  const auto dest_state = set_iter.Element();
817  if (dest_map_.find(dest_state) != dest_map_.end()) continue;
818  auto dest_weight = Weight::Zero();
819  internal::ParenState<Arc> paren_state(paren_id, dest_state);
820  for (auto it = close_paren_multimap_.find(paren_state);
821  it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
822  const auto &arc = it->second;
823  const PdtStateTuple<StateId, StackId> tuple(arc.nextstate,
824  stack_.Pop(si));
825  dest_weight =
826  Plus(dest_weight,
827  Times(arc.weight, FinalDistance(state_table_.FindState(tuple))));
828  }
829  dest_map_[dest_state] = dest_weight;
830  VLOG(2) << "State " << dest_state << " is a dest state for stack ID " << si
831  << " with weight " << dest_weight;
832  }
833 }
834 
835 // Expands and prunes the input PDT, writing the result in ofst.
836 template <class Arc>
838  const typename Arc::Weight &threshold) {
839  ofst_ = ofst;
840  if (error_) {
841  ofst_->SetProperties(kError, kError);
842  return;
843  }
844  ofst_->DeleteStates();
845  ofst_->SetInputSymbols(ifst_->InputSymbols());
846  ofst_->SetOutputSymbols(ifst_->OutputSymbols());
847  limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
848  flags_.clear();
849  ProcStart();
850  while (!queue_.Empty()) {
851  const auto s = queue_.Head();
852  queue_.Dequeue();
853  SetFlags(s, kExpanded, kExpanded | kEnqueued);
854  VLOG(2) << s << " dequeued!";
855  ProcFinal(s);
856  StackId stack_id = state_table_.Tuple(s).stack_id;
857  ProcDestStates(s, stack_id);
858  for (ArcIterator<PdtExpandFst<Arc>> aiter(efst_, s); !aiter.Done();
859  aiter.Next()) {
860  const auto &arc = aiter.Value();
861  const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
862  if (stack_id == nextstack_id) {
863  ProcNonParen(s, arc, true);
864  } else if (stack_id == stack_.Pop(nextstack_id)) {
865  ProcOpenParen(s, arc, stack_id, nextstack_id);
866  } else {
867  ProcCloseParen(s, arc);
868  }
869  }
870  VLOG(2) << "d[" << s << "] = " << Distance(s) << ", fd[" << s
871  << "] = " << FinalDistance(s);
872  }
873 }
874 
875 // Expand functions.
876 
877 template <class Arc>
879  using Weight = typename Arc::Weight;
880 
881  bool connect;
884 
885  PdtExpandOptions(bool connect = true, bool keep_parentheses = false,
886  Weight weight_threshold = Weight::Zero())
887  : connect(connect),
888  keep_parentheses(keep_parentheses),
889  weight_threshold(std::move(weight_threshold)) {}
890 };
891 
892 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
893 // version writes the expanded PDT to a mutable FST. In the PDT, some
894 // transitions are labeled with open or close parentheses. To be interpreted as
895 // a PDT, the parens must balance on a path. The open-close parenthesis label
896 // pairs are passed using the parens argument. Expansion enforces the
897 // parenthesis constraints. The PDT must be expandable as an FST.
898 template <class Arc>
899 void Expand(
900  const Fst<Arc> &ifst,
901  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
902  &parens,
903  MutableFst<Arc> *ofst, const PdtExpandOptions<Arc> &opts) {
905  eopts.gc_limit = 0;
906  if (opts.weight_threshold == Arc::Weight::Zero()) {
907  eopts.keep_parentheses = opts.keep_parentheses;
908  *ofst = PdtExpandFst<Arc>(ifst, parens, eopts);
909  } else {
910  PdtPrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
911  pruned_expand.Expand(ofst, opts.weight_threshold);
912  }
913  if (opts.connect) Connect(ofst);
914 }
915 
916 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
917 // version writes the expanded PDT result to a mutable FST. In the PDT, some
918 // transitions are labeled with open or close parentheses. To be interpreted as
919 // a PDT, the parens must balance on a path. The open-close parenthesis label
920 // pairs are passed using the parents argument. Expansion enforces the
921 // parenthesis constraints. The PDT must be expandable as an FST.
922 template <class Arc>
923 void Expand(const Fst<Arc> &ifst,
924  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
925  &parens, MutableFst<Arc> *ofst, bool connect = true,
926  bool keep_parentheses = false) {
927  const PdtExpandOptions<Arc> opts(connect, keep_parentheses);
928  Expand(ifst, parens, ofst, opts);
929 }
930 
931 } // namespace fst
932 
933 #endif // FST_EXTENSIONS_PDT_EXPAND_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:99
typename Arc::StateId StateId
Definition: expand.h:209
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:174
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: expand.h:283
typename Arc::Label Label
Definition: expand.h:49
constexpr uint64 kInitialAcyclic
Definition: properties.h:97
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: expand.h:145
ArcIterator(const PdtExpandFst< Arc > &fst, StateId s)
Definition: expand.h:276
StackId stack_id
Definition: pdt.h:125
uint64_t uint64
Definition: types.h:32
typename Store::State State
Definition: expand.h:214
Weight Final(StateId s)
Definition: expand.h:118
size_t NumOutputEpsilons(StateId s)
Definition: expand.h:140
size_t NumInputEpsilons(StateId s)
Definition: expand.h:135
PdtExpandFst(const PdtExpandFst< Arc > &fst, bool safe=false)
Definition: expand.h:231
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:82
const SymbolTable * OutputSymbols() const
Definition: fst.h:690
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, const char *src="")
Definition: flags.cc:46
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens)
Definition: expand.h:220
typename Arc::StateId StateId
Definition: expand.h:274
typename Arc::StateId StateId
Definition: expand.h:50
PdtStack< typename Arc::StateId, typename Arc::Label > * stack
Definition: expand.h:29
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
SetType
Definition: set-weight.h:37
typename PdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1143
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:268
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:88
typename internal::PdtBalanceData< Arc >::SetIterator SetIterator
Definition: expand.h:312
constexpr uint64 kFstProperties
Definition: properties.h:301
constexpr uint64 kCopyProperties
Definition: properties.h:138
virtual uint64 Properties() const
Definition: fst.h:666
constexpr int kNoStateId
Definition: fst.h:180
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: expand.h:241
typename Arc::Weight Weight
Definition: expand.h:879
constexpr uint64 kExpanded
Definition: properties.h:27
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:20
virtual uint64 Properties(uint64 mask, bool test) const =0
void ExpandState(StateId s)
Definition: expand.h:152
const PdtStack< StackId, Label > & GetStack() const
Definition: expand.h:172
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
Definition: expand.h:225
bool Error() const
Definition: expand.h:341
typename Arc::Weight Weight
Definition: expand.h:210
constexpr uint64 kUnweighted
Definition: properties.h:87
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:317
size_t NumArcs(StateId s)
Definition: expand.h:130
StateIteratorBase< Arc > * base
Definition: fst.h:351
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
Definition: expand.h:30
CacheOptions(bool gc=FLAGS_fst_default_cache_gc, size_t gc_limit=FLAGS_fst_default_cache_gc_limit)
Definition: cache.h:31
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:93
PdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool keep_parentheses=false, PdtStack< typename Arc::StateId, typename Arc::Label > *stack=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *state_table=nullptr)
Definition: expand.h:32
typename Arc::StateId StateId
Definition: expand.h:306
uint8_t uint8
Definition: types.h:29
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:644
PdtExpandOptions(bool connect=true, bool keep_parentheses=false, Weight weight_threshold=Weight::Zero())
Definition: expand.h:885
void Expand(MutableFst< Arc > *ofst, const Weight &threshold)
Definition: expand.h:837
#define VLOG(level)
Definition: log.h:49
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
constexpr uint64 kAcceptor
Definition: properties.h:45
typename Arc::Weight Weight
Definition: expand.h:51
PdtExpandFst< Arc > * Copy(bool safe=false) const override
Definition: expand.h:235
PdtExpandFstImpl(const PdtExpandFstImpl &impl)
Definition: expand.h:88
StateIterator(const PdtExpandFst< Arc > &fst)
Definition: expand.h:265
Weight weight_threshold
Definition: expand.h:883
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:249
typename PdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1189
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:688
constexpr uint64 kAcyclic
Definition: properties.h:92
constexpr uint64 kError
Definition: properties.h:33
StateId state_id
Definition: pdt.h:124
PdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
Definition: expand.h:70
const PdtStack< StackId, Label > & GetStack() const
Definition: expand.h:245
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:302
typename Arc::Weight Weight
Definition: expand.h:307
typename CacheState< Arc >::Arc Arc
Definition: cache.h:833
Impl * GetMutableImpl() const
Definition: fst.h:947
typename Arc::Label Label
Definition: expand.h:208
uint64 Properties(uint64 mask, bool test) const override
Definition: fst.h:889
typename Arc::Label Label
Definition: expand.h:305
size_t gc_limit
Definition: cache.h:29
const Impl * GetImpl() const
Definition: fst.h:945
PdtPrunedExpand(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, bool keep_parentheses=false, const CacheOptions &opts=CacheOptions())
Definition: expand.h:318
virtual void SetProperties(uint64 props, uint64 mask)=0
virtual const SymbolTable * OutputSymbols() const =0