FST  openfst-1.8.3
OpenFst Library
expand.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Expands a PDT to an FST.
19 
20 #ifndef FST_EXTENSIONS_PDT_EXPAND_H_
21 #define FST_EXTENSIONS_PDT_EXPAND_H_
22 
23 #include <sys/types.h>
24 
25 #include <cstddef>
26 #include <cstdint>
27 #include <forward_list>
28 #include <memory>
29 #include <unordered_map>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/log.h>
35 #include <fst/extensions/pdt/pdt.h>
38 #include <fst/cache.h>
39 #include <fst/connect.h>
40 #include <fst/fst.h>
41 #include <fst/impl-to-fst.h>
42 #include <fst/mutable-fst.h>
43 #include <fst/properties.h>
44 #include <fst/queue.h>
45 #include <fst/state-table.h>
46 #include <fst/util.h>
47 #include <fst/vector-fst.h>
48 #include <fst/weight.h>
49 #include <unordered_map>
50 
51 namespace fst {
52 
53 template <class Arc>
58 
60  const CacheOptions &opts = CacheOptions(), bool keep_parentheses = false,
63  nullptr)
64  : CacheOptions(opts),
65  keep_parentheses(keep_parentheses),
66  stack(stack),
67  state_table(state_table) {}
68 };
69 
70 namespace internal {
71 
72 // Implementation class for PdtExpandFst.
73 template <class Arc>
74 class PdtExpandFstImpl : public CacheImpl<Arc> {
75  public:
76  using Label = typename Arc::Label;
77  using StateId = typename Arc::StateId;
78  using Weight = typename Arc::Weight;
79 
80  using StackId = StateId;
82 
88 
89  using CacheBaseImpl<CacheState<Arc>>::PushArc;
90  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
91  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
92  using CacheBaseImpl<CacheState<Arc>>::HasStart;
93  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
94  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
95  using CacheBaseImpl<CacheState<Arc>>::SetStart;
96 
98  const std::vector<std::pair<Label, Label>> &parens,
99  const PdtExpandFstOptions<Arc> &opts)
100  : CacheImpl<Arc>(opts),
101  fst_(fst.Copy()),
102  stack_(opts.stack ? opts.stack : new PdtStack<StateId, Label>(parens)),
103  state_table_(opts.state_table ? opts.state_table
104  : new PdtStateTable<StateId, StackId>()),
105  own_stack_(opts.stack == nullptr),
106  own_state_table_(opts.state_table == nullptr),
107  keep_parentheses_(opts.keep_parentheses) {
108  SetType("expand");
109  const auto props = fst.Properties(kFstProperties, false);
110  SetProperties(PdtExpandProperties(props), kCopyProperties);
111  SetInputSymbols(fst.InputSymbols());
112  SetOutputSymbols(fst.OutputSymbols());
113  }
114 
116  : CacheImpl<Arc>(impl),
117  fst_(impl.fst_->Copy(true)),
118  stack_(new PdtStack<StateId, Label>(*impl.stack_)),
119  state_table_(new PdtStateTable<StateId, StackId>()),
120  own_stack_(true),
121  own_state_table_(true),
122  keep_parentheses_(impl.keep_parentheses_) {
123  SetType("expand");
124  SetProperties(impl.Properties(), kCopyProperties);
125  SetInputSymbols(impl.InputSymbols());
126  SetOutputSymbols(impl.OutputSymbols());
127  }
128 
129  ~PdtExpandFstImpl() override {
130  if (own_stack_) delete stack_;
131  if (own_state_table_) delete state_table_;
132  }
133 
135  if (!HasStart()) {
136  const auto s = fst_->Start();
137  if (s == kNoStateId) return kNoStateId;
138  StateTuple tuple(s, 0);
139  const auto start = state_table_->FindState(tuple);
140  SetStart(start);
141  }
142  return CacheImpl<Arc>::Start();
143  }
144 
146  if (!HasFinal(s)) {
147  const auto &tuple = state_table_->Tuple(s);
148  const auto weight = fst_->Final(tuple.state_id);
149  if (weight != Weight::Zero() && tuple.stack_id == 0)
150  SetFinal(s, weight);
151  else
152  SetFinal(s, Weight::Zero());
153  }
154  return CacheImpl<Arc>::Final(s);
155  }
156 
157  size_t NumArcs(StateId s) {
158  if (!HasArcs(s)) ExpandState(s);
159  return CacheImpl<Arc>::NumArcs(s);
160  }
161 
163  if (!HasArcs(s)) ExpandState(s);
165  }
166 
168  if (!HasArcs(s)) ExpandState(s);
170  }
171 
173  if (!HasArcs(s)) ExpandState(s);
175  }
176 
177  // Computes the outgoing transitions from a state, creating new destination
178  // states as needed.
180  StateTuple tuple = state_table_->Tuple(s);
181  for (ArcIterator<Fst<Arc>> aiter(*fst_, tuple.state_id); !aiter.Done();
182  aiter.Next()) {
183  auto arc = aiter.Value();
184  const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
185  if (stack_id == -1) { // Non-matching close parenthesis.
186  continue;
187  } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
188  // Stack push/pop.
189  arc.ilabel = 0;
190  arc.olabel = 0;
191  }
192  StateTuple ntuple(arc.nextstate, stack_id);
193  arc.nextstate = state_table_->FindState(ntuple);
194  PushArc(s, arc);
195  }
196  SetArcs(s);
197  }
198 
199  const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
200 
202  return *state_table_;
203  }
204 
205  private:
206  // Properties for an expanded PDT.
207  inline uint64_t PdtExpandProperties(uint64_t inprops) {
208  return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
209  }
210 
211  std::unique_ptr<const Fst<Arc>> fst_;
212  PdtStack<StackId, Label> *stack_;
213  PdtStateTable<StateId, StackId> *state_table_;
214  bool own_stack_;
215  bool own_state_table_;
216  bool keep_parentheses_;
217 };
218 
219 } // namespace internal
220 
221 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
222 // version is a delayed FST. In the PDT, some transitions are labeled with open
223 // or close parentheses. To be interpreted as a PDT, the parens must balance on
224 // a path. The open-close parenthesis label pairs are passed using the parens
225 // argument. The expansion enforces the parenthesis constraints. The PDT must be
226 // expandable as an FST.
227 //
228 // This class attaches interface to implementation and handles reference
229 // counting, delegating most methods to ImplToFst.
230 template <class A>
231 class PdtExpandFst : public ImplToFst<internal::PdtExpandFstImpl<A>> {
232  public:
233  using Arc = A;
234 
235  using Label = typename Arc::Label;
236  using StateId = typename Arc::StateId;
237  using Weight = typename Arc::Weight;
238 
239  using StackId = StateId;
241  using State = typename Store::State;
243 
244  friend class ArcIterator<PdtExpandFst<Arc>>;
246 
248  const std::vector<std::pair<Label, Label>> &parens)
249  : ImplToFst<Impl>(
250  std::make_shared<Impl>(fst, parens, PdtExpandFstOptions<A>())) {}
251 
253  const std::vector<std::pair<Label, Label>> &parens,
254  const PdtExpandFstOptions<Arc> &opts)
255  : ImplToFst<Impl>(std::make_shared<Impl>(fst, parens, opts)) {}
256 
257  // See Fst<>::Copy() for doc.
258  PdtExpandFst(const PdtExpandFst<Arc> &fst, bool safe = false)
259  : ImplToFst<Impl>(fst, safe) {}
260 
261  // Gets a copy of this ExpandFst. See Fst<>::Copy() for further doc.
262  PdtExpandFst<Arc> *Copy(bool safe = false) const override {
263  return new PdtExpandFst<Arc>(*this, safe);
264  }
265 
266  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
267 
268  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
269  GetMutableImpl()->InitArcIterator(s, data);
270  }
271 
273  return GetImpl()->GetStack();
274  }
275 
277  return GetImpl()->GetStateTable();
278  }
279 
280  private:
283 
284  void operator=(const PdtExpandFst &) = delete;
285 };
286 
287 // Specialization for PdtExpandFst.
288 template <class Arc>
290  : public CacheStateIterator<PdtExpandFst<Arc>> {
291  public:
293  : CacheStateIterator<PdtExpandFst<Arc>>(fst, fst.GetMutableImpl()) {}
294 };
295 
296 // Specialization for PdtExpandFst.
297 template <class Arc>
299  : public CacheArcIterator<PdtExpandFst<Arc>> {
300  public:
301  using StateId = typename Arc::StateId;
302 
304  : CacheArcIterator<PdtExpandFst<Arc>>(fst.GetMutableImpl(), s) {
305  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s);
306  }
307 };
308 
309 template <class Arc>
311  StateIteratorData<Arc> *data) const {
312  data->base = std::make_unique<StateIterator<PdtExpandFst<Arc>>>(*this);
313 }
314 
315 // PrunedExpand prunes the delayed expansion of a pushdown transducer (PDT)
316 // encoded as an FST into an FST. In the PDT, some transitions are labeled with
317 // open or close parentheses. To be interpreted as a PDT, the parens must
318 // balance on a path. The open-close parenthesis label pairs are passed
319 // using the parens argument. The expansion enforces the parenthesis
320 // constraints.
321 //
322 // The algorithm works by visiting the delayed ExpandFst using a shortest-stack
323 // first queue discipline and relies on the shortest-distance information
324 // computed using a reverse shortest-path call to perform the pruning.
325 // Requires Arc::Weight is idempotent.
326 //
327 // The algorithm maintains the same state ordering between the ExpandFst being
328 // visited (efst_) and the result of pruning written into the MutableFst (ofst_)
329 // to improve readability.
330 template <class Arc>
332  public:
333  using Label = typename Arc::Label;
334  using StateId = typename Arc::StateId;
335  using Weight = typename Arc::Weight;
336  static_assert(IsIdempotent<Weight>::value, "Weight must be idempotent.");
337 
338  using StackId = StateId;
342 
343  // Constructor taking as input a PDT specified by by an input FST and a vector
344  // of parentheses. The keep_parentheses argument specifies whether parentheses
345  // are replaced by epsilons or not during the expansion. The cache options are
346  // passed to the underlying ExpandFst.
348  const std::vector<std::pair<Label, Label>> &parens,
349  bool keep_parentheses = false,
350  const CacheOptions &opts = CacheOptions())
351  : ifst_(ifst.Copy()),
352  keep_parentheses_(keep_parentheses),
353  stack_(parens),
354  efst_(ifst, parens,
355  PdtExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
356  queue_(state_table_, stack_, stack_length_, distance_, fdistance_),
357  error_(false) {
358  Reverse(*ifst_, parens, &rfst_);
359  VectorFst<Arc> path;
360  reverse_shortest_path_.reset(new PdtShortestPath<Arc, FifoQueue<StateId>>(
361  rfst_, parens,
362  PdtShortestPathOptions<Arc, FifoQueue<StateId>>(true, false)));
363  reverse_shortest_path_->ShortestPath(&path);
364  error_ = (path.Properties(kError, true) == kError);
365  balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse(
366  rfst_.NumStates(), 10, -1));
367  InitCloseParenMultimap(parens);
368  }
369 
370  bool Error() const { return error_; }
371 
372  // Expands and prunes the input PDT according to the provided weight
373  // threshold, wirting the result into an output mutable FST.
374  void Expand(MutableFst<Arc> *ofst, const Weight &threshold);
375 
376  private:
377  static constexpr uint8_t kEnqueued = 0x01;
378  static constexpr uint8_t kExpanded = 0x02;
379  static constexpr uint8_t kSourceState = 0x04;
380 
381  // Comparison functor used by the queue:
382  //
383  // 1. States corresponding to shortest stack first, and
384  // 2. for stacks of matching length, reverse lexicographic order is used, and
385  // 3. for states with the same stack, shortest-first order is used.
386  class StackCompare {
387  public:
388  StackCompare(const StateTable &state_table, const Stack &stack,
389  const std::vector<StackId> &stack_length,
390  const std::vector<Weight> &distance,
391  const std::vector<Weight> &fdistance)
392  : state_table_(state_table),
393  stack_(stack),
394  stack_length_(stack_length),
395  distance_(distance),
396  fdistance_(fdistance),
397  less_() {}
398 
399  bool operator()(StateId s1, StateId s2) const {
400  auto si1 = state_table_.Tuple(s1).stack_id;
401  auto si2 = state_table_.Tuple(s2).stack_id;
402  if (stack_length_[si1] < stack_length_[si2]) return true;
403  if (stack_length_[si1] > stack_length_[si2]) return false;
404  // If stack IDs are equal, use A*.
405  if (si1 == si2) {
406  return less_(Distance(s1), Distance(s2));
407  }
408  // If lengths are equal, uses reverse lexicographic order.
409  for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
410  if (stack_.Top(si1) < stack_.Top(si2)) return true;
411  if (stack_.Top(si1) > stack_.Top(si2)) return false;
412  }
413  return false;
414  }
415 
416  private:
417  Weight Distance(StateId s) const {
418  return (s < distance_.size()) && (s < fdistance_.size())
419  ? Times(distance_[s], fdistance_[s])
420  : Weight::Zero();
421  }
422 
423  const StateTable &state_table_;
424  const Stack &stack_;
425  const std::vector<StackId> &stack_length_;
426  const std::vector<Weight> &distance_;
427  const std::vector<Weight> &fdistance_;
428  const NaturalLess<Weight> less_;
429  };
430 
431  // Requires Weight is idempotent.
432  class ShortestStackFirstQueue
433  : public ShortestFirstQueue<StateId, StackCompare> {
434  public:
435  ShortestStackFirstQueue(const PdtStateTable<StateId, StackId> &state_table,
436  const Stack &stack,
437  const std::vector<StackId> &stack_length,
438  const std::vector<Weight> &distance,
439  const std::vector<Weight> &fdistance)
441  state_table, stack, stack_length, distance, fdistance)) {}
442  };
443 
444  void InitCloseParenMultimap(
445  const std::vector<std::pair<Label, Label>> &parens);
446 
447  Weight DistanceToDest(StateId source, StateId dest) const;
448 
449  uint8_t Flags(StateId s) const;
450 
451  void SetFlags(StateId s, uint8_t flags, uint8_t mask);
452 
453  Weight Distance(StateId s) const;
454 
455  void SetDistance(StateId s, Weight weight);
456 
457  Weight FinalDistance(StateId s) const;
458 
459  void SetFinalDistance(StateId s, Weight weight);
460 
461  StateId SourceState(StateId s) const;
462 
463  void SetSourceState(StateId s, StateId p);
464 
465  void AddStateAndEnqueue(StateId s);
466 
467  void Relax(StateId s, const Arc &arc, Weight weight);
468 
469  bool PruneArc(StateId s, const Arc &arc);
470 
471  void ProcStart();
472 
473  void ProcFinal(StateId s);
474 
475  bool ProcNonParen(StateId s, const Arc &arc, bool add_arc);
476 
477  bool ProcOpenParen(StateId s, const Arc &arc, StackId si, StackId nsi);
478 
479  bool ProcCloseParen(StateId s, const Arc &arc);
480 
481  void ProcDestStates(StateId s, StackId si);
482 
483  // Input PDT.
484  std::unique_ptr<Fst<Arc>> ifst_;
485  // Reversed PDT.
486  VectorFst<Arc> rfst_;
487  // Keep parentheses in ofst?
488  const bool keep_parentheses_;
489  // State table for efst_.
490  StateTable state_table_;
491  // Stack trie.
492  Stack stack_;
493  // Expanded PDT.
494  PdtExpandFst<Arc> efst_;
495  // Length of stack for given stack ID.
496  std::vector<StackId> stack_length_;
497  // Distance from initial state in efst_/ofst.
498  std::vector<Weight> distance_;
499  // Distance to final states in efst_/ofst.
500  std::vector<Weight> fdistance_;
501  // Queue used to visit efst_.
502  ShortestStackFirstQueue queue_;
503  // Construction time failure?
504  bool error_;
505  // Status flags for states in efst_/ofst.
506  std::vector<uint8_t> flags_;
507  // PDT source state for each expanded state.
508  std::vector<StateId> sources_;
509  // Shortest path for rfst_.
510  std::unique_ptr<PdtShortestPath<Arc, FifoQueue<StateId>>>
511  reverse_shortest_path_;
512  std::unique_ptr<internal::PdtBalanceData<Arc>> balance_data_;
513  // Maps open paren arcs to balancing close paren arcs.
514  typename PdtShortestPath<Arc, FifoQueue<StateId>>::CloseParenMultimap
515  close_paren_multimap_;
516  MutableFst<Arc> *ofst_; // Output FST.
517  Weight limit_; // Weight limit.
518 
519  // Maps a state s in ifst (i.e., the source of a closed paranthesis matching
520  // the top of current_stack_id_ to final states in efst_.
521  std::unordered_map<StateId, Weight> dest_map_;
522  // Stack ID of the states currently at the top of the queue, i.e., the states
523  // currently being popped and processed.
524  StackId current_stack_id_;
525  ssize_t current_paren_id_; // Paren ID at top of current stack.
526  ssize_t cached_stack_id_;
527  StateId cached_source_;
528  // The set of pairs of destination states and weights to final states for the
529  // source state cached_source_ and the paren ID cached_paren_id_; i.e., the
530  // set of source states of a closed parenthesis with paren ID cached_paren_id
531  // balancing an incoming open parenthesis with paren ID cached_paren_id_ in
532  // state cached_source_.
533  std::forward_list<std::pair<StateId, Weight>> cached_dest_list_;
534  NaturalLess<Weight> less_;
535 };
536 
537 // Initializes close paren multimap, mapping pairs (s, paren_id) to all the arcs
538 // out of s labeled with close parenthese for paren_id.
539 template <class Arc>
541  const std::vector<std::pair<Label, Label>> &parens) {
542  std::unordered_map<Label, Label> paren_map;
543  for (size_t i = 0; i < parens.size(); ++i) {
544  const auto &pair = parens[i];
545  paren_map[pair.first] = i;
546  paren_map[pair.second] = i;
547  }
548  for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
549  const auto s = siter.Value();
550  for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
551  const auto &arc = aiter.Value();
552  const auto it = paren_map.find(arc.ilabel);
553  if (it == paren_map.end()) continue;
554  if (arc.ilabel == parens[it->second].second) { // Close paren.
555  const internal::ParenState<Arc> key(it->second, s);
556  close_paren_multimap_.emplace(key, arc);
557  }
558  }
559  }
560 }
561 
562 // Returns the weight of the shortest balanced path from source to dest
563 // in ifst_; dest must be the source state of a close paren arc.
564 template <class Arc>
565 typename Arc::Weight PdtPrunedExpand<Arc>::DistanceToDest(StateId source,
566  StateId dest) const {
567  using SearchState =
568  typename PdtShortestPath<Arc, FifoQueue<StateId>>::SearchState;
569  const SearchState ss(source + 1, dest + 1);
570  const auto distance =
571  reverse_shortest_path_->GetShortestPathData().Distance(ss);
572  VLOG(2) << "D(" << source << ", " << dest << ") =" << distance;
573  return distance;
574 }
575 
576 // Returns the flags for state s in ofst_.
577 template <class Arc>
578 uint8_t PdtPrunedExpand<Arc>::Flags(StateId s) const {
579  return s < flags_.size() ? flags_[s] : 0;
580 }
581 
582 // Modifies the flags for state s in ofst_.
583 template <class Arc>
584 void PdtPrunedExpand<Arc>::SetFlags(StateId s, uint8_t flags, uint8_t mask) {
585  while (flags_.size() <= s) flags_.push_back(0);
586  flags_[s] &= ~mask;
587  flags_[s] |= flags & mask;
588 }
589 
590 // Returns the shortest distance from the initial state to s in ofst_.
591 template <class Arc>
592 typename Arc::Weight PdtPrunedExpand<Arc>::Distance(StateId s) const {
593  return s < distance_.size() ? distance_[s] : Weight::Zero();
594 }
595 
596 // Sets the shortest distance from the initial state to s in ofst_.
597 template <class Arc>
599  while (distance_.size() <= s) distance_.push_back(Weight::Zero());
600  distance_[s] = std::move(weight);
601 }
602 
603 // Returns the shortest distance from s to the final states in ofst_.
604 template <class Arc>
605 typename Arc::Weight PdtPrunedExpand<Arc>::FinalDistance(StateId s) const {
606  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
607 }
608 
609 // Sets the shortest distance from s to the final states in ofst_.
610 template <class Arc>
612  while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
613  fdistance_[s] = std::move(weight);
614 }
615 
616 // Returns the PDT source state of state s in ofst_.
617 template <class Arc>
618 typename Arc::StateId PdtPrunedExpand<Arc>::SourceState(StateId s) const {
619  return s < sources_.size() ? sources_[s] : kNoStateId;
620 }
621 
622 // Sets the PDT source state of state s in ofst_ to state p'in ifst_.
623 template <class Arc>
625  while (sources_.size() <= s) sources_.push_back(kNoStateId);
626  sources_[s] = p;
627 }
628 
629 // Adds state s of efst_ to ofst_ and inserts it in the queue, modifying the
630 // flags for s accordingly.
631 template <class Arc>
633  if (!(Flags(s) & (kEnqueued | kExpanded))) {
634  while (ofst_->NumStates() <= s) ofst_->AddState();
635  queue_.Enqueue(s);
636  SetFlags(s, kEnqueued, kEnqueued);
637  } else if (Flags(s) & kEnqueued) {
638  queue_.Update(s);
639  }
640  // TODO(allauzen): Check everything is fine when kExpanded?
641 }
642 
643 // Relaxes arc out of state s in ofst_ as follows:
644 //
645 // 1. If the distance to s times the weight of arc is smaller than
646 // the currently stored distance for arc.nextstate, updates
647 // Distance(arc.nextstate) with a new estimate
648 // 2. If fd is less than the currently stored distance from arc.nextstate to the
649 // final state, updates with new estimate.
650 template <class Arc>
651 void PdtPrunedExpand<Arc>::Relax(StateId s, const Arc &arc, Weight fd) {
652  const auto nd = Times(Distance(s), arc.weight);
653  if (less_(nd, Distance(arc.nextstate))) {
654  SetDistance(arc.nextstate, nd);
655  SetSourceState(arc.nextstate, SourceState(s));
656  }
657  if (less_(fd, FinalDistance(arc.nextstate))) {
658  SetFinalDistance(arc.nextstate, fd);
659  }
660  VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
661  << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
662  << ", nd = " << nd;
663 }
664 
665 // Returns whether the arc out of state s in efst needs pruned.
666 template <class Arc>
667 bool PdtPrunedExpand<Arc>::PruneArc(StateId s, const Arc &arc) {
668  VLOG(2) << "Prune ?";
669  auto fd = Weight::Zero();
670  if ((cached_source_ != SourceState(s)) ||
671  (cached_stack_id_ != current_stack_id_)) {
672  cached_source_ = SourceState(s);
673  cached_stack_id_ = current_stack_id_;
674  cached_dest_list_.clear();
675  if (cached_source_ != ifst_->Start()) {
676  for (auto set_iter =
677  balance_data_->Find(current_paren_id_, cached_source_);
678  !set_iter.Done(); set_iter.Next()) {
679  auto dest = set_iter.Element();
680  const auto it = dest_map_.find(dest);
681  cached_dest_list_.push_front(*it);
682  }
683  } else {
684  // TODO(allauzen): queue discipline should prevent this from ever
685  // happening.
686  // Replace by a check.
687  cached_dest_list_.push_front(
688  std::make_pair(rfst_.Start() - 1, Weight::One()));
689  }
690  }
691  for (auto it = cached_dest_list_.begin(); it != cached_dest_list_.end();
692  ++it) {
693  const auto d =
694  DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first);
695  fd = Plus(fd, Times(d, it->second));
696  }
697  Relax(s, arc, fd);
698  return less_(limit_, Times(Distance(s), Times(arc.weight, fd)));
699 }
700 
701 // Adds start state of efst_ to ofst_, enqueues it, and initializes the distance
702 // data structures.
703 template <class Arc>
705  const auto s = efst_.Start();
706  AddStateAndEnqueue(s);
707  ofst_->SetStart(s);
708  SetSourceState(s, ifst_->Start());
709  current_stack_id_ = 0;
710  current_paren_id_ = -1;
711  stack_length_.push_back(0);
712  const auto r = rfst_.Start() - 1;
713  cached_source_ = ifst_->Start();
714  cached_stack_id_ = 0;
715  cached_dest_list_.push_front(std::make_pair(r, Weight::One()));
716  const PdtStateTuple<StateId, StackId> tuple(r, 0);
717  SetFinalDistance(state_table_.FindState(tuple), Weight::One());
718  SetDistance(s, Weight::One());
719  const auto d = DistanceToDest(ifst_->Start(), r);
720  SetFinalDistance(s, d);
721  VLOG(2) << d;
722 }
723 
724 // Makes s final in ofst_ if shortest accepting path ending in s is below
725 // threshold.
726 template <class Arc>
728  const auto weight = efst_.Final(s);
729  if (weight == Weight::Zero()) return;
730  if (less_(limit_, Times(Distance(s), weight))) return;
731  ofst_->SetFinal(s, weight);
732 }
733 
734 // Returns true when an arc (or meta-arc) leaving state s in efst_ is below the
735 // threshold. When add_arc is true, arc is added to ofst_.
736 template <class Arc>
737 bool PdtPrunedExpand<Arc>::ProcNonParen(StateId s, const Arc &arc,
738  bool add_arc) {
739  VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate << ", "
740  << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
741  << ", add_arc = " << (add_arc ? "true" : "false");
742  if (PruneArc(s, arc)) return false;
743  if (add_arc) ofst_->AddArc(s, arc);
744  AddStateAndEnqueue(arc.nextstate);
745  return true;
746 }
747 
748 // Processes an open paren arc leaving state s in ofst_. When the arc is labeled
749 // with an open paren,
750 //
751 // 1. Considers each (shortest) balanced path starting in s by taking the arc
752 // and ending by a close paren balancing the open paren of as a meta-arc,
753 // processing and pruning each meta-arc as a non-paren arc, inserting its
754 // destination to the queue;
755 // 2. if at least one of these meta-arcs has not been pruned, adds the
756 // destination of arc to ofst_ as a new source state for the stack ID nsi, and
757 // inserts it in the queue.
758 template <class Arc>
759 bool PdtPrunedExpand<Arc>::ProcOpenParen(StateId s, const Arc &arc, StackId si,
760  StackId nsi) {
761  // Updates the stack length when needed.
762  while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
763  if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1;
764  const auto ns = arc.nextstate;
765  VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
766  << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
767  bool proc_arc = false;
768  auto fd = Weight::Zero();
769  const auto paren_id = stack_.ParenId(arc.ilabel);
770  std::forward_list<StateId> sources;
771  for (auto set_iter =
772  balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
773  !set_iter.Done(); set_iter.Next()) {
774  sources.push_front(set_iter.Element());
775  }
776  for (const auto source : sources) {
777  VLOG(2) << "Close paren source: " << source;
778  const internal::ParenState<Arc> paren_state(paren_id, source);
779  for (auto it = close_paren_multimap_.find(paren_state);
780  it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
781  auto meta_arc = it->second;
782  const PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
783  meta_arc.nextstate = state_table_.FindState(tuple);
784  const auto state_id = state_table_.Tuple(ns).state_id;
785  const auto d = DistanceToDest(state_id, source);
786  VLOG(2) << state_id << ", " << source;
787  VLOG(2) << "Meta arc weight = " << arc.weight << " Times " << d
788  << " Times " << meta_arc.weight;
789  meta_arc.weight = Times(arc.weight, Times(d, meta_arc.weight));
790  proc_arc |= ProcNonParen(s, meta_arc, false);
791  fd = Plus(
792  fd,
793  Times(Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
794  it->second.weight),
795  FinalDistance(meta_arc.nextstate)));
796  }
797  }
798  if (proc_arc) {
799  VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
800  ofst_->AddArc(
801  s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
802  AddStateAndEnqueue(arc.nextstate);
803  const auto nd = Times(Distance(s), arc.weight);
804  if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd);
805  // FinalDistance not necessary for source state since pruning decided using
806  // meta-arcs above. But this is a problem with A*, hence the following.
807  if (less_(fd, FinalDistance(arc.nextstate)))
808  SetFinalDistance(arc.nextstate, fd);
809  SetFlags(arc.nextstate, kSourceState, kSourceState);
810  }
811  return proc_arc;
812 }
813 
814 // Checks that shortest path through close paren arc in efst_ is below
815 // threshold, and if so, adds it to ofst_.
816 template <class Arc>
817 bool PdtPrunedExpand<Arc>::ProcCloseParen(StateId s, const Arc &arc) {
818  const auto weight =
819  Times(Distance(s), Times(arc.weight, FinalDistance(arc.nextstate)));
820  if (less_(limit_, weight)) return false;
821  ofst_->AddArc(s,
822  keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
823  return true;
824 }
825 
826 // When state s in ofst_ is a source state for stack ID si, identifies all the
827 // corresponding possible destination states, that is, all the states in ifst_
828 // that have an outgoing close paren arc balancing the incoming open paren taken
829 // to get to s. For each such state t, computes the shortest distance from (t,
830 // si) to the final states in ofst_. Stores this information in dest_map_.
831 template <class Arc>
833  if (!(Flags(s) & kSourceState)) return;
834  if (si != current_stack_id_) {
835  dest_map_.clear();
836  current_stack_id_ = si;
837  current_paren_id_ = stack_.Top(current_stack_id_);
838  VLOG(2) << "StackID " << si << " dequeued for first time";
839  }
840  // TODO(allauzen): clean up source state business; rename current function to
841  // ProcSourceState.
842  SetSourceState(s, state_table_.Tuple(s).state_id);
843  const auto paren_id = stack_.Top(si);
844  for (auto set_iter =
845  balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
846  !set_iter.Done(); set_iter.Next()) {
847  const auto dest_state = set_iter.Element();
848  if (dest_map_.find(dest_state) != dest_map_.end()) continue;
849  auto dest_weight = Weight::Zero();
850  internal::ParenState<Arc> paren_state(paren_id, dest_state);
851  for (auto it = close_paren_multimap_.find(paren_state);
852  it != close_paren_multimap_.end() && paren_state == it->first; ++it) {
853  const auto &arc = it->second;
854  const PdtStateTuple<StateId, StackId> tuple(arc.nextstate,
855  stack_.Pop(si));
856  dest_weight =
857  Plus(dest_weight,
858  Times(arc.weight, FinalDistance(state_table_.FindState(tuple))));
859  }
860  dest_map_[dest_state] = dest_weight;
861  VLOG(2) << "State " << dest_state << " is a dest state for stack ID " << si
862  << " with weight " << dest_weight;
863  }
864 }
865 
866 // Expands and prunes the input PDT, writing the result in ofst.
867 template <class Arc>
869  const typename Arc::Weight &threshold) {
870  ofst_ = ofst;
871  if (error_) {
872  ofst_->SetProperties(kError, kError);
873  return;
874  }
875  ofst_->DeleteStates();
876  ofst_->SetInputSymbols(ifst_->InputSymbols());
877  ofst_->SetOutputSymbols(ifst_->OutputSymbols());
878  limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
879  flags_.clear();
880  ProcStart();
881  while (!queue_.Empty()) {
882  const auto s = queue_.Head();
883  queue_.Dequeue();
884  SetFlags(s, kExpanded, kExpanded | kEnqueued);
885  VLOG(2) << s << " dequeued!";
886  ProcFinal(s);
887  StackId stack_id = state_table_.Tuple(s).stack_id;
888  ProcDestStates(s, stack_id);
889  for (ArcIterator<PdtExpandFst<Arc>> aiter(efst_, s); !aiter.Done();
890  aiter.Next()) {
891  const auto &arc = aiter.Value();
892  const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
893  if (stack_id == nextstack_id) {
894  ProcNonParen(s, arc, true);
895  } else if (stack_id == stack_.Pop(nextstack_id)) {
896  ProcOpenParen(s, arc, stack_id, nextstack_id);
897  } else {
898  ProcCloseParen(s, arc);
899  }
900  }
901  VLOG(2) << "d[" << s << "] = " << Distance(s) << ", fd[" << s
902  << "] = " << FinalDistance(s);
903  }
904 }
905 
906 // Expand functions.
907 
908 template <class Arc>
910  using Weight = typename Arc::Weight;
911 
912  bool connect;
915 
916  explicit PdtExpandOptions(bool connect = true, bool keep_parentheses = false,
917  Weight weight_threshold = Weight::Zero())
918  : connect(connect),
919  keep_parentheses(keep_parentheses),
920  weight_threshold(std::move(weight_threshold)) {}
921 };
922 
923 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
924 // version writes the expanded PDT to a mutable FST. In the PDT, some
925 // transitions are labeled with open or close parentheses. To be interpreted as
926 // a PDT, the parens must balance on a path. The open-close parenthesis label
927 // pairs are passed using the parens argument. Expansion enforces the
928 // parenthesis constraints. The PDT must be expandable as an FST.
929 template <class Arc>
930 void Expand(
931  const Fst<Arc> &ifst,
932  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
933  &parens,
934  MutableFst<Arc> *ofst, const PdtExpandOptions<Arc> &opts) {
935  using Weight = typename Arc::Weight;
937  eopts.gc_limit = 0;
938  if (opts.weight_threshold == Arc::Weight::Zero()) {
939  eopts.keep_parentheses = opts.keep_parentheses;
940  *ofst = PdtExpandFst<Arc>(ifst, parens, eopts);
941  } else if constexpr (IsIdempotent<Weight>::value) {
942  PdtPrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
943  pruned_expand.Expand(ofst, opts.weight_threshold);
944  } else {
945  FSTERROR() << "Expand: non-Zero weight_threshold with non-idempotent"
946  << " Weight " << Weight::Type();
947  ofst->SetProperties(kError, kError);
948  }
949  if (opts.connect) Connect(ofst);
950 }
951 
952 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. This
953 // version writes the expanded PDT result to a mutable FST. In the PDT, some
954 // transitions are labeled with open or close parentheses. To be interpreted as
955 // a PDT, the parens must balance on a path. The open-close parenthesis label
956 // pairs are passed using the parents argument. Expansion enforces the
957 // parenthesis constraints. The PDT must be expandable as an FST.
958 template <class Arc>
959 void Expand(
960  const Fst<Arc> &ifst,
961  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
962  &parens,
963  MutableFst<Arc> *ofst, bool connect = true, bool keep_parentheses = false) {
964  const PdtExpandOptions<Arc> opts(connect, keep_parentheses);
965  Expand(ifst, parens, ofst, opts);
966 }
967 
968 } // namespace fst
969 
970 #endif // FST_EXTENSIONS_PDT_EXPAND_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:124
typename Arc::StateId StateId
Definition: expand.h:236
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:201
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: expand.h:310
typename Arc::Label Label
Definition: expand.h:76
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: expand.h:172
ArcIterator(const PdtExpandFst< Arc > &fst, StateId s)
Definition: expand.h:303
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
Definition: cache.h:54
virtual uint64_t Properties(uint64_t mask, bool test) const =0
StackId stack_id
Definition: pdt.h:144
typename Store::State State
Definition: expand.h:241
Weight Final(StateId s)
Definition: expand.h:145
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
size_t NumOutputEpsilons(StateId s)
Definition: expand.h:167
size_t NumInputEpsilons(StateId s)
Definition: expand.h:162
PdtExpandFst(const PdtExpandFst< Arc > &fst, bool safe=false)
Definition: expand.h:258
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:107
const SymbolTable * OutputSymbols() const
Definition: fst.h:761
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, const char *src="")
Definition: flags.cc:57
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens)
Definition: expand.h:247
constexpr uint64_t kError
Definition: properties.h:52
typename Arc::StateId StateId
Definition: expand.h:301
constexpr uint64_t kInitialAcyclic
Definition: properties.h:116
typename Arc::StateId StateId
Definition: expand.h:77
PdtStack< typename Arc::StateId, typename Arc::Label > * stack
Definition: expand.h:56
SetType
Definition: set-weight.h:59
typename PdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1156
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:47
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:113
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
Definition: weight.h:159
typename internal::PdtBalanceData< Arc >::SetIterator SetIterator
Definition: expand.h:341
constexpr int kNoStateId
Definition: fst.h:196
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: expand.h:268
typename Arc::Weight Weight
Definition: expand.h:910
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:36
void ExpandState(StateId s)
Definition: expand.h:179
const PdtStack< StackId, Label > & GetStack() const
Definition: expand.h:199
PdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
Definition: expand.h:252
bool Error() const
Definition: expand.h:370
#define FSTERROR()
Definition: util.h:56
typename Arc::Weight Weight
Definition: expand.h:237
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:359
size_t NumArcs(StateId s)
Definition: expand.h:157
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
Definition: expand.h:57
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:118
PdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool keep_parentheses=false, PdtStack< typename Arc::StateId, typename Arc::Label > *stack=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *state_table=nullptr)
Definition: expand.h:59
virtual uint64_t Properties() const
Definition: fst.h:701
typename Arc::StateId StateId
Definition: expand.h:334
constexpr uint64_t kCopyProperties
Definition: properties.h:163
constexpr uint64_t kAcyclic
Definition: properties.h:111
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:673
PdtExpandOptions(bool connect=true, bool keep_parentheses=false, Weight weight_threshold=Weight::Zero())
Definition: expand.h:916
void Expand(MutableFst< Arc > *ofst, const Weight &threshold)
Definition: expand.h:868
#define VLOG(level)
Definition: log.h:54
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
typename Arc::Weight Weight
Definition: expand.h:78
PdtExpandFst< Arc > * Copy(bool safe=false) const override
Definition: expand.h:262
PdtExpandFstImpl(const PdtExpandFstImpl &impl)
Definition: expand.h:115
StateIterator(const PdtExpandFst< Arc > &fst)
Definition: expand.h:292
Weight weight_threshold
Definition: expand.h:914
constexpr uint64_t kFstProperties
Definition: properties.h:326
constexpr uint64_t kUnweighted
Definition: properties.h:106
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:276
typename PdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1202
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:759
StateId state_id
Definition: pdt.h:143
PdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const PdtExpandFstOptions< Arc > &opts)
Definition: expand.h:97
const PdtStack< StackId, Label > & GetStack() const
Definition: expand.h:272
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:323
typename Arc::Weight Weight
Definition: expand.h:335
typename CacheState< Arc >::Arc Arc
Definition: cache.h:859
Impl * GetMutableImpl() const
Definition: impl-to-fst.h:125
typename Arc::Label Label
Definition: expand.h:235
constexpr uint64_t kExpanded
Definition: properties.h:46
typename Arc::Label Label
Definition: expand.h:333
size_t gc_limit
Definition: cache.h:52
uint64_t Properties(uint64_t mask, bool test) const override
Definition: impl-to-fst.h:63
const Impl * GetImpl() const
Definition: impl-to-fst.h:123
PdtPrunedExpand(const Fst< Arc > &ifst, const std::vector< std::pair< Label, Label >> &parens, bool keep_parentheses=false, const CacheOptions &opts=CacheOptions())
Definition: expand.h:347
constexpr uint64_t kAcceptor
Definition: properties.h:64
virtual const SymbolTable * OutputSymbols() const =0