FST  openfst-1.7.2
OpenFst Library
replace.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Recursively replaces FST arcs with other FSTs, returning a PDT.
5 
6 #ifndef FST_EXTENSIONS_PDT_REPLACE_H_
7 #define FST_EXTENSIONS_PDT_REPLACE_H_
8 
9 #include <map>
10 #include <memory>
11 #include <set>
12 #include <unordered_map>
13 #include <utility>
14 #include <vector>
15 
16 #include <fst/replace.h>
17 #include <fst/replace-util.h>
18 #include <fst/symbol-table-ops.h>
19 
20 namespace fst {
21 namespace internal {
22 
23 // Hash to paren IDs
24 template <typename S>
26  size_t operator()(const std::pair<size_t, S> &paren) const {
27  static constexpr auto prime = 7853;
28  return paren.first + paren.second * prime;
29  }
30 };
31 
32 } // namespace internal
33 
34 // Parser types characterize the PDT construction method. When applied to a CFG,
35 // each non-terminal is encoded as a DFA that accepts precisely the RHS's of
36 // productions of that non-terminal. For parsing (rather than just recognition),
37 // production numbers can used as outputs (placed as early as possible) in the
38 // DFAs promoted to DFTs. For more information on the strongly regular
39 // construction, see:
40 //
41 // Mohri, M., and Pereira, F. 1998. Dynamic compilation of weighted context-free
42 // grammars. In Proc. ACL, pages 891-897.
44  // Top-down construction. Applied to a simple LL(1) grammar (among others),
45  // gives a DPDA. If promoted to a DPDT, with outputs being production
46  // numbers, gives a leftmost derivation. Left recursive grammars are
47  // problematic in use.
49 
50  // Top-down construction. Similar to PDT_LEFT_PARSE except bounded-stack
51  // (expandable as an FST) result with regular or, more generally, strongly
52  // regular grammars. Epsilons may replace some parentheses, which may
53  // introduce some non-determinism.
55 
56  /* TODO(riley):
57  // Bottom-up construction. Applied to a LR(0) grammar, gives a DPDA.
58  // If promoted to a DPDT, with outputs being the production nubmers,
59  // gives the reverse of a rightmost derivation.
60  PDT_RIGHT_PARSER,
61  */
62 };
63 
64 template <class Arc>
66  using Label = typename Arc::Label;
67 
68  explicit PdtReplaceOptions(Label root,
70  Label start_paren_labels = kNoLabel,
71  string left_paren_prefix = "(_",
72  string right_paren_prefix = ")_") :
73  root(root), type(type), start_paren_labels(start_paren_labels),
74  left_paren_prefix(std::move(left_paren_prefix)),
75  right_paren_prefix(std::move(right_paren_prefix)) {}
76 
80  const string left_paren_prefix;
81  const string right_paren_prefix;
82 };
83 
84 // PdtParser: Base PDT parser class common to specific parsers.
85 
86 template <class Arc>
87 class PdtParser {
88  public:
89  using Label = typename Arc::Label;
90  using StateId = typename Arc::StateId;
91  using Weight = typename Arc::Weight;
92  using LabelFstPair = std::pair<Label, const Fst<Arc> *>;
93  using LabelPair = std::pair<Label, Label>;
94  using LabelStatePair = std::pair<Label, StateId>;
95  using StateWeightPair = std::pair<StateId, Weight>;
96  using ParenKey = std::pair<size_t, StateId>;
97  using ParenMap =
98  std::unordered_map<ParenKey, size_t, internal::ReplaceParenHash<StateId>>;
99 
100  PdtParser(const std::vector<LabelFstPair> &fst_array,
101  const PdtReplaceOptions<Arc> &opts) :
102  root_(opts.root), start_paren_labels_(opts.start_paren_labels),
103  left_paren_prefix_(std::move(opts.left_paren_prefix)),
104  right_paren_prefix_(std::move(opts.right_paren_prefix)),
105  error_(false) {
106  for (size_t i = 0; i < fst_array.size(); ++i) {
107  if (!CompatSymbols(fst_array[0].second->InputSymbols(),
108  fst_array[i].second->InputSymbols())) {
109  FSTERROR() << "PdtParser: Input symbol table of input FST " << i
110  << " does not match input symbol table of 0th input FST";
111  error_ = true;
112  }
113  if (!CompatSymbols(fst_array[0].second->OutputSymbols(),
114  fst_array[i].second->OutputSymbols())) {
115  FSTERROR() << "PdtParser: Output symbol table of input FST " << i
116  << " does not match input symbol table of 0th input FST";
117  error_ = true;
118  }
119  fst_array_.emplace_back(fst_array[i].first, fst_array[i].second->Copy());
120  // Builds map from non-terminal label to FST ID.
121  label2id_[fst_array[i].first] = i;
122  }
123  }
124 
125  virtual ~PdtParser() {
126  for (auto &pair : fst_array_) delete pair.second;
127  }
128 
129  // Constructs the output PDT, dependent on the derived parser type.
130  virtual void GetParser(MutableFst<Arc> *ofst,
131  std::vector<LabelPair> *parens) = 0;
132 
133  protected:
134  const std::vector<LabelFstPair> &FstArray() const { return fst_array_; }
135 
136  Label Root() const { return root_; }
137 
138  // Maps from non-terminal label to corresponding FST ID, or returns
139  // kNoStateId to signal lookup failure.
140  StateId Label2Id(Label l) const {
141  auto it = label2id_.find(l);
142  return it == label2id_.end() ? kNoStateId : it->second;
143  }
144 
145  // Maps from output state to input FST label, state pair, or returns a
146  // (kNoLabel, kNoStateId) pair to signal lookup failure.
148  if (os >= label_state_pairs_.size()) {
149  static const LabelStatePair no_pair(kNoLabel, kNoLabel);
150  return no_pair;
151  } else {
152  return label_state_pairs_[os];
153  }
154  }
155 
156  // Maps to output state from input FST (label, state) pair, or returns
157  // kNoStateId to signal lookup failure.
158  StateId GetState(const LabelStatePair &lsp) const {
159  auto it = state_map_.find(lsp);
160  if (it == state_map_.end()) {
161  return kNoStateId;
162  } else {
163  return it->second;
164  }
165  }
166 
167  // Builds single FST combining all referenced input FSTs, leaving in the
168  // non-termnals for now; also tabulates the PDT states that correspond to the
169  // start and final states of the input FSTs.
170  void CreateFst(MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
171  std::vector<std::vector<StateWeightPair>> *close_src);
172 
173  // Assigns parenthesis labels from total allocated paren IDs.
174  void AssignParenLabels(size_t total_nparens, std::vector<LabelPair> *parens) {
175  parens->clear();
176  for (size_t paren_id = 0; paren_id < total_nparens; ++paren_id) {
177  const auto open_paren = start_paren_labels_ + paren_id;
178  const auto close_paren = open_paren + total_nparens;
179  parens->emplace_back(open_paren, close_paren);
180  }
181  }
182 
183  // Determines how non-terminal instances are assigned parentheses IDs.
184  virtual size_t AssignParenIds(const Fst<Arc> &ofst,
185  ParenMap *paren_map) const = 0;
186 
187  // Changes a non-terminal transition to an open parenthesis transition
188  // redirected to the PDT state specified in the open_dest argument, when
189  // indexed by the input FST ID for the non-terminal. Adds close parenthesis
190  // transitions (with specified weights) from the PDT states specified in the
191  // close_src argument, when indexed by the input FST ID for the non-terminal,
192  // to the former destination state of the non-terminal transition. The
193  // paren_map argument gives the parenthesis ID for a given non-terminal FST ID
194  // and destination state pair. The close_non_term_weight vector specifies
195  // non-terminals for which the non-terminal arc weight should be applied on
196  // the close parenthesis (multiplying the close_src weight above) rather than
197  // on the open parenthesis. If no paren ID is found, then an epsilon replaces
198  // the parenthesis that would carry the non-terminal arc weight and the other
199  // parenthesis is omitted (appropriate for the strongly-regular case).
200  void AddParensToFst(
201  const std::vector<LabelPair> &parens,
202  const ParenMap &paren_map,
203  const std::vector<StateId> &open_dest,
204  const std::vector<std::vector<StateWeightPair>> &close_src,
205  const std::vector<bool> &close_non_term_weight,
206  MutableFst<Arc> *ofst);
207 
208  // Ensures that parentheses arcs are added to the symbol table.
209  void AddParensToSymbolTables(const std::vector<LabelPair> &parens,
210  MutableFst<Arc> *ofst);
211 
212  private:
213  std::vector<LabelFstPair> fst_array_;
214  Label root_;
215  // Index to use for the first parenthesis.
216  Label start_paren_labels_;
217  const string left_paren_prefix_;
218  const string right_paren_prefix_;
219  // Maps from non-terminal label to FST ID.
220  std::unordered_map<Label, StateId> label2id_;
221  // Given an output state, specifies the input FST (label, state) pair.
222  std::vector<LabelStatePair> label_state_pairs_;
223  // Given an FST (label, state) pair, specifies the output FST state ID.
224  std::map<LabelStatePair, StateId> state_map_;
225  bool error_;
226 };
227 
228 template <class Arc>
230  MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
231  std::vector<std::vector<StateWeightPair>> *close_src) {
232  ofst->DeleteStates();
233  if (error_) {
234  ofst->SetProperties(kError, kError);
235  return;
236  }
237  open_dest->resize(fst_array_.size(), kNoStateId);
238  close_src->resize(fst_array_.size());
239  // Queue of non-terminals to replace.
240  std::deque<Label> non_term_queue;
241  non_term_queue.push_back(root_);
242  // Has a non-terminal been enqueued?
243  std::vector<bool> enqueued(fst_array_.size(), false);
244  enqueued[label2id_[root_]] = true;
245  Label max_label = kNoLabel;
246  for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
247  const auto label = non_term_queue.front();
248  non_term_queue.pop_front();
249  StateId fst_id = Label2Id(label);
250  const auto *ifst = fst_array_[fst_id].second;
251  for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
252  const auto is = siter.Value();
253  const auto os = ofst->AddState();
254  const LabelStatePair lsp(label, is);
255  label_state_pairs_.push_back(lsp);
256  state_map_[lsp] = os;
257  if (is == ifst->Start()) {
258  (*open_dest)[fst_id] = os;
259  if (label == root_) ofst->SetStart(os);
260  }
261  if (ifst->Final(is) != Weight::Zero()) {
262  if (label == root_) ofst->SetFinal(os, ifst->Final(is));
263  (*close_src)[fst_id].emplace_back(os, ifst->Final(is));
264  }
265  for (ArcIterator<Fst<Arc>> aiter(*ifst, is); !aiter.Done();
266  aiter.Next()) {
267  auto arc = aiter.Value();
268  arc.nextstate += soff;
269  if (max_label == kNoLabel || arc.olabel > max_label)
270  max_label = arc.olabel;
271  const auto nfst_id = Label2Id(arc.olabel);
272  if (nfst_id != kNoStateId) {
273  if (fst_array_[nfst_id].second->Start() == kNoStateId) continue;
274  if (!enqueued[nfst_id]) {
275  non_term_queue.push_back(arc.olabel);
276  enqueued[nfst_id] = true;
277  }
278  }
279  ofst->AddArc(os, arc);
280  }
281  }
282  }
283  if (start_paren_labels_ == kNoLabel)
284  start_paren_labels_ = max_label + 1;
285 }
286 
287 template <class Arc>
289  const std::vector<LabelPair> &parens,
290  const ParenMap &paren_map,
291  const std::vector<StateId> &open_dest,
292  const std::vector<std::vector<StateWeightPair>> &close_src,
293  const std::vector<bool> &close_non_term_weight,
294  MutableFst<Arc> *ofst) {
295  StateId dead_state = kNoStateId;
296  using MIter = MutableArcIterator<MutableFst<Arc>>;
297  for (StateIterator<Fst<Arc>> siter(*ofst); !siter.Done(); siter.Next()) {
298  StateId os = siter.Value();
299  std::unique_ptr<MIter> aiter(new MIter(ofst, os));
300  for (auto n = 0; !aiter->Done(); aiter->Next(), ++n) {
301  const auto arc = aiter->Value(); // A reference here may go stale.
302  StateId nfst_id = Label2Id(arc.olabel);
303  if (nfst_id != kNoStateId) {
304  // Gets parentheses.
305  const ParenKey paren_key(nfst_id, arc.nextstate);
306  auto it = paren_map.find(paren_key);
307  Label open_paren = 0;
308  Label close_paren = 0;
309  if (it != paren_map.end()) {
310  const auto paren_id = it->second;
311  open_paren = parens[paren_id].first;
312  close_paren = parens[paren_id].second;
313  }
314  // Sets open parenthesis.
315  if (open_paren != 0 || !close_non_term_weight[nfst_id]) {
316  const auto open_weight =
317  close_non_term_weight[nfst_id] ? Weight::One() : arc.weight;
318  const Arc sarc(open_paren, open_paren, open_weight,
319  open_dest[nfst_id]);
320  aiter->SetValue(sarc);
321  } else {
322  if (dead_state == kNoStateId) {
323  dead_state = ofst->AddState();
324  }
325  const Arc sarc(0, 0, Weight::One(), dead_state);
326  aiter->SetValue(sarc);
327  }
328  // Adds close parentheses.
329  if (close_paren != 0 || close_non_term_weight[nfst_id]) {
330  for (size_t i = 0; i < close_src[nfst_id].size(); ++i) {
331  const auto &pair = close_src[nfst_id][i];
332  const auto close_weight = close_non_term_weight[nfst_id]
333  ? Times(arc.weight, pair.second)
334  : pair.second;
335  const Arc farc(close_paren, close_paren, close_weight,
336  arc.nextstate);
337 
338  ofst->AddArc(pair.first, farc);
339  if (os == pair.first) { // Invalidated iterator.
340  aiter.reset(new MIter(ofst, os));
341  aiter->Seek(n);
342  }
343  }
344  }
345  }
346  }
347  }
348 }
349 
350 template <class Arc>
352  const std::vector<LabelPair> &parens, MutableFst<Arc> *ofst) {
353  auto size = parens.size();
354  if (ofst->InputSymbols()) {
355  if (!AddAuxiliarySymbols(left_paren_prefix_, start_paren_labels_, size,
356  ofst->MutableInputSymbols())) {
357  ofst->SetProperties(kError, kError);
358  return;
359  }
360  if (!AddAuxiliarySymbols(right_paren_prefix_, start_paren_labels_ + size,
361  size, ofst->MutableInputSymbols())) {
362  ofst->SetProperties(kError, kError);
363  return;
364  }
365  }
366  if (ofst->OutputSymbols()) {
367  if (!AddAuxiliarySymbols(left_paren_prefix_, start_paren_labels_, size,
368  ofst->MutableOutputSymbols())) {
369  ofst->SetProperties(kError, kError);
370  return;
371  }
372  if (!AddAuxiliarySymbols(right_paren_prefix_, start_paren_labels_ + size,
373  size, ofst->MutableOutputSymbols())) {
374  ofst->SetProperties(kError, kError);
375  return;
376  }
377  }
378 }
379 
380 // Builds a PDT by recursive replacement top-down, where the call and return are
381 // encoded in the parentheses.
382 template <class Arc>
383 class PdtLeftParser final : public PdtParser<Arc> {
384  public:
385  using Label = typename Arc::Label;
386  using StateId = typename Arc::StateId;
387  using Weight = typename Arc::Weight;
394 
403  using PdtParser<Arc>::Root;
404 
405  PdtLeftParser(const std::vector<LabelFstPair> &fst_array,
406  const PdtReplaceOptions<Arc> &opts) :
407  PdtParser<Arc>(fst_array, opts) { }
408 
409  void GetParser(MutableFst<Arc> *ofst,
410  std::vector<LabelPair> *parens) override;
411 
412  protected:
413  // Assigns a unique parenthesis ID for each non-terminal, destination
414  // state pair.
415  size_t AssignParenIds(const Fst<Arc> &ofst,
416  ParenMap *paren_map) const override;
417 };
418 
419 template <class Arc>
421  MutableFst<Arc> *ofst,
422  std::vector<LabelPair> *parens) {
423  ofst->DeleteStates();
424  parens->clear();
425  const auto &fst_array = FstArray();
426  // Map that gives the paren ID for a (non-terminal, dest. state) pair
427  // (which can be unique).
428  ParenMap paren_map;
429  // Specifies the open parenthesis destination state for a given non-terminal.
430  // The source is the non-terminal instance source state.
431  std::vector<StateId> open_dest(fst_array.size(), kNoStateId);
432  // Specifies close parenthesis source states and weights for a given
433  // non-terminal. The destination is the non-terminal instance destination
434  // state.
435  std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
436  // Specifies non-terminals for which the non-terminal arc weight
437  // should be applied on the close parenthesis (multiplying the
438  // 'close_src' weight above) rather than on the open parenthesis.
439  std::vector<bool> close_non_term_weight(fst_array.size(), false);
440  CreateFst(ofst, &open_dest, &close_src);
441  auto total_nparens = AssignParenIds(*ofst, &paren_map);
442  AssignParenLabels(total_nparens, parens);
443  AddParensToFst(*parens, paren_map, open_dest, close_src,
444  close_non_term_weight, ofst);
445  if (!fst_array.empty()) {
446  ofst->SetInputSymbols(fst_array[0].second->InputSymbols());
447  ofst->SetOutputSymbols(fst_array[0].second->OutputSymbols());
448  }
449  AddParensToSymbolTables(*parens, ofst);
450 }
451 
452 template <class Arc>
454  const Fst<Arc> &ofst,
455  ParenMap *paren_map) const {
456  // Number of distinct parenthesis pairs per FST.
457  std::vector<size_t> nparens(FstArray().size(), 0);
458  // Number of distinct parenthesis pairs overall.
459  size_t total_nparens = 0;
460  for (StateIterator<Fst<Arc>> siter(ofst); !siter.Done(); siter.Next()) {
461  const auto os = siter.Value();
462  for (ArcIterator<Fst<Arc>> aiter(ofst, os); !aiter.Done(); aiter.Next()) {
463  const auto &arc = aiter.Value();
464  const auto nfst_id = Label2Id(arc.olabel);
465  if (nfst_id != kNoStateId) {
466  const ParenKey paren_key(nfst_id, arc.nextstate);
467  auto it = paren_map->find(paren_key);
468  if (it == paren_map->end()) {
469  // Assigns new paren ID for this (FST, dest state) pair.
470  (*paren_map)[paren_key] = nparens[nfst_id]++;
471  if (nparens[nfst_id] > total_nparens)
472  total_nparens = nparens[nfst_id];
473  }
474  }
475  }
476  }
477  return total_nparens;
478 }
479 
480 // Similar to PdtLeftParser but:
481 //
482 // 1. Uses epsilons rather than parentheses labels for any non-terminal
483 // instances within a left- (right-) linear dependency SCC,
484 // 2. Allocates a paren ID uniquely for each such dependency SCC (rather than
485 // non-terminal = dependency state) and destination state.
486 template <class Arc>
487 class PdtLeftSRParser final : public PdtParser<Arc> {
488  public:
489  using Label = typename Arc::Label;
490  using StateId = typename Arc::StateId;
491  using Weight = typename Arc::Weight;
498 
507  using PdtParser<Arc>::Root;
508 
509  PdtLeftSRParser(const std::vector<LabelFstPair> &fst_array,
510  const PdtReplaceOptions<Arc> &opts) :
511  PdtParser<Arc>(fst_array, opts),
512  replace_util_(fst_array, ReplaceUtilOptions(opts.root)) { }
513 
514  void GetParser(MutableFst<Arc> *ofst,
515  std::vector<LabelPair> *parens) override;
516 
517  protected:
518  // Assigns a unique parenthesis ID for each non-terminal, destination state
519  // pair when the non-terminal refers to a non-linear FST. Otherwise, assigns
520  // a unique parenthesis ID for each dependency SCC, destination state pair if
521  // the non-terminal instance is between
522  // SCCs. Otherwise does nothing.
523  size_t AssignParenIds(const Fst<Arc> &ofst,
524  ParenMap *paren_map) const override;
525 
526  // Returns dependency SCC for given label.
527  size_t SCC(Label label) const { return replace_util_.SCC(label); }
528 
529  // Is a given dependency SCC left-linear?
530  bool SCCLeftLinear(size_t scc_id) const {
531  const auto ll_props = kReplaceSCCLeftLinear | kReplaceSCCNonTrivial;
532  const auto scc_props = replace_util_.SCCProperties(scc_id);
533  return (scc_props & ll_props) == ll_props;
534  }
535 
536  // Is a given dependency SCC right-linear?
537  bool SCCRightLinear(size_t scc_id) const {
538  const auto lr_props = kReplaceSCCRightLinear | kReplaceSCCNonTrivial;
539  const auto scc_props = replace_util_.SCCProperties(scc_id);
540  return (scc_props & lr_props) == lr_props;
541  }
542 
543  // Components of left- (right-) linear dependency SCC; empty o.w.
544  const std::vector<size_t> &SCCComps(size_t scc_id) const {
545  if (scc_comps_.empty()) GetSCCComps();
546  return scc_comps_[scc_id];
547  }
548 
549  // Returns the representative state of an SCC. For left-linear grammars, it
550  // is one of the initial states. For right-linear grammars, it is one of the
551  // non-terminal destination states; otherwise, it is kNoStateId.
552  StateId RepState(size_t scc_id) const {
553  if (SCCComps(scc_id).empty()) return kNoStateId;
554  const auto fst_id = SCCComps(scc_id).front();
555  const auto &fst_array = FstArray();
556  const auto label = fst_array[fst_id].first;
557  const auto *ifst = fst_array[fst_id].second;
558  if (SCCLeftLinear(scc_id)) {
559  const LabelStatePair lsp(label, ifst->Start());
560  return GetState(lsp);
561  } else { // Right-linear.
562  const LabelStatePair lsp(label, *NonTermDests(fst_id).begin());
563  return GetState(lsp);
564  }
565  return kNoStateId;
566  }
567 
568  private:
569  // Merges initial (final) states of in a left- (right-) linear dependency SCC
570  // after dealing with the non-terminal arc and final weights.
571  void ProcSCCs(MutableFst<Arc> *ofst,
572  std::vector<StateId> *open_dest,
573  std::vector<std::vector<StateWeightPair>> *close_src,
574  std::vector<bool> *close_non_term_weight) const;
575 
576  // Computes components of left- (right-) linear dependency SCC.
577  void GetSCCComps() const {
578  const std::vector<LabelFstPair> &fst_array = FstArray();
579  for (size_t i = 0; i < fst_array.size(); ++i) {
580  const auto label = fst_array[i].first;
581  const auto scc_id = SCC(label);
582  if (scc_comps_.size() <= scc_id) scc_comps_.resize(scc_id + 1);
583  if (SCCLeftLinear(scc_id) || SCCRightLinear(scc_id)) {
584  scc_comps_[scc_id].push_back(i);
585  }
586  }
587  }
588 
589  const std::set<StateId> &NonTermDests(StateId fst_id) const {
590  if (non_term_dests_.empty()) GetNonTermDests();
591  return non_term_dests_[fst_id];
592  }
593 
594  // Finds non-terminal destination states for right-linear FSTS, or does
595  // nothing if not found.
596  void GetNonTermDests() const;
597 
598  // Dependency SCC info.
599  mutable ReplaceUtil<Arc> replace_util_;
600  // Components of left- (right-) linear dependency SCCs, or empty otherwise.
601  mutable std::vector<std::vector<size_t>> scc_comps_;
602  // States that have non-terminals entering them for each (right-linear) FST.
603  mutable std::vector<std::set<StateId>> non_term_dests_;
604 };
605 
606 template <class Arc>
608  MutableFst<Arc> *ofst,
609  std::vector<LabelPair> *parens) {
610  ofst->DeleteStates();
611  parens->clear();
612  const auto &fst_array = FstArray();
613  // Map that gives the paren ID for a (non-terminal, dest. state) pair.
614  ParenMap paren_map;
615  // Specifies the open parenthesis destination state for a given non-terminal.
616  // The source is the non-terminal instance source state.
617  std::vector<StateId> open_dest(fst_array.size(), kNoStateId);
618  // Specifies close parenthesis source states and weights for a given
619  // non-terminal. The destination is the non-terminal instance destination
620  // state.
621  std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
622  // Specifies non-terminals for which the non-terminal arc weight should be
623  // applied on the close parenthesis (multiplying the close_src weight above)
624  // rather than on the open parenthesis.
625  std::vector<bool> close_non_term_weight(fst_array.size(), false);
626  CreateFst(ofst, &open_dest, &close_src);
627  ProcSCCs(ofst, &open_dest, &close_src, &close_non_term_weight);
628  const auto total_nparens = AssignParenIds(*ofst, &paren_map);
629  AssignParenLabels(total_nparens, parens);
630  AddParensToFst(*parens, paren_map, open_dest, close_src,
631  close_non_term_weight, ofst);
632  if (!fst_array.empty()) {
633  ofst->SetInputSymbols(fst_array[0].second->InputSymbols());
634  ofst->SetOutputSymbols(fst_array[0].second->OutputSymbols());
635  }
636  AddParensToSymbolTables(*parens, ofst);
637  Connect(ofst);
638 }
639 
640 template <class Arc>
642  MutableFst<Arc> *ofst,
643  std::vector<StateId> *open_dest,
644  std::vector<std::vector<StateWeightPair>> *close_src,
645  std::vector<bool> *close_non_term_weight) const {
646  const auto &fst_array = FstArray();
647  for (StateIterator<Fst<Arc>> siter(*ofst); !siter.Done(); siter.Next()) {
648  const auto os = siter.Value();
649  const auto label = GetLabelStatePair(os).first;
650  const auto is = GetLabelStatePair(os).second;
651  const auto fst_id = Label2Id(label);
652  const auto scc_id = SCC(label);
653  const auto rs = RepState(scc_id);
654  const auto *ifst = fst_array[fst_id].second;
655  // SCC LEFT-LINEAR: puts non-terminal weights on close parentheses. Merges
656  // initial states into SCC representative state and updates open_dest.
657  if (SCCLeftLinear(scc_id)) {
658  (*close_non_term_weight)[fst_id] = true;
659  if (is == ifst->Start() && os != rs) {
660  for (ArcIterator<Fst<Arc>> aiter(*ofst, os); !aiter.Done();
661  aiter.Next()) {
662  const auto &arc = aiter.Value();
663  ofst->AddArc(rs, arc);
664  }
665  ofst->DeleteArcs(os);
666  if (os == ofst->Start())
667  ofst->SetStart(rs);
668  (*open_dest)[fst_id] = rs;
669  }
670  }
671  // SCC RIGHT-LINEAR: pushes back final weights onto non-terminals, if
672  // possible, or adds weighted epsilons to the SCC representative state.
673  // Merges final states into SCC representative state and updates close_src.
674  if (SCCRightLinear(scc_id)) {
675  for (MutableArcIterator<MutableFst<Arc>> aiter(ofst, os); !aiter.Done();
676  aiter.Next()) {
677  auto arc = aiter.Value();
678  const auto idest = GetLabelStatePair(arc.nextstate).second;
679  if (NonTermDests(fst_id).count(idest) > 0) {
680  if (ofst->Final(arc.nextstate) != Weight::Zero()) {
681  ofst->SetFinal(arc.nextstate, Weight::Zero());
682  ofst->SetFinal(rs, Weight::One());
683  }
684  arc.weight = Times(arc.weight, ifst->Final(idest));
685  arc.nextstate = rs;
686  aiter.SetValue(arc);
687  }
688  }
689  const auto final_weight = ifst->Final(is);
690  if (final_weight != Weight::Zero() &&
691  NonTermDests(fst_id).count(is) == 0) {
692  ofst->AddArc(os, Arc(0, 0, final_weight, rs));
693  if (ofst->Final(os) != Weight::Zero()) {
694  ofst->SetFinal(os, Weight::Zero());
695  ofst->SetFinal(rs, Weight::One());
696  }
697  }
698  if (is == ifst->Start()) {
699  (*close_src)[fst_id].clear();
700  (*close_src)[fst_id].emplace_back(rs, Weight::One());
701  }
702  }
703  }
704 }
705 
706 template <class Arc>
708  const auto &fst_array = FstArray();
709  non_term_dests_.resize(fst_array.size());
710  for (size_t fst_id = 0; fst_id < fst_array.size(); ++fst_id) {
711  const auto label = fst_array[fst_id].first;
712  const auto scc_id = SCC(label);
713  if (SCCRightLinear(scc_id)) {
714  const auto *ifst = fst_array[fst_id].second;
715  for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
716  const auto is = siter.Value();
717  for (ArcIterator<Fst<Arc>> aiter(*ifst, is); !aiter.Done();
718  aiter.Next()) {
719  const auto &arc = aiter.Value();
720  if (Label2Id(arc.olabel) != kNoStateId) {
721  non_term_dests_[fst_id].insert(arc.nextstate);
722  }
723  }
724  }
725  }
726  }
727 }
728 
729 template <class Arc>
731  const Fst<Arc> &ofst,
732  ParenMap *paren_map) const {
733  const auto &fst_array = FstArray();
734  // Number of distinct parenthesis pairs per FST.
735  std::vector<size_t> nparens(fst_array.size(), 0);
736  // Number of distinct parenthesis pairs overall.
737  size_t total_nparens = 0;
738  for (StateIterator<Fst<Arc>> siter(ofst); !siter.Done(); siter.Next()) {
739  const auto os = siter.Value();
740  const auto label = GetLabelStatePair(os).first;
741  const auto scc_id = SCC(label);
742  for (ArcIterator<Fst<Arc>> aiter(ofst, os); !aiter.Done(); aiter.Next()) {
743  const auto &arc = aiter.Value();
744  const auto nfst_id = Label2Id(arc.olabel);
745  if (nfst_id != kNoStateId) {
746  size_t nscc_id = SCC(arc.olabel);
747  bool nscc_linear = !SCCComps(nscc_id).empty();
748  // Assigns a parenthesis ID for the non-terminal transition
749  // if the non-terminal belongs to a (left-/right-) linear dependency
750  // SCC or if the transition is in an FST from a different SCC
751  if (!nscc_linear || scc_id != nscc_id) {
752  // For (left-/right-) linear SCCs instead of using nfst_id, we
753  // will use its SCC prototype pfst_id for assigning distinct
754  // parenthesis IDs.
755  const auto pfst_id =
756  nscc_linear ? SCCComps(nscc_id).front() : nfst_id;
757  ParenKey paren_key(pfst_id, arc.nextstate);
758  const auto it = paren_map->find(paren_key);
759  if (it == paren_map->end()) {
760  // Assigns new paren ID for this (FST/SCC, dest. state) pair.
761  if (nscc_linear) {
762  // This is mapping we'll need, but we also store (harmlessly)
763  // for the prototype below so we can easily keep count per SCC.
764  const ParenKey nparen_key(nfst_id, arc.nextstate);
765  (*paren_map)[nparen_key] = nparens[pfst_id];
766  }
767  (*paren_map)[paren_key] = nparens[pfst_id]++;
768  if (nparens[pfst_id] > total_nparens) {
769  total_nparens = nparens[pfst_id];
770  }
771  }
772  }
773  }
774  }
775  }
776  return total_nparens;
777 }
778 
779 // Builds a pushdown transducer (PDT) from an RTN specification. The result is
780 // a PDT written to a mutable FST where some transitions are labeled with
781 // open or close parentheses. To be interpreted as a PDT, the parens must
782 // balance on a path (see PdtExpand()). The open/close parenthesis label pairs
783 // are returned in the parens argument.
784 template <class Arc>
785 void Replace(
786  const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
787  &ifst_array,
788  MutableFst<Arc> *ofst,
789  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
790  const PdtReplaceOptions<Arc> &opts) {
791  switch (opts.type) {
792  case PDT_LEFT_PARSER:
793  {
794  PdtLeftParser<Arc> pr(ifst_array, opts);
795  pr.GetParser(ofst, parens);
796  return;
797  }
798  case PDT_LEFT_SR_PARSER:
799  {
800  PdtLeftSRParser<Arc> pr(ifst_array, opts);
801  pr.GetParser(ofst, parens);
802  return;
803  }
804  default:
805  FSTERROR() << "Replace: Unknown PDT parser type: " << opts.type;
806  ofst->DeleteStates();
807  ofst->SetProperties(kError, kError);
808  parens->clear();
809  return;
810  }
811 }
812 
813 // Variant where the only user-controlled arguments is the root ID.
814 template <class Arc>
815 void Replace(
816  const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
817  &ifst_array,
818  MutableFst<Arc> *ofst,
819  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
820  typename Arc::Label root) {
821  PdtReplaceOptions<Arc> opts(root);
822  Replace(ifst_array, ofst, parens, opts);
823 }
824 
825 } // namespace fst
826 
827 #endif // FST_EXTENSIONS_PDT_REPLACE_H_
virtual ~PdtParser()
Definition: replace.h:125
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
Definition: replace.h:607
std::unordered_map< ParenKey, size_t, internal::ReplaceParenHash< StateId >> ParenMap
Definition: replace.h:98
constexpr int kNoLabel
Definition: fst.h:179
PdtReplaceOptions(Label root, PdtParserType type=PDT_LEFT_PARSER, Label start_paren_labels=kNoLabel, string left_paren_prefix="(_", string right_paren_prefix=")_")
Definition: replace.h:68
virtual SymbolTable * MutableOutputSymbols()=0
PdtLeftSRParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:509
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
Definition: replace.h:453
virtual void AddArc(StateId, const Arc &arc)=0
StateId RepState(size_t scc_id) const
Definition: replace.h:552
const SymbolTable * InputSymbols() const override=0
virtual void SetInputSymbols(const SymbolTable *isyms)=0
StateId Label2Id(Label l) const
Definition: replace.h:140
Label Root() const
Definition: replace.h:136
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
Definition: replace.h:730
virtual Weight Final(StateId) const =0
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
virtual void SetStart(StateId)=0
bool SCCLeftLinear(size_t scc_id) const
Definition: replace.h:530
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:268
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:785
size_t operator()(const std::pair< size_t, S > &paren) const
Definition: replace.h:26
void AssignParenLabels(size_t total_nparens, std::vector< LabelPair > *parens)
Definition: replace.h:174
constexpr uint8 kReplaceSCCNonTrivial
Definition: replace-util.h:74
constexpr int kNoStateId
Definition: fst.h:180
std::pair< Label, StateId > LabelStatePair
Definition: replace.h:94
bool SCCRightLinear(size_t scc_id) const
Definition: replace.h:537
const Arc & Value() const
Definition: mutable-fst.h:213
typename Arc::StateId StateId
Definition: replace.h:90
typename Arc::Label Label
Definition: replace.h:89
void CreateFst(MutableFst< Arc > *ofst, std::vector< StateId > *open_dest, std::vector< std::vector< StateWeightPair >> *close_src)
Definition: replace.h:229
#define FSTERROR()
Definition: util.h:35
const SymbolTable * OutputSymbols() const override=0
const string right_paren_prefix
Definition: replace.h:81
std::pair< StateId, Weight > StateWeightPair
Definition: replace.h:95
const std::vector< size_t > & SCCComps(size_t scc_id) const
Definition: replace.h:544
virtual void SetFinal(StateId, Weight)=0
bool AddAuxiliarySymbols(const string &prefix, int64 start_label, int64 nlabels, SymbolTable *syms)
PdtParserType type
Definition: replace.h:78
const string left_paren_prefix
Definition: replace.h:80
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
Definition: replace.h:420
virtual StateId Start() const =0
StateId GetState(const LabelStatePair &lsp) const
Definition: replace.h:158
PdtParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:100
virtual void DeleteArcs(StateId, size_t n)=0
std::pair< size_t, StateId > ParenKey
Definition: replace.h:96
const std::vector< LabelFstPair > & FstArray() const
Definition: replace.h:134
constexpr uint8 kReplaceSCCLeftLinear
Definition: replace-util.h:67
std::pair< Label, Label > LabelPair
Definition: replace.h:93
void AddParensToSymbolTables(const std::vector< LabelPair > &parens, MutableFst< Arc > *ofst)
Definition: replace.h:351
constexpr uint8 kReplaceSCCRightLinear
Definition: replace-util.h:71
typename Arc::Label Label
Definition: replace.h:66
virtual SymbolTable * MutableInputSymbols()=0
Label start_paren_labels
Definition: replace.h:79
constexpr uint64 kError
Definition: properties.h:33
PdtParserType
Definition: replace.h:43
virtual StateId AddState()=0
std::pair< Label, const Fst< Arc > * > LabelFstPair
Definition: replace.h:92
void AddParensToFst(const std::vector< LabelPair > &parens, const ParenMap &paren_map, const std::vector< StateId > &open_dest, const std::vector< std::vector< StateWeightPair >> &close_src, const std::vector< bool > &close_non_term_weight, MutableFst< Arc > *ofst)
Definition: replace.h:288
virtual void DeleteStates(const std::vector< StateId > &)=0
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
typename Arc::Weight Weight
Definition: replace.h:91
virtual StateId NumStates() const =0
LabelStatePair GetLabelStatePair(StateId os) const
Definition: replace.h:147
PdtLeftParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:405
size_t SCC(Label label) const
Definition: replace.h:527
virtual void SetProperties(uint64 props, uint64 mask)=0