FST  openfst-1.8.2
OpenFst Library
expand.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Expands a PDT to an FST.
19 
20 #ifndef FST_EXTENSIONS_PDT_EXPAND_H_
21 #define FST_EXTENSIONS_PDT_EXPAND_H_
22 
23 #include <cstdint>
24 #include <forward_list>
25 #include <vector>
26 
27 #include <fst/log.h>
29 #include <fst/extensions/pdt/pdt.h>
32 #include <fst/cache.h>
33 #include <fst/mutable-fst.h>
34 #include <fst/queue.h>
35 #include <fst/state-table.h>
36 #include <fst/test-properties.h>
37 #include <unordered_map>
38 
39 namespace fst {
40 
41 template <class Arc>
46 
48  const CacheOptions &opts = CacheOptions(), bool keep_parentheses = false,
51  nullptr)
52  : CacheOptions(opts),
53  keep_parentheses(keep_parentheses),
54  stack(stack),
55  state_table(state_table) {}
56 };
57 
58 namespace internal {
59 
60 // Implementation class for PdtExpandFst.
61 template <class Arc>
62 class PdtExpandFstImpl : public CacheImpl<Arc> {
63  public:
64  using Label = typename Arc::Label;
65  using StateId = typename Arc::StateId;
66  using Weight = typename Arc::Weight;
67 
68  using StackId = StateId;
70 
76 
77  using CacheBaseImpl<CacheState<Arc>>::PushArc;
78  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
79  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
80  using CacheBaseImpl<CacheState<Arc>>::HasStart;
81  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
82  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
83  using CacheBaseImpl<CacheState<Arc>>::SetStart;
84 
86  const std::vector<std::pair<Label, Label>> &parens,
87  const PdtExpandFstOptions<Arc> &opts)
88  : CacheImpl<Arc>(opts),
89  fst_(fst.Copy()),
90  stack_(opts.stack ? opts.stack : new PdtStack<StateId, Label>(parens)),
91  state_table_(opts.state_table ? opts.state_table
92  : new PdtStateTable<StateId, StackId>()),
93  own_stack_(opts.stack == nullptr),
94  own_state_table_(opts.state_table == nullptr),
95  keep_parentheses_(opts.keep_parentheses) {
96  SetType("expand");
97  const auto props = fst.Properties(kFstProperties, false);
98  SetProperties(PdtExpandProperties(props), kCopyProperties);
99  SetInputSymbols(fst.InputSymbols());
100  SetOutputSymbols(fst.OutputSymbols());
101  }
102 
104  : CacheImpl<Arc>(impl),
105  fst_(impl.fst_->Copy(true)),
106  stack_(new PdtStack<StateId, Label>(*impl.stack_)),
107  state_table_(new PdtStateTable<StateId, StackId>()),
108  own_stack_(true),
109  own_state_table_(true),
110  keep_parentheses_(impl.keep_parentheses_) {
111  SetType("expand");
112  SetProperties(impl.Properties(), kCopyProperties);
113  SetInputSymbols(impl.InputSymbols());
114  SetOutputSymbols(impl.OutputSymbols());
115  }
116 
117  ~PdtExpandFstImpl() override {
118  if (own_stack_) delete stack_;
119  if (own_state_table_) delete state_table_;
120  }
121 
123  if (!HasStart()) {
124  const auto s = fst_->Start();
125  if (s == kNoStateId) return kNoStateId;
126  StateTuple tuple(s, 0);
127  const auto start = state_table_->FindState(tuple);
128  SetStart(start);
129  }
130  return CacheImpl<Arc>::Start();
131  }
132 
134  if (!HasFinal(s)) {
135  const auto &tuple = state_table_->Tuple(s);
136  const auto weight = fst_->Final(tuple.state_id);
137  if (weight != Weight::Zero() && tuple.stack_id == 0)
138  SetFinal(s, weight);
139  else
140  SetFinal(s, Weight::Zero());
141  }
142  return CacheImpl<Arc>::Final(s);
143  }
144 
145  size_t NumArcs(StateId s) {
146  if (!HasArcs(s)) ExpandState(s);
147  return CacheImpl<Arc>::NumArcs(s);
148  }
149 
151  if (!HasArcs(s)) ExpandState(s);
153  }
154 
156  if (!HasArcs(s)) ExpandState(s);
158  }
159 
161  if (!HasArcs(s)) ExpandState(s);
163  }
164 
165  // Computes the outgoing transitions from a state, creating new destination
166  // states as needed.
168  StateTuple tuple = state_table_->Tuple(s);
169  for (ArcIterator<Fst<Arc>> aiter(*fst_, tuple.state_id); !aiter.Done();
170  aiter.Next()) {
171  auto arc = aiter.Value();
172  const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
173  if (stack_id == -1) { // Non-matching close parenthesis.
174  continue;
175  } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
176  // Stack push/pop.
177  arc.ilabel = 0;
178  arc.olabel = 0;
179  }
180  StateTuple ntuple(arc.nextstate, stack_id);
181  arc.nextstate = state_table_->FindState(ntuple);
182  PushArc(s, arc);
183  }
184  SetArcs(s);
185  }
186 
187  const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
188 
190  return *state_table_;
191  }
192 
193  private:
194  // Properties for an expanded PDT.
195  inline uint64_t PdtExpandProperties(uint64_t inprops) {
196  return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
197  }
198 
199  std::unique_ptr<const Fst<Arc>> fst_;
200  PdtStack<StackId, Label> *stack_;
201  PdtStateTable<StateId, StackId> *state_table_;
202  bool own_stack_;
203  bool own_state_table_;
204  bool keep_parentheses_;
205 };
206 
207 } // namespace internal
208 
209 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
210 // version is a delayed FST. In the PDT, some transitions are labeled with open
211 // or close parentheses. To be interpreted as a PDT, the parens must balance on
212 // a path. The open-close parenthesis label pairs are passed using the parens
213 // argument. The expansion enforces the parenthesis constraints. The PDT must be
214 // expandable as an FST.
215 //
216 // This class attaches interface to implementation and handles reference
217 // counting, delegating most methods to ImplToFst.
218 template <class A>
219 class PdtExpandFst : public ImplToFst<internal::PdtExpandFstImpl<A>> {
220  public:
221  using Arc = A;
222 
223  using Label = typename Arc::Label;
224  using StateId = typename Arc::StateId;
225  using Weight = typename Arc::Weight;
226 
227  using StackId = StateId;
229  using State = typename Store::State;
231 
232  friend class ArcIterator<PdtExpandFst<Arc>>;
234 
236  const std::vector<std::pair<Label, Label>> &parens)
237  : ImplToFst<Impl>(
238  std::make_shared<Impl>(fst, parens, PdtExpandFstOptions<A>())) {}
239 
241  const std::vector<std::pair<Label, Label>> &parens,
242  const PdtExpandFstOptions<Arc> &opts)
243  : ImplToFst<Impl>(std::make_shared<Impl>(fst, parens, opts)) {}
244 
245  // See Fst<>::Copy() for doc.
246  PdtExpandFst(const PdtExpandFst<Arc> &fst, bool safe = false)
247  : ImplToFst<Impl>(fst, safe) {}
248 
249  // Gets a copy of this ExpandFst. See Fst<>::Copy() for further doc.
250  PdtExpandFst<Arc> *Copy(bool safe = false) const override {
251  return new PdtExpandFst<Arc>(*this, safe);
252  }
253 
254  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
255 
256  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
257  GetMutableImpl()->InitArcIterator(s, data);
258  }
259 
261  return GetImpl()->GetStack();
262  }
263 
265  return GetImpl()->GetStateTable();
266  }
267 
268  private:
271 
272  void operator=(const PdtExpandFst &) = delete;
273 };
274 
275 // Specialization for PdtExpandFst.
276 template <class Arc>
278  : public CacheStateIterator<PdtExpandFst<Arc>> {
279  public:
281  : CacheStateIterator<PdtExpandFst<Arc>>(fst, fst.GetMutableImpl()) {}
282 };
283 
284 // Specialization for PdtExpandFst.
285 template <class Arc>
287  : public CacheArcIterator<PdtExpandFst<Arc>> {
288  public:
289  using StateId = typename Arc::StateId;
290 
292  : CacheArcIterator<PdtExpandFst<Arc>>(fst.GetMutableImpl(), s) {
293  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s);
294  }
295 };
296 
297 template <class Arc>
299  StateIteratorData<Arc> *data) const {
300  data->base = std::make_unique<StateIterator<PdtExpandFst<Arc>>>(*this);
301 }
302 
303 // PrunedExpand prunes the delayed expansion of a pushdown transducer (PDT)
304 // encoded as an FST into an FST. In the PDT, some transitions are labeled with
305 // open or close parentheses. To be interpreted as a PDT, the parens must
306 // balance on a path. The open-close parenthesis label pairs are passed
307 // using the parens argument. The expansion enforces the parenthesis
308 // constraints.
309 //
310 // The algorithm works by visiting the delayed ExpandFst using a shortest-stack
311 // first queue discipline and relies on the shortest-distance information
312 // computed using a reverse shortest-path call to perform the pruning.
313 // Requires Arc::Weight is idempotent.
314 //
315 // The algorithm maintains the same state ordering between the ExpandFst being
316 // visited (efst_) and the result of pruning written into the MutableFst (ofst_)
317 // to improve readability.
318 template <class Arc>
320  public:
321  using Label = typename Arc::Label;
322  using StateId = typename Arc::StateId;
323  using Weight = typename Arc::Weight;
324  static_assert(IsIdempotent<Weight>::value, "Weight must be idempotent.");
325 
326  using StackId = StateId;
330 
331  // Constructor taking as input a PDT specified by by an input FST and a vector
332  // of parentheses. The keep_parentheses argument specifies whether parentheses
333  // are replaced by epsilons or not during the expansion. The cache options are
334  // passed to the underlying ExpandFst.
336  const std::vector<std::pair<Label, Label>> &parens,
337  bool keep_parentheses = false,
338  const CacheOptions &opts = CacheOptions())
339  : ifst_(ifst.Copy()),
340  keep_parentheses_(keep_parentheses),
341  stack_(parens),
342  efst_(ifst, parens,
343  PdtExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
344  queue_(state_table_, stack_, stack_length_, distance_, fdistance_),
345  error_(false) {
346  Reverse(*ifst_, parens, &rfst_);
347  VectorFst<Arc> path;
348  reverse_shortest_path_.reset(new PdtShortestPath<Arc, FifoQueue<StateId>>(
349  rfst_, parens,
350  PdtShortestPathOptions<Arc, FifoQueue<StateId>>(true, false)));
351  reverse_shortest_path_->ShortestPath(&path);
352  error_ = (path.Properties(kError, true) == kError);
353  balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse(
354  rfst_.NumStates(), 10, -1));
355  InitCloseParenMultimap(parens);
356  }
357 
358  bool Error() const { return error_; }
359 
360  // Expands and prunes the input PDT according to the provided weight
361  // threshold, wirting the result into an output mutable FST.
362  void Expand(MutableFst<Arc> *ofst, const Weight &threshold);
363 
364  private:
365  static constexpr uint8_t kEnqueued = 0x01;
366  static constexpr uint8_t kExpanded = 0x02;
367  static constexpr uint8_t kSourceState = 0x04;
368 
369  // Comparison functor used by the queue:
370  //
371  // 1. States corresponding to shortest stack first, and
372  // 2. for stacks of matching length, reverse lexicographic order is used, and
373  // 3. for states with the same stack, shortest-first order is used.
374  class StackCompare {
375  public:
376  StackCompare(const StateTable &state_table, const Stack &stack,
377  const std::vector<StackId> &stack_length,
378  const std::vector<Weight> &distance,
379  const std::vector<Weight> &fdistance)
380  : state_table_(state_table),
381  stack_(stack),
382  stack_length_(stack_length),
383  distance_(distance),
384  fdistance_(fdistance) {}
385 
386  bool operator()(StateId s1, StateId s2) const {
387  auto si1 = state_table_.Tuple(s1).stack_id;
388  auto si2 = state_table_.Tuple(s2).stack_id;
389  if (stack_length_[si1] < stack_length_[si2]) return true;
390  if (stack_length_[si1] > stack_length_[si2]) return false;
391  // If stack IDs are equal, use A*.
392  if (si1 == si2) {
393  return less_(Distance(s1), Distance(s2));
394  }
395  // If lengths are equal, uses reverse lexicographic order.
396  for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
397  if (stack_.Top(si1) < stack_.Top(si2)) return true;
398  if (stack_.Top(si1) > stack_.Top(si2)) return false;
399  }
400  return false;
401  }
402 
403  private:
404  Weight Distance(StateId s) const {
405  return (s < distance_.size()) && (s < fdistance_.size())
406  ? Times(distance_[s], fdistance_[s])
407  : Weight::Zero();
408  }
409 
410  const StateTable &state_table_;
411  const Stack &stack_;
412  const std::vector<StackId> &stack_length_;
413  const std::vector<Weight> &distance_;
414  const std::vector<Weight> &fdistance_;
415  const NaturalLess<Weight> less_;
416  };
417 
418  // Requires Weight is idempotent.
419  class ShortestStackFirstQueue
420  : public ShortestFirstQueue<StateId, StackCompare> {
421  public:
422  ShortestStackFirstQueue(const PdtStateTable<StateId, StackId> &state_table,
423  const Stack &stack,
424  const std::vector<StackId> &stack_length,
425  const std::vector<Weight> &distance,
426  const std::vector<Weight> &fdistance)
428  state_table, stack, stack_length, distance, fdistance)) {}
429  };
430 
431  void InitCloseParenMultimap(
432  const std::vector<std::pair<Label, Label>> &parens);
433 
434  Weight DistanceToDest(StateId source, StateId dest) const;
435 
436  uint8_t Flags(StateId s) const;
437 
438  void SetFlags(StateId s, uint8_t flags, uint8_t mask);
439 
440  Weight Distance(StateId s) const;
441 
442  void SetDistance(StateId s, Weight weight);
443 
444  Weight FinalDistance(StateId s) const;
445 
446  void SetFinalDistance(StateId s, Weight weight);
447 
448  StateId SourceState(StateId s) const;
449 
450  void SetSourceState(StateId s, StateId p);
451 
452  void AddStateAndEnqueue(StateId s);
453 
454  void Relax(StateId s, const Arc &arc, Weight weight);
455 
456  bool PruneArc(StateId s, const Arc &arc);
457 
458  void ProcStart();
459 
460  void ProcFinal(StateId s);
461 
462  bool ProcNonParen(StateId s, const Arc &arc, bool add_arc);
463 
464  bool ProcOpenParen(StateId s, const Arc &arc, StackId si, StackId nsi);
465 
466  bool ProcCloseParen(StateId s, const Arc &arc);
467 
468  void ProcDestStates(StateId s, StackId si);
469 
470  // Input PDT.
471  std::unique_ptr<Fst<Arc>> ifst_;
472  // Reversed PDT.
473  VectorFst<Arc> rfst_;
474  // Keep parentheses in ofst?
475  const bool keep_parentheses_;
476  // State table for efst_.
477  StateTable state_table_;
478  // Stack trie.
479  Stack stack_;
480  // Expanded PDT.
481  PdtExpandFst<Arc> efst_;
482  // Length of stack for given stack ID.
483  std::vector<StackId> stack_length_;
484  // Distance from initial state in efst_/ofst.
485  std::vector<Weight> distance_;
486  // Distance to final states in efst_/ofst.
487  std::vector<Weight> fdistance_;
488  // Queue used to visit efst_.
489  ShortestStackFirstQueue queue_;
490  // Construction time failure?
491  bool error_;
492  // Status flags for states in efst_/ofst.
493  std::vector<uint8_t> flags_;
494  // PDT source state for each expanded state.
495  std::vector<StateId> sources_;
496  // Shortest path for rfst_.
497  std::unique_ptr<PdtShortestPath<Arc, FifoQueue<StateId>>>
498  reverse_shortest_path_;
499  std::unique_ptr<internal::PdtBalanceData<Arc>> balance_data_;
500  // Maps open paren arcs to balancing close paren arcs.
501  typename PdtShortestPath<Arc, FifoQueue<StateId>>::CloseParenMultimap
502  close_paren_multimap_;
503  MutableFst<Arc> *ofst_; // Output FST.
504  Weight limit_; // Weight limit.
505 
506  // Maps a state s in ifst (i.e., the source of a closed paranthesis matching
507  // the top of current_stack_id_ to final states in efst_.
508  std::unordered_map<StateId, Weight> dest_map_;
509  // Stack ID of the states currently at the top of the queue, i.e., the states
510  // currently being popped and processed.
511  StackId current_stack_id_;
512  ssize_t current_paren_id_; // Paren ID at top of current stack.
513  ssize_t cached_stack_id_;
514  StateId cached_source_;
515  // The set of pairs of destination states and weights to final states for the
516  // source state cached_source_ and the paren ID cached_paren_id_; i.e., the
517  // set of source states of a closed parenthesis with paren ID cached_paren_id
518  // balancing an incoming open parenthesis with paren ID cached_paren_id_ in
519  // state cached_source_.
520  std::forward_list<std::pair<StateId, Weight>> cached_dest_list_;
521  NaturalLess<Weight> less_;
522 };
523 
524 // Initializes close paren multimap, mapping pairs (s, paren_id) to all the arcs
525 // out of s labeled with close parenthese for paren_id.
526 template <class Arc>
528  const std::vector<std::pair<Label, Label>> &parens) {
529  std::unordered_map<Label, Label> paren_map;
530  for (size_t i = 0; i < parens.size(); ++i) {
531  const auto &pair = parens[i];
532  paren_map[pair.first] = i;
533  paren_map[pair.second] = i;
534  }
535  for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
536  const auto s = siter.Value();
537  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
538  const auto &arc = aiter.Value();
539  const auto it = paren_map.find(arc.ilabel);
540  if (it == paren_map.end()) continue;
541  if (arc.ilabel == parens[it->second].second) { // Close paren.
542  const internal::ParenState<Arc> key(it->second, s);
543  close_paren_multimap_.emplace(key, arc);
544  }
545  }
546  }
547 }
548 
549 // Returns the weight of the shortest balanced path from source to dest
550 // in ifst_; dest must be the source state of a close paren arc.
551 template <class Arc>
552 typename Arc::Weight PdtPrunedExpand<Arc>::DistanceToDest(StateId source,
553  StateId dest) const {
554  using SearchState =
555  typename PdtShortestPath<Arc, FifoQueue<StateId>>::SearchState;
556  const SearchState ss(source + 1, dest + 1);
557  const auto distance =
558  reverse_shortest_path_->GetShortestPathData().Distance(ss);
559  VLOG(2) << "D(" << source << ", " << dest << ") =" << distance;
560  return distance;
561 }
562 
563 // Returns the flags for state s in ofst_.
564 template <class Arc>
565 uint8_t PdtPrunedExpand<Arc>::Flags(StateId s) const {
566  return s < flags_.size() ? flags_[s] : 0;
567 }
568 
569 // Modifies the flags for state s in ofst_.
570 template <class Arc>
571 void PdtPrunedExpand<Arc>::SetFlags(StateId s, uint8_t flags, uint8_t mask) {
572  while (flags_.size() <= s) flags_.push_back(0);
573  flags_[s] &= ~mask;
574  flags_[s] |= flags & mask;
575 }
576 
577 // Returns the shortest distance from the initial state to s in ofst_.
578 template <class Arc>
579 typename Arc::Weight PdtPrunedExpand<Arc>::Distance(StateId s) const {
580  return s < distance_.size() ? distance_[s] : Weight::Zero();
581 }
582 
583 // Sets the shortest distance from the initial state to s in ofst_.
584 template <class Arc>
586  while (distance_.size() <= s) distance_.push_back(Weight::Zero());
587  distance_[s] = std::move(weight);
588 }
589 
590 // Returns the shortest distance from s to the final states in ofst_.
591 template <class Arc>
592 typename Arc::Weight PdtPrunedExpand<Arc>::FinalDistance(StateId s) const {
593  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
594 }
595 
596 // Sets the shortest distance from s to the final states in ofst_.
597 template <class Arc>
599  while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
600  fdistance_[s] = std::move(weight);
601 }
602 
603 // Returns the PDT source state of state s in ofst_.
604 template <class Arc>
605 typename Arc::StateId PdtPrunedExpand<Arc>::SourceState(StateId s) const {
606  return s < sources_.size() ? sources_[s] : kNoStateId;
607 }
608 
609 // Sets the PDT source state of state s in ofst_ to state p'in ifst_.
610 template <class Arc>
612  while (sources_.size() <= s) sources_.push_back(kNoStateId);
613  sources_[s] = p;
614 }
615 
616 // Adds state s of efst_ to ofst_ and inserts it in the queue, modifying the
617 // flags for s accordingly.
618 template <class Arc>
620  if (!(Flags(s) & (kEnqueued | kExpanded))) {
621  while (ofst_->NumStates() <= s) ofst_->AddState();
622  queue_.Enqueue(s);
623  SetFlags(s, kEnqueued, kEnqueued);
624  } else if (Flags(s) & kEnqueued) {
625  queue_.Update(s);
626  }
627  // TODO(allauzen): Check everything is fine when kExpanded?
628 }
629 
630 // Relaxes arc out of state s in ofst_ as follows:
631 //
632 // 1. If the distance to s times the weight of arc is smaller than
633 // the currently stored distance for arc.nextstate, updates
634 // Distance(arc.nextstate) with a new estimate
635 // 2. If fd is less than the currently stored distance from arc.nextstate to the
636 // final state, updates with new estimate.
637 template <class Arc>
638 void PdtPrunedExpand<Arc>::Relax(StateId s, const Arc &arc, Weight fd) {
639  const auto nd = Times(Distance(s), arc.weight);
640  if (less_(nd, Distance(arc.nextstate))) {
641  SetDistance(arc.nextstate, nd);
642  SetSourceState(arc.nextstate, SourceState(s));
643  }
644  if (less_(fd, FinalDistance(arc.nextstate))) {
645  SetFinalDistance(arc.nextstate, fd);
646  }
647  VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
648  << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
649  << ", nd = " << nd;
650 }
651 
652 // Returns whether the arc out of state s in efst needs pruned.
653 template <class Arc>
654 bool PdtPrunedExpand<Arc>::PruneArc(StateId s, const Arc &arc) {
655  VLOG(2) << "Prune ?";
656  auto fd = Weight::Zero();
657  if ((cached_source_ != SourceState(s)) ||
658  (cached_stack_id_ != current_stack_id_)) {
659  cached_source_ = SourceState(s);
660  cached_stack_id_ = current_stack_id_;
661  cached_dest_list_.clear();
662  if (cached_source_ != ifst_->Start()) {
663  for (auto set_iter =
664  balance_data_->Find(current_paren_id_, cached_source_);
665  !set_iter.Done(); set_iter.Next()) {
666  auto dest = set_iter.Element();
667  const auto it = dest_map_.find(dest);
668  cached_dest_list_.push_front(*it);
669  }
670  } else {
671  // TODO(allauzen): queue discipline should prevent this from ever
672  // happening.
673  // Replace by a check.
674  cached_dest_list_.push_front(
675  std::make_pair(rfst_.Start() - 1, Weight::One()));
676  }
677  }
678  for (auto it = cached_dest_list_.begin(); it != cached_dest_list_.end();
679  ++it) {
680  const auto d =
681  DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first);
682  fd = Plus(fd, Times(d, it->second));
683  }
684  Relax(s, arc, fd);
685  return less_(limit_, Times(Distance(s), Times(arc.weight, fd)));
686 }
687 
688 // Adds start state of efst_ to ofst_, enqueues it, and initializes the distance
689 // data structures.
690 template <class Arc>
692  const auto s = efst_.Start();
693  AddStateAndEnqueue(s);
694  ofst_->SetStart(s);
695  SetSourceState(s, ifst_->Start());
696  current_stack_id_ = 0;
697  current_paren_id_ = -1;
698  stack_length_.push_back(0);
699  const auto r = rfst_.Start() - 1;
700  cached_source_ = ifst_->Start();
701  cached_stack_id_ = 0;
702  cached_dest_list_.push_front(std::make_pair(r, Weight::One()));
703  const PdtStateTuple<StateId, StackId> tuple(r, 0);
704  SetFinalDistance(state_table_.FindState(tuple), Weight::One());
705  SetDistance(s, Weight::One());
706  const auto d = DistanceToDest(ifst_->Start(), r);
707  SetFinalDistance(s, d);
708  VLOG(2) << d;
709 }
710 
711 // Makes s final in ofst_ if shortest accepting path ending in s is below
712 // threshold.
713 template <class Arc>
715  const auto weight = efst_.Final(s);
716  if (weight == Weight::Zero()) return;
717  if (less_(limit_, Times(Distance(s), weight))) return;
718  ofst_->SetFinal(s, weight);
719 }
720 
721 // Returns true when an arc (or meta-arc) leaving state s in efst_ is below the
722 // threshold. When add_arc is true, arc is added to ofst_.
723 template <class Arc>
724 bool PdtPrunedExpand<Arc>::ProcNonParen(StateId s, const Arc &arc,
725  bool add_arc) {
726  VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate << ", "
727  << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
728  << ", add_arc = " << (add_arc ? "true" : "false");
729  if (PruneArc(s, arc)) return false;
730  if (add_arc) ofst_->AddArc(s, arc);
731  AddStateAndEnqueue(arc.nextstate);
732  return true;
733 }
734 
735 // Processes an open paren arc leaving state s in ofst_. When the arc is labeled
736 // with an open paren,
737 //
738 // 1. Considers each (shortest) balanced path starting in s by taking the arc
739 // and ending by a close paren balancing the open paren of as a meta-arc,
740 // processing and pruning each meta-arc as a non-paren arc, inserting its
741 // destination to the queue;
742 // 2. if at least one of these meta-arcs has not been pruned, adds the
743 // destination of arc to ofst_ as a new source state for the stack ID nsi, and
744 // inserts it in the queue.
745 template <class Arc>
746 bool PdtPrunedExpand<Arc>::ProcOpenParen(StateId s, const Arc &arc, StackId si,
747  StackId nsi) {
748  // Updates the stack length when needed.
749  while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
750  if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1;
751  const auto ns = arc.nextstate;
752  VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
753  << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
754  bool proc_arc = false;
755  auto fd = Weight::Zero();
756  const auto paren_id = stack_.ParenId(arc.ilabel);
757  std::forward_list<StateId> sources;
758  for (auto set_iter =
759  balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
760  !set_iter.Done(); set_iter.Next()) {
761  sources.push_front(set_iter.Element());
762  }
763  for (const auto source : sources) {
764  VLOG(2) << "Close paren source: " << source;
765  const internal::ParenState<Arc> paren_state(paren_id, source);
766  for (auto it = close_paren_multimap_.find(paren_state);
767  it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
768  auto meta_arc = it->second;
769  const PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
770  meta_arc.nextstate = state_table_.FindState(tuple);
771  const auto state_id = state_table_.Tuple(ns).state_id;
772  const auto d = DistanceToDest(state_id, source);
773  VLOG(2) << state_id << ", " << source;
774  VLOG(2) << "Meta arc weight = " << arc.weight << " Times " << d
775  << " Times " << meta_arc.weight;
776  meta_arc.weight = Times(arc.weight, Times(d, meta_arc.weight));
777  proc_arc |= ProcNonParen(s, meta_arc, false);
778  fd = Plus(
779  fd,
780  Times(Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
781  it->second.weight),
782  FinalDistance(meta_arc.nextstate)));
783  }
784  }
785  if (proc_arc) {
786  VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
787  ofst_->AddArc(
788  s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
789  AddStateAndEnqueue(arc.nextstate);
790  const auto nd = Times(Distance(s), arc.weight);
791  if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd);
792  // FinalDistance not necessary for source state since pruning decided using
793  // meta-arcs above. But this is a problem with A*, hence the following.
794  if (less_(fd, FinalDistance(arc.nextstate)))
795  SetFinalDistance(arc.nextstate, fd);
796  SetFlags(arc.nextstate, kSourceState, kSourceState);
797  }
798  return proc_arc;
799 }
800 
801 // Checks that shortest path through close paren arc in efst_ is below
802 // threshold, and if so, adds it to ofst_.
803 template <class Arc>
804 bool PdtPrunedExpand<Arc>::ProcCloseParen(StateId s, const Arc &arc) {
805  const auto weight =
806  Times(Distance(s), Times(arc.weight, FinalDistance(arc.nextstate)));
807  if (less_(limit_, weight)) return false;
808  ofst_->AddArc(s,
809  keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
810  return true;
811 }
812 
813 // When state s in ofst_ is a source state for stack ID si, identifies all the
814 // corresponding possible destination states, that is, all the states in ifst_
815 // that have an outgoing close paren arc balancing the incoming open paren taken
816 // to get to s. For each such state t, computes the shortest distance from (t,
817 // si) to the final states in ofst_. Stores this information in dest_map_.
818 template <class Arc>
820  if (!(Flags(s) & kSourceState)) return;
821  if (si != current_stack_id_) {
822  dest_map_.clear();
823  current_stack_id_ = si;
824  current_paren_id_ = stack_.Top(current_stack_id_);
825  VLOG(2) << "StackID " << si << " dequeued for first time";
826  }
827  // TODO(allauzen): clean up source state business; rename current function to
828  // ProcSourceState.
829  SetSourceState(s, state_table_.Tuple(s).state_id);
830  const auto paren_id = stack_.Top(si);
831  for (auto set_iter =
832  balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
833  !set_iter.Done(); set_iter.Next()) {
834  const auto dest_state = set_iter.Element();
835  if (dest_map_.find(dest_state) != dest_map_.end()) continue;
836  auto dest_weight = Weight::Zero();
837  internal::ParenState<Arc> paren_state(paren_id, dest_state);
838  for (auto it = close_paren_multimap_.find(paren_state);
839  it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
840  const auto &arc = it->second;
841  const PdtStateTuple<StateId, StackId> tuple(arc.nextstate,
842  stack_.Pop(si));
843  dest_weight =
844  Plus(dest_weight,
845  Times(arc.weight, FinalDistance(state_table_.FindState(tuple))));
846  }
847  dest_map_[dest_state] = dest_weight;
848  VLOG(2) << "State " << dest_state << " is a dest state for stack ID " << si
849  << " with weight " << dest_weight;
850  }
851 }
852 
853 // Expands and prunes the input PDT, writing the result in ofst.
854 template <class Arc>
856  const typename Arc::Weight &threshold) {
857  ofst_ = ofst;
858  if (error_) {
859  ofst_->SetProperties(kError, kError);
860  return;
861  }
862  ofst_->DeleteStates();
863  ofst_->SetInputSymbols(ifst_->InputSymbols());
864  ofst_->SetOutputSymbols(ifst_->OutputSymbols());
865  limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
866  flags_.clear();
867  ProcStart();
868  while (!queue_.Empty()) {
869  const auto s = queue_.Head();
870  queue_.Dequeue();
871  SetFlags(s, kExpanded, kExpanded | kEnqueued);
872  VLOG(2) << s << " dequeued!";
873  ProcFinal(s);
874  StackId stack_id = state_table_.Tuple(s).stack_id;
875  ProcDestStates(s, stack_id);
876  for (ArcIterator<PdtExpandFst<Arc>> aiter(efst_, s); !aiter.Done();
877  aiter.Next()) {
878  const auto &arc = aiter.Value();
879  const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
880  if (stack_id == nextstack_id) {
881  ProcNonParen(s, arc, true);
882  } else if (stack_id == stack_.Pop(nextstack_id)) {
883  ProcOpenParen(s, arc, stack_id, nextstack_id);
884  } else {
885  ProcCloseParen(s, arc);
886  }
887  }
888  VLOG(2) << "d[" << s << "] = " << Distance(s) << ", fd[" << s
889  << "] = " << FinalDistance(s);
890  }
891 }
892 
893 // Expand functions.
894 
895 template <class Arc>
897  using Weight = typename Arc::Weight;
898 
899  bool connect;
902 
903  explicit PdtExpandOptions(bool connect = true, bool keep_parentheses = false,
904  Weight weight_threshold = Weight::Zero())
905  : connect(connect),
906  keep_parentheses(keep_parentheses),
907  weight_threshold(std::move(weight_threshold)) {}
908 };
909 
910 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
911 // version writes the expanded PDT to a mutable FST. In the PDT, some
912 // transitions are labeled with open or close parentheses. To be interpreted as
913 // a PDT, the parens must balance on a path. The open-close parenthesis label
914 // pairs are passed using the parens argument. Expansion enforces the
915 // parenthesis constraints. The PDT must be expandable as an FST.
916 template <class Arc>
917 void Expand(
918  const Fst<Arc> &ifst,
919  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
920  &parens,
921  MutableFst<Arc> *ofst, const PdtExpandOptions<Arc> &opts) {
922  using Weight = typename Arc::Weight;
924  eopts.gc_limit = 0;
925  if (opts.weight_threshold == Arc::Weight::Zero()) {
926  eopts.keep_parentheses = opts.keep_parentheses;
927  *ofst = PdtExpandFst<Arc>(ifst, parens, eopts);
928  } else if constexpr (IsIdempotent<Weight>::value) {
929  PdtPrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
930  pruned_expand.Expand(ofst, opts.weight_threshold);
931  } else {
932  FSTERROR() << "Expand: non-Zero weight_threshold with non-idempotent"
933  << " Weight " << Weight::Type();
934  ofst->SetProperties(kError, kError);
935  }
936  if (opts.connect) Connect(ofst);
937 }
938 
939 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
940 // version writes the expanded PDT result to a mutable FST. In the PDT, some
941 // transitions are labeled with open or close parentheses. To be interpreted as
942 // a PDT, the parens must balance on a path. The open-close parenthesis label
943 // pairs are passed using the parents argument. Expansion enforces the
944 // parenthesis constraints. The PDT must be expandable as an FST.
945 template <class Arc>
946 void Expand(
947  const Fst<Arc> &ifst,
948  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
949  &parens,
950  MutableFst<Arc> *ofst, bool connect = true, bool keep_parentheses = false) {
951  const PdtExpandOptions<Arc> opts(connect, keep_parentheses);
952  Expand(ifst, parens, ofst, opts);
953 }
954 
955 } // namespace fst
956 
957 #endif // FST_EXTENSIONS_PDT_EXPAND_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:114
typename Arc::StateId StateId
Definition: expand.h:224
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:189
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: expand.h:298
typename Arc::Label Label
Definition: expand.h:64
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: expand.h:160
ArcIterator(const PdtExpandFst< Arc > &fst, StateId s)
Definition: expand.h:291
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
Definition: cache.h:47
virtual uint64_t Properties(uint64_t mask, bool test) const =0
StackId stack_id
Definition: pdt.h:139
typename Store::State State
Definition: expand.h:229
Weight Final(StateId s)
Definition: expand.h:133
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
size_t NumOutputEpsilons(StateId s)
Definition: expand.h:155
size_t NumInputEpsilons(StateId s)
Definition: expand.h:150
PdtExpandFst(const PdtExpandFst< Arc > &fst, bool safe=false)
Definition: expand.h:246
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:97
const SymbolTable * OutputSymbols() const
Definition: fst.h:762
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, const char *src="")
Definition: flags.cc:56
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens)
Definition: expand.h:235
constexpr uint64_t kError
Definition: properties.h:51
typename Arc::StateId StateId
Definition: expand.h:289
constexpr uint64_t kInitialAcyclic
Definition: properties.h:115
typename Arc::StateId StateId
Definition: expand.h:65
PdtStack< typename Arc::StateId, typename Arc::Label > * stack
Definition: expand.h:44
SetType
Definition: set-weight.h:52
typename PdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1149
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:278
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:103
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:156
typename internal::PdtBalanceData< Arc >::SetIterator SetIterator
Definition: expand.h:329
constexpr int kNoStateId
Definition: fst.h:202
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: expand.h:256
typename Arc::Weight Weight
Definition: expand.h:897
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:34
void ExpandState(StateId s)
Definition: expand.h:167
const PdtStack< StackId, Label > & GetStack() const
Definition: expand.h:187
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
Definition: expand.h:240
bool Error() const
Definition: expand.h:358
#define FSTERROR()
Definition: util.h:53
typename Arc::Weight Weight
Definition: expand.h:225
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:354
size_t NumArcs(StateId s)
Definition: expand.h:145
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
Definition: expand.h:45
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:108
PdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool keep_parentheses=false, PdtStack< typename Arc::StateId, typename Arc::Label > *stack=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *state_table=nullptr)
Definition: expand.h:47
virtual uint64_t Properties() const
Definition: fst.h:702
typename Arc::StateId StateId
Definition: expand.h:322
constexpr uint64_t kCopyProperties
Definition: properties.h:162
constexpr uint64_t kAcyclic
Definition: properties.h:110
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:666
PdtExpandOptions(bool connect=true, bool keep_parentheses=false, Weight weight_threshold=Weight::Zero())
Definition: expand.h:903
void Expand(MutableFst< Arc > *ofst, const Weight &threshold)
Definition: expand.h:855
#define VLOG(level)
Definition: log.h:50
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
typename Arc::Weight Weight
Definition: expand.h:66
PdtExpandFst< Arc > * Copy(bool safe=false) const override
Definition: expand.h:250
PdtExpandFstImpl(const PdtExpandFstImpl &impl)
Definition: expand.h:103
StateIterator(const PdtExpandFst< Arc > &fst)
Definition: expand.h:280
Weight weight_threshold
Definition: expand.h:901
constexpr uint64_t kFstProperties
Definition: properties.h:325
constexpr uint64_t kUnweighted
Definition: properties.h:105
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:264
typename PdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1195
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:760
StateId state_id
Definition: pdt.h:138
PdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
Definition: expand.h:85
const PdtStack< StackId, Label > & GetStack() const
Definition: expand.h:260
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:316
typename Arc::Weight Weight
Definition: expand.h:323
typename CacheState< Arc >::Arc Arc
Definition: cache.h:852
Impl * GetMutableImpl() const
Definition: fst.h:1028
typename Arc::Label Label
Definition: expand.h:223
constexpr uint64_t kExpanded
Definition: properties.h:45
typename Arc::Label Label
Definition: expand.h:321
size_t gc_limit
Definition: cache.h:45
uint64_t Properties(uint64_t mask, bool test) const override
Definition: fst.h:966
const Impl * GetImpl() const
Definition: fst.h:1026
PdtPrunedExpand(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, bool keep_parentheses=false, const CacheOptions &opts=CacheOptions())
Definition: expand.h:335
constexpr uint64_t kAcceptor
Definition: properties.h:63
virtual const SymbolTable * OutputSymbols() const =0