FST  openfst-1.8.2
OpenFst Library
replace.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Recursively replaces FST arcs with other FSTs, returning a PDT.
19 
20 #ifndef FST_EXTENSIONS_PDT_REPLACE_H_
21 #define FST_EXTENSIONS_PDT_REPLACE_H_
22 
23 #include <cstdint>
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 #include <vector>
31 
32 #include <fst/replace-util.h>
33 #include <fst/replace.h>
34 #include <fst/symbol-table-ops.h>
35 #include <unordered_map>
36 
37 namespace fst {
38 namespace internal {
39 
40 // Hash to paren IDs
41 template <typename S>
43  size_t operator()(const std::pair<size_t, S> &paren) const {
44  static constexpr auto prime = 7853;
45  return paren.first + paren.second * prime;
46  }
47 };
48 
49 } // namespace internal
50 
51 // Parser types characterize the PDT construction method. When applied to a CFG,
52 // each non-terminal is encoded as a DFA that accepts precisely the RHS's of
53 // productions of that non-terminal. For parsing (rather than just recognition),
54 // production numbers can used as outputs (placed as early as possible) in the
55 // DFAs promoted to DFTs. For more information on the strongly regular
56 // construction, see:
57 //
58 // Mohri, M., and Pereira, F. 1998. Dynamic compilation of weighted context-free
59 // grammars. In Proc. ACL, pages 891-897.
60 enum class PdtParserType : uint8_t {
61  // Top-down construction. Applied to a simple LL(1) grammar (among others),
62  // gives a DPDA. If promoted to a DPDT, with outputs being production
63  // numbers, gives a leftmost derivation. Left recursive grammars are
64  // problematic in use.
65  LEFT,
66 
67  // Top-down construction. Similar to LEFT except bounded-stack
68  // (expandable as an FST) result with regular or, more generally, strongly
69  // regular grammars. Epsilons may replace some parentheses, which may
70  // introduce some non-determinism.
71  LEFT_SR,
72 
73  /* TODO(riley):
74  // Bottom-up construction. Applied to a LR(0) grammar, gives a DPDA.
75  // If promoted to a DPDT, with outputs being the production numbers,
76  // gives the reverse of a rightmost derivation.
77  RIGHT,
78  */
79 };
80 
81 template <class Arc>
83  using Label = typename Arc::Label;
84 
85  explicit PdtReplaceOptions(Label root,
87  Label start_paren_labels = kNoLabel,
88  std::string left_paren_prefix = "(_",
89  std::string right_paren_prefix = ")_")
90  : root(root),
91  type(type),
92  start_paren_labels(start_paren_labels),
93  left_paren_prefix(std::move(left_paren_prefix)),
94  right_paren_prefix(std::move(right_paren_prefix)) {}
95 
99  const std::string left_paren_prefix;
100  const std::string right_paren_prefix;
101 };
102 
103 // PdtParser: Base PDT parser class common to specific parsers.
104 
105 template <class Arc>
106 class PdtParser {
107  public:
108  using Label = typename Arc::Label;
109  using StateId = typename Arc::StateId;
110  using Weight = typename Arc::Weight;
111  using LabelFstPair = std::pair<Label, const Fst<Arc> *>;
112  using LabelPair = std::pair<Label, Label>;
113  using LabelStatePair = std::pair<Label, StateId>;
114  using StateWeightPair = std::pair<StateId, Weight>;
115  using ParenKey = std::pair<size_t, StateId>;
116  using ParenMap = std::unordered_map<ParenKey, size_t,
118 
119  PdtParser(const std::vector<LabelFstPair> &fst_array,
120  const PdtReplaceOptions<Arc> &opts)
121  : root_(opts.root),
122  start_paren_labels_(opts.start_paren_labels),
123  left_paren_prefix_(std::move(opts.left_paren_prefix)),
124  right_paren_prefix_(std::move(opts.right_paren_prefix)),
125  error_(false) {
126  for (size_t i = 0; i < fst_array.size(); ++i) {
127  if (!CompatSymbols(fst_array[0].second->InputSymbols(),
128  fst_array[i].second->InputSymbols())) {
129  FSTERROR() << "PdtParser: Input symbol table of input FST " << i
130  << " does not match input symbol table of 0th input FST";
131  error_ = true;
132  }
133  if (!CompatSymbols(fst_array[0].second->OutputSymbols(),
134  fst_array[i].second->OutputSymbols())) {
135  FSTERROR() << "PdtParser: Output symbol table of input FST " << i
136  << " does not match output symbol table of 0th input FST";
137  error_ = true;
138  }
139  fst_array_.emplace_back(fst_array[i].first, fst_array[i].second->Copy());
140  // Builds map from non-terminal label to FST ID.
141  label2id_[fst_array[i].first] = i;
142  }
143  }
144 
145  virtual ~PdtParser() {
146  for (auto &pair : fst_array_) delete pair.second;
147  }
148 
149  // Constructs the output PDT, dependent on the derived parser type.
150  virtual void GetParser(MutableFst<Arc> *ofst,
151  std::vector<LabelPair> *parens) = 0;
152 
153  protected:
154  const std::vector<LabelFstPair> &FstArray() const { return fst_array_; }
155 
156  Label Root() const { return root_; }
157 
158  // Maps from non-terminal label to corresponding FST ID, or returns
159  // kNoStateId to signal lookup failure.
160  StateId Label2Id(Label l) const {
161  auto it = label2id_.find(l);
162  return it == label2id_.end() ? kNoStateId : it->second;
163  }
164 
165  // Maps from output state to input FST label, state pair, or returns a
166  // (kNoLabel, kNoStateId) pair to signal lookup failure.
168  if (os >= label_state_pairs_.size()) {
169  static const LabelStatePair no_pair(kNoLabel, kNoLabel);
170  return no_pair;
171  } else {
172  return label_state_pairs_[os];
173  }
174  }
175 
176  // Maps to output state from input FST (label, state) pair, or returns
177  // kNoStateId to signal lookup failure.
178  StateId GetState(const LabelStatePair &lsp) const {
179  auto it = state_map_.find(lsp);
180  if (it == state_map_.end()) {
181  return kNoStateId;
182  } else {
183  return it->second;
184  }
185  }
186 
187  // Builds single FST combining all referenced input FSTs, leaving in the
188  // non-termnals for now; also tabulates the PDT states that correspond to the
189  // start and final states of the input FSTs.
190  void CreateFst(MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
191  std::vector<std::vector<StateWeightPair>> *close_src);
192 
193  // Assigns parenthesis labels from total allocated paren IDs.
194  void AssignParenLabels(size_t total_nparens, std::vector<LabelPair> *parens) {
195  parens->clear();
196  for (size_t paren_id = 0; paren_id < total_nparens; ++paren_id) {
197  const auto open_paren = start_paren_labels_ + paren_id;
198  const auto close_paren = open_paren + total_nparens;
199  parens->emplace_back(open_paren, close_paren);
200  }
201  }
202 
203  // Determines how non-terminal instances are assigned parentheses IDs.
204  virtual size_t AssignParenIds(const Fst<Arc> &ofst,
205  ParenMap *paren_map) const = 0;
206 
207  // Changes a non-terminal transition to an open parenthesis transition
208  // redirected to the PDT state specified in the open_dest argument, when
209  // indexed by the input FST ID for the non-terminal. Adds close parenthesis
210  // transitions (with specified weights) from the PDT states specified in the
211  // close_src argument, when indexed by the input FST ID for the non-terminal,
212  // to the former destination state of the non-terminal transition. The
213  // paren_map argument gives the parenthesis ID for a given non-terminal FST ID
214  // and destination state pair. The close_non_term_weight vector specifies
215  // non-terminals for which the non-terminal arc weight should be applied on
216  // the close parenthesis (multiplying the close_src weight above) rather than
217  // on the open parenthesis. If no paren ID is found, then an epsilon replaces
218  // the parenthesis that would carry the non-terminal arc weight and the other
219  // parenthesis is omitted (appropriate for the strongly-regular case).
220  void AddParensToFst(
221  const std::vector<LabelPair> &parens, const ParenMap &paren_map,
222  const std::vector<StateId> &open_dest,
223  const std::vector<std::vector<StateWeightPair>> &close_src,
224  const std::vector<bool> &close_non_term_weight, MutableFst<Arc> *ofst);
225 
226  // Ensures that parentheses arcs are added to the symbol table.
227  void AddParensToSymbolTables(const std::vector<LabelPair> &parens,
228  MutableFst<Arc> *ofst);
229 
230  private:
231  std::vector<LabelFstPair> fst_array_;
232  Label root_;
233  // Index to use for the first parenthesis.
234  Label start_paren_labels_;
235  const std::string left_paren_prefix_;
236  const std::string right_paren_prefix_;
237  // Maps from non-terminal label to FST ID.
238  std::unordered_map<Label, StateId> label2id_;
239  // Given an output state, specifies the input FST (label, state) pair.
240  std::vector<LabelStatePair> label_state_pairs_;
241  // Given an FST (label, state) pair, specifies the output FST state ID.
242  std::map<LabelStatePair, StateId> state_map_;
243  bool error_;
244 };
245 
246 template <class Arc>
248  MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
249  std::vector<std::vector<StateWeightPair>> *close_src) {
250  ofst->DeleteStates();
251  if (error_) {
252  ofst->SetProperties(kError, kError);
253  return;
254  }
255  open_dest->resize(fst_array_.size(), kNoStateId);
256  close_src->resize(fst_array_.size());
257  // Queue of non-terminals to replace.
258  std::deque<Label> non_term_queue;
259  non_term_queue.push_back(root_);
260  // Has a non-terminal been enqueued?
261  std::vector<bool> enqueued(fst_array_.size(), false);
262  enqueued[label2id_[root_]] = true;
263  Label max_label = kNoLabel;
264  for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
265  const auto label = non_term_queue.front();
266  non_term_queue.pop_front();
267  StateId fst_id = Label2Id(label);
268  const auto *ifst = fst_array_[fst_id].second;
269  for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
270  const auto is = siter.Value();
271  const auto os = ofst->AddState();
272  const LabelStatePair lsp(label, is);
273  label_state_pairs_.push_back(lsp);
274  state_map_[lsp] = os;
275  if (is == ifst->Start()) {
276  (*open_dest)[fst_id] = os;
277  if (label == root_) ofst->SetStart(os);
278  }
279  if (ifst->Final(is) != Weight::Zero()) {
280  if (label == root_) ofst->SetFinal(os, ifst->Final(is));
281  (*close_src)[fst_id].emplace_back(os, ifst->Final(is));
282  }
283  for (ArcIterator<Fst<Arc>> aiter(*ifst, is); !aiter.Done();
284  aiter.Next()) {
285  auto arc = aiter.Value();
286  arc.nextstate += soff;
287  if (max_label == kNoLabel || arc.olabel > max_label)
288  max_label = arc.olabel;
289  const auto nfst_id = Label2Id(arc.olabel);
290  if (nfst_id != kNoStateId) {
291  if (fst_array_[nfst_id].second->Start() == kNoStateId) continue;
292  if (!enqueued[nfst_id]) {
293  non_term_queue.push_back(arc.olabel);
294  enqueued[nfst_id] = true;
295  }
296  }
297  ofst->AddArc(os, arc);
298  }
299  }
300  }
301  if (start_paren_labels_ == kNoLabel) start_paren_labels_ = max_label + 1;
302 }
303 
304 template <class Arc>
306  const std::vector<LabelPair> &parens, const ParenMap &paren_map,
307  const std::vector<StateId> &open_dest,
308  const std::vector<std::vector<StateWeightPair>> &close_src,
309  const std::vector<bool> &close_non_term_weight, MutableFst<Arc> *ofst) {
310  StateId dead_state = kNoStateId;
311  using MIter = MutableArcIterator<MutableFst<Arc>>;
312  for (StateIterator<Fst<Arc>> siter(*ofst); !siter.Done(); siter.Next()) {
313  StateId os = siter.Value();
314  std::unique_ptr<MIter> aiter(new MIter(ofst, os));
315  for (auto n = 0; !aiter->Done(); aiter->Next(), ++n) {
316  const auto arc = aiter->Value(); // A reference here may go stale.
317  StateId nfst_id = Label2Id(arc.olabel);
318  if (nfst_id != kNoStateId) {
319  // Gets parentheses.
320  const ParenKey paren_key(nfst_id, arc.nextstate);
321  auto it = paren_map.find(paren_key);
322  Label open_paren = 0;
323  Label close_paren = 0;
324  if (it != paren_map.end()) {
325  const auto paren_id = it->second;
326  open_paren = parens[paren_id].first;
327  close_paren = parens[paren_id].second;
328  }
329  // Sets open parenthesis.
330  if (open_paren != 0 || !close_non_term_weight[nfst_id]) {
331  const auto open_weight =
332  close_non_term_weight[nfst_id] ? Weight::One() : arc.weight;
333  const Arc sarc(open_paren, open_paren, open_weight,
334  open_dest[nfst_id]);
335  aiter->SetValue(sarc);
336  } else {
337  if (dead_state == kNoStateId) {
338  dead_state = ofst->AddState();
339  }
340  const Arc sarc(0, 0, Weight::One(), dead_state);
341  aiter->SetValue(sarc);
342  }
343  // Adds close parentheses.
344  if (close_paren != 0 || close_non_term_weight[nfst_id]) {
345  for (size_t i = 0; i < close_src[nfst_id].size(); ++i) {
346  const auto &pair = close_src[nfst_id][i];
347  const auto close_weight = close_non_term_weight[nfst_id]
348  ? Times(arc.weight, pair.second)
349  : pair.second;
350  const Arc farc(close_paren, close_paren, close_weight,
351  arc.nextstate);
352 
353  ofst->AddArc(pair.first, farc);
354  if (os == pair.first) { // Invalidated iterator.
355  aiter.reset(new MIter(ofst, os));
356  aiter->Seek(n);
357  }
358  }
359  }
360  }
361  }
362  }
363 }
364 
365 template <class Arc>
367  const std::vector<LabelPair> &parens, MutableFst<Arc> *ofst) {
368  auto size = parens.size();
369  if (ofst->InputSymbols()) {
370  if (!AddAuxiliarySymbols(left_paren_prefix_, start_paren_labels_, size,
371  ofst->MutableInputSymbols())) {
372  ofst->SetProperties(kError, kError);
373  return;
374  }
375  if (!AddAuxiliarySymbols(right_paren_prefix_, start_paren_labels_ + size,
376  size, ofst->MutableInputSymbols())) {
377  ofst->SetProperties(kError, kError);
378  return;
379  }
380  }
381  if (ofst->OutputSymbols()) {
382  if (!AddAuxiliarySymbols(left_paren_prefix_, start_paren_labels_, size,
383  ofst->MutableOutputSymbols())) {
384  ofst->SetProperties(kError, kError);
385  return;
386  }
387  if (!AddAuxiliarySymbols(right_paren_prefix_, start_paren_labels_ + size,
388  size, ofst->MutableOutputSymbols())) {
389  ofst->SetProperties(kError, kError);
390  return;
391  }
392  }
393 }
394 
395 // Builds a PDT by recursive replacement top-down, where the call and return are
396 // encoded in the parentheses.
397 template <class Arc>
398 class PdtLeftParser final : public PdtParser<Arc> {
399  public:
400  using Label = typename Arc::Label;
401  using StateId = typename Arc::StateId;
402  using Weight = typename Arc::Weight;
409 
418  using PdtParser<Arc>::Root;
419 
420  PdtLeftParser(const std::vector<LabelFstPair> &fst_array,
421  const PdtReplaceOptions<Arc> &opts)
422  : PdtParser<Arc>(fst_array, opts) {}
423 
424  void GetParser(MutableFst<Arc> *ofst,
425  std::vector<LabelPair> *parens) override;
426 
427  protected:
428  // Assigns a unique parenthesis ID for each non-terminal, destination
429  // state pair.
430  size_t AssignParenIds(const Fst<Arc> &ofst,
431  ParenMap *paren_map) const override;
432 };
433 
434 template <class Arc>
436  std::vector<LabelPair> *parens) {
437  ofst->DeleteStates();
438  parens->clear();
439  const auto &fst_array = FstArray();
440  // Map that gives the paren ID for a (non-terminal, dest. state) pair
441  // (which can be unique).
442  ParenMap paren_map;
443  // Specifies the open parenthesis destination state for a given non-terminal.
444  // The source is the non-terminal instance source state.
445  std::vector<StateId> open_dest(fst_array.size(), kNoStateId);
446  // Specifies close parenthesis source states and weights for a given
447  // non-terminal. The destination is the non-terminal instance destination
448  // state.
449  std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
450  // Specifies non-terminals for which the non-terminal arc weight
451  // should be applied on the close parenthesis (multiplying the
452  // 'close_src' weight above) rather than on the open parenthesis.
453  std::vector<bool> close_non_term_weight(fst_array.size(), false);
454  CreateFst(ofst, &open_dest, &close_src);
455  auto total_nparens = AssignParenIds(*ofst, &paren_map);
456  AssignParenLabels(total_nparens, parens);
457  AddParensToFst(*parens, paren_map, open_dest, close_src,
458  close_non_term_weight, ofst);
459  if (!fst_array.empty()) {
460  ofst->SetInputSymbols(fst_array[0].second->InputSymbols());
461  ofst->SetOutputSymbols(fst_array[0].second->OutputSymbols());
462  }
463  AddParensToSymbolTables(*parens, ofst);
464 }
465 
466 template <class Arc>
468  ParenMap *paren_map) const {
469  // Number of distinct parenthesis pairs per FST.
470  std::vector<size_t> nparens(FstArray().size(), 0);
471  // Number of distinct parenthesis pairs overall.
472  size_t total_nparens = 0;
473  for (StateIterator<Fst<Arc>> siter(ofst); !siter.Done(); siter.Next()) {
474  const auto os = siter.Value();
475  for (ArcIterator<Fst<Arc>> aiter(ofst, os); !aiter.Done(); aiter.Next()) {
476  const auto &arc = aiter.Value();
477  const auto nfst_id = Label2Id(arc.olabel);
478  if (nfst_id != kNoStateId) {
479  const ParenKey paren_key(nfst_id, arc.nextstate);
480  auto it = paren_map->find(paren_key);
481  if (it == paren_map->end()) {
482  // Assigns new paren ID for this (FST, dest state) pair.
483  (*paren_map)[paren_key] = nparens[nfst_id]++;
484  if (nparens[nfst_id] > total_nparens)
485  total_nparens = nparens[nfst_id];
486  }
487  }
488  }
489  }
490  return total_nparens;
491 }
492 
493 // Similar to PdtLeftParser but:
494 //
495 // 1. Uses epsilons rather than parentheses labels for any non-terminal
496 // instances within a left- (right-) linear dependency SCC,
497 // 2. Allocates a paren ID uniquely for each such dependency SCC (rather than
498 // non-terminal = dependency state) and destination state.
499 template <class Arc>
500 class PdtLeftSRParser final : public PdtParser<Arc> {
501  public:
502  using Label = typename Arc::Label;
503  using StateId = typename Arc::StateId;
504  using Weight = typename Arc::Weight;
511 
520  using PdtParser<Arc>::Root;
521 
522  PdtLeftSRParser(const std::vector<LabelFstPair> &fst_array,
523  const PdtReplaceOptions<Arc> &opts)
524  : PdtParser<Arc>(fst_array, opts),
525  replace_util_(fst_array, ReplaceUtilOptions(opts.root)) {}
526 
527  void GetParser(MutableFst<Arc> *ofst,
528  std::vector<LabelPair> *parens) override;
529 
530  protected:
531  // Assigns a unique parenthesis ID for each non-terminal, destination state
532  // pair when the non-terminal refers to a non-linear FST. Otherwise, assigns
533  // a unique parenthesis ID for each dependency SCC, destination state pair if
534  // the non-terminal instance is between
535  // SCCs. Otherwise does nothing.
536  size_t AssignParenIds(const Fst<Arc> &ofst,
537  ParenMap *paren_map) const override;
538 
539  // Returns dependency SCC for given label.
540  size_t SCC(Label label) const { return replace_util_.SCC(label); }
541 
542  // Is a given dependency SCC left-linear?
543  bool SCCLeftLinear(size_t scc_id) const {
544  const auto ll_props = kReplaceSCCLeftLinear | kReplaceSCCNonTrivial;
545  const auto scc_props = replace_util_.SCCProperties(scc_id);
546  return (scc_props & ll_props) == ll_props;
547  }
548 
549  // Is a given dependency SCC right-linear?
550  bool SCCRightLinear(size_t scc_id) const {
551  const auto lr_props = kReplaceSCCRightLinear | kReplaceSCCNonTrivial;
552  const auto scc_props = replace_util_.SCCProperties(scc_id);
553  return (scc_props & lr_props) == lr_props;
554  }
555 
556  // Components of left- (right-) linear dependency SCC; empty o.w.
557  const std::vector<size_t> &SCCComps(size_t scc_id) const {
558  if (scc_comps_.empty()) GetSCCComps();
559  return scc_comps_[scc_id];
560  }
561 
562  // Returns the representative state of an SCC. For left-linear grammars, it
563  // is one of the initial states. For right-linear grammars, it is one of the
564  // non-terminal destination states; otherwise, it is kNoStateId.
565  StateId RepState(size_t scc_id) const {
566  if (SCCComps(scc_id).empty()) return kNoStateId;
567  const auto fst_id = SCCComps(scc_id).front();
568  const auto &fst_array = FstArray();
569  const auto label = fst_array[fst_id].first;
570  const auto *ifst = fst_array[fst_id].second;
571  if (SCCLeftLinear(scc_id)) {
572  const LabelStatePair lsp(label, ifst->Start());
573  return GetState(lsp);
574  } else { // Right-linear.
575  const LabelStatePair lsp(label, *NonTermDests(fst_id).begin());
576  return GetState(lsp);
577  }
578  return kNoStateId;
579  }
580 
581  private:
582  // Merges initial (final) states of in a left- (right-) linear dependency SCC
583  // after dealing with the non-terminal arc and final weights.
584  void ProcSCCs(MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
585  std::vector<std::vector<StateWeightPair>> *close_src,
586  std::vector<bool> *close_non_term_weight) const;
587 
588  // Computes components of left- (right-) linear dependency SCC.
589  void GetSCCComps() const {
590  const std::vector<LabelFstPair> &fst_array = FstArray();
591  for (size_t i = 0; i < fst_array.size(); ++i) {
592  const auto label = fst_array[i].first;
593  const auto scc_id = SCC(label);
594  if (scc_comps_.size() <= scc_id) scc_comps_.resize(scc_id + 1);
595  if (SCCLeftLinear(scc_id) || SCCRightLinear(scc_id)) {
596  scc_comps_[scc_id].push_back(i);
597  }
598  }
599  }
600 
601  const std::set<StateId> &NonTermDests(StateId fst_id) const {
602  if (non_term_dests_.empty()) GetNonTermDests();
603  return non_term_dests_[fst_id];
604  }
605 
606  // Finds non-terminal destination states for right-linear FSTS, or does
607  // nothing if not found.
608  void GetNonTermDests() const;
609 
610  // Dependency SCC info.
611  mutable ReplaceUtil<Arc> replace_util_;
612  // Components of left- (right-) linear dependency SCCs, or empty otherwise.
613  mutable std::vector<std::vector<size_t>> scc_comps_;
614  // States that have non-terminals entering them for each (right-linear) FST.
615  mutable std::vector<std::set<StateId>> non_term_dests_;
616 };
617 
618 template <class Arc>
620  std::vector<LabelPair> *parens) {
621  ofst->DeleteStates();
622  parens->clear();
623  const auto &fst_array = FstArray();
624  // Map that gives the paren ID for a (non-terminal, dest. state) pair.
625  ParenMap paren_map;
626  // Specifies the open parenthesis destination state for a given non-terminal.
627  // The source is the non-terminal instance source state.
628  std::vector<StateId> open_dest(fst_array.size(), kNoStateId);
629  // Specifies close parenthesis source states and weights for a given
630  // non-terminal. The destination is the non-terminal instance destination
631  // state.
632  std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
633  // Specifies non-terminals for which the non-terminal arc weight should be
634  // applied on the close parenthesis (multiplying the close_src weight above)
635  // rather than on the open parenthesis.
636  std::vector<bool> close_non_term_weight(fst_array.size(), false);
637  CreateFst(ofst, &open_dest, &close_src);
638  ProcSCCs(ofst, &open_dest, &close_src, &close_non_term_weight);
639  const auto total_nparens = AssignParenIds(*ofst, &paren_map);
640  AssignParenLabels(total_nparens, parens);
641  AddParensToFst(*parens, paren_map, open_dest, close_src,
642  close_non_term_weight, ofst);
643  if (!fst_array.empty()) {
644  ofst->SetInputSymbols(fst_array[0].second->InputSymbols());
645  ofst->SetOutputSymbols(fst_array[0].second->OutputSymbols());
646  }
647  AddParensToSymbolTables(*parens, ofst);
648  Connect(ofst);
649 }
650 
651 template <class Arc>
653  MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
654  std::vector<std::vector<StateWeightPair>> *close_src,
655  std::vector<bool> *close_non_term_weight) const {
656  const auto &fst_array = FstArray();
657  for (StateIterator<Fst<Arc>> siter(*ofst); !siter.Done(); siter.Next()) {
658  const auto os = siter.Value();
659  const auto label = GetLabelStatePair(os).first;
660  const auto is = GetLabelStatePair(os).second;
661  const auto fst_id = Label2Id(label);
662  const auto scc_id = SCC(label);
663  const auto rs = RepState(scc_id);
664  const auto *ifst = fst_array[fst_id].second;
665  // SCC LEFT-LINEAR: puts non-terminal weights on close parentheses. Merges
666  // initial states into SCC representative state and updates open_dest.
667  if (SCCLeftLinear(scc_id)) {
668  (*close_non_term_weight)[fst_id] = true;
669  if (is == ifst->Start() && os != rs) {
670  for (ArcIterator<Fst<Arc>> aiter(*ofst, os); !aiter.Done();
671  aiter.Next()) {
672  const auto &arc = aiter.Value();
673  ofst->AddArc(rs, arc);
674  }
675  ofst->DeleteArcs(os);
676  if (os == ofst->Start()) ofst->SetStart(rs);
677  (*open_dest)[fst_id] = rs;
678  }
679  }
680  // SCC RIGHT-LINEAR: pushes back final weights onto non-terminals, if
681  // possible, or adds weighted epsilons to the SCC representative state.
682  // Merges final states into SCC representative state and updates close_src.
683  if (SCCRightLinear(scc_id)) {
684  for (MutableArcIterator<MutableFst<Arc>> aiter(ofst, os); !aiter.Done();
685  aiter.Next()) {
686  auto arc = aiter.Value();
687  const auto idest = GetLabelStatePair(arc.nextstate).second;
688  if (NonTermDests(fst_id).count(idest) > 0) {
689  if (ofst->Final(arc.nextstate) != Weight::Zero()) {
690  ofst->SetFinal(arc.nextstate, Weight::Zero());
691  ofst->SetFinal(rs, Weight::One());
692  }
693  arc.weight = Times(arc.weight, ifst->Final(idest));
694  arc.nextstate = rs;
695  aiter.SetValue(arc);
696  }
697  }
698  const auto final_weight = ifst->Final(is);
699  if (final_weight != Weight::Zero() &&
700  NonTermDests(fst_id).count(is) == 0) {
701  ofst->AddArc(os, Arc(0, 0, final_weight, rs));
702  if (ofst->Final(os) != Weight::Zero()) {
703  ofst->SetFinal(os, Weight::Zero());
704  ofst->SetFinal(rs, Weight::One());
705  }
706  }
707  if (is == ifst->Start()) {
708  (*close_src)[fst_id].clear();
709  (*close_src)[fst_id].emplace_back(rs, Weight::One());
710  }
711  }
712  }
713 }
714 
715 template <class Arc>
717  const auto &fst_array = FstArray();
718  non_term_dests_.resize(fst_array.size());
719  for (size_t fst_id = 0; fst_id < fst_array.size(); ++fst_id) {
720  const auto label = fst_array[fst_id].first;
721  const auto scc_id = SCC(label);
722  if (SCCRightLinear(scc_id)) {
723  const auto *ifst = fst_array[fst_id].second;
724  for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
725  const auto is = siter.Value();
726  for (ArcIterator<Fst<Arc>> aiter(*ifst, is); !aiter.Done();
727  aiter.Next()) {
728  const auto &arc = aiter.Value();
729  if (Label2Id(arc.olabel) != kNoStateId) {
730  non_term_dests_[fst_id].insert(arc.nextstate);
731  }
732  }
733  }
734  }
735  }
736 }
737 
738 template <class Arc>
740  ParenMap *paren_map) const {
741  const auto &fst_array = FstArray();
742  // Number of distinct parenthesis pairs per FST.
743  std::vector<size_t> nparens(fst_array.size(), 0);
744  // Number of distinct parenthesis pairs overall.
745  size_t total_nparens = 0;
746  for (StateIterator<Fst<Arc>> siter(ofst); !siter.Done(); siter.Next()) {
747  const auto os = siter.Value();
748  const auto label = GetLabelStatePair(os).first;
749  const auto scc_id = SCC(label);
750  for (ArcIterator<Fst<Arc>> aiter(ofst, os); !aiter.Done(); aiter.Next()) {
751  const auto &arc = aiter.Value();
752  const auto nfst_id = Label2Id(arc.olabel);
753  if (nfst_id != kNoStateId) {
754  size_t nscc_id = SCC(arc.olabel);
755  bool nscc_linear = !SCCComps(nscc_id).empty();
756  // Assigns a parenthesis ID for the non-terminal transition
757  // if the non-terminal belongs to a (left-/right-) linear dependency
758  // SCC or if the transition is in an FST from a different SCC
759  if (!nscc_linear || scc_id != nscc_id) {
760  // For (left-/right-) linear SCCs instead of using nfst_id, we
761  // will use its SCC prototype pfst_id for assigning distinct
762  // parenthesis IDs.
763  const auto pfst_id =
764  nscc_linear ? SCCComps(nscc_id).front() : nfst_id;
765  ParenKey paren_key(pfst_id, arc.nextstate);
766  const auto it = paren_map->find(paren_key);
767  if (it == paren_map->end()) {
768  // Assigns new paren ID for this (FST/SCC, dest. state) pair.
769  if (nscc_linear) {
770  // This is mapping we'll need, but we also store (harmlessly)
771  // for the prototype below so we can easily keep count per SCC.
772  const ParenKey nparen_key(nfst_id, arc.nextstate);
773  (*paren_map)[nparen_key] = nparens[pfst_id];
774  }
775  (*paren_map)[paren_key] = nparens[pfst_id]++;
776  if (nparens[pfst_id] > total_nparens) {
777  total_nparens = nparens[pfst_id];
778  }
779  }
780  }
781  }
782  }
783  }
784  return total_nparens;
785 }
786 
787 // Builds a pushdown transducer (PDT) from an RTN specification. The result is
788 // a PDT written to a mutable FST where some transitions are labeled with
789 // open or close parentheses. To be interpreted as a PDT, the parens must
790 // balance on a path (see PdtExpand()). The open/close parenthesis label pairs
791 // are returned in the parens argument.
792 template <class Arc>
793 void Replace(
794  const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
795  &ifst_array,
796  MutableFst<Arc> *ofst,
797  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
798  const PdtReplaceOptions<Arc> &opts) {
799  switch (opts.type) {
800  case PdtParserType::LEFT: {
801  PdtLeftParser<Arc> pr(ifst_array, opts);
802  pr.GetParser(ofst, parens);
803  return;
804  }
805  case PdtParserType::LEFT_SR: {
806  PdtLeftSRParser<Arc> pr(ifst_array, opts);
807  pr.GetParser(ofst, parens);
808  return;
809  }
810  default:
811  FSTERROR() << "Replace: Unknown PDT parser type: "
812  << static_cast<std::underlying_type_t<PdtParserType>>(
813  opts.type);
814  ofst->DeleteStates();
815  ofst->SetProperties(kError, kError);
816  parens->clear();
817  return;
818  }
819 }
820 
821 // Variant where the only user-controlled arguments is the root ID.
822 template <class Arc>
823 void Replace(
824  const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
825  &ifst_array,
826  MutableFst<Arc> *ofst,
827  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
828  typename Arc::Label root) {
829  PdtReplaceOptions<Arc> opts(root);
830  Replace(ifst_array, ofst, parens, opts);
831 }
832 
833 } // namespace fst
834 
835 #endif // FST_EXTENSIONS_PDT_REPLACE_H_
virtual ~PdtParser()
Definition: replace.h:145
const std::string right_paren_prefix
Definition: replace.h:100
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
Definition: replace.h:619
std::unordered_map< ParenKey, size_t, internal::ReplaceParenHash< StateId >> ParenMap
Definition: replace.h:117
constexpr int kNoLabel
Definition: fst.h:201
virtual SymbolTable * MutableOutputSymbols()=0
PdtLeftSRParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:522
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
Definition: replace.h:467
StateId RepState(size_t scc_id) const
Definition: replace.h:565
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
Definition: properties.h:51
virtual void SetInputSymbols(const SymbolTable *isyms)=0
StateId Label2Id(Label l) const
Definition: replace.h:160
Label Root() const
Definition: replace.h:156
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
Definition: replace.h:739
virtual Weight Final(StateId) const =0
PdtParserType
Definition: replace.h:60
virtual void SetStart(StateId)=0
bool SCCLeftLinear(size_t scc_id) const
Definition: replace.h:543
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:278
constexpr uint8_t kReplaceSCCLeftLinear
Definition: replace-util.h:82
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:793
size_t operator()(const std::pair< size_t, S > &paren) const
Definition: replace.h:43
const std::string left_paren_prefix
Definition: replace.h:99
void AssignParenLabels(size_t total_nparens, std::vector< LabelPair > *parens)
Definition: replace.h:194
constexpr int kNoStateId
Definition: fst.h:202
std::pair< Label, StateId > LabelStatePair
Definition: replace.h:113
bool SCCRightLinear(size_t scc_id) const
Definition: replace.h:550
const Arc & Value() const
Definition: mutable-fst.h:231
typename Arc::StateId StateId
Definition: replace.h:109
PdtReplaceOptions(Label root, PdtParserType type=PdtParserType::LEFT, Label start_paren_labels=kNoLabel, std::string left_paren_prefix="(_", std::string right_paren_prefix=")_")
Definition: replace.h:85
typename Arc::Label Label
Definition: replace.h:108
void CreateFst(MutableFst< Arc > *ofst, std::vector< StateId > *open_dest, std::vector< std::vector< StateWeightPair >> *close_src)
Definition: replace.h:247
#define FSTERROR()
Definition: util.h:53
const SymbolTable * OutputSymbols() const override=0
std::pair< StateId, Weight > StateWeightPair
Definition: replace.h:114
constexpr uint8_t kReplaceSCCNonTrivial
Definition: replace-util.h:89
const std::vector< size_t > & SCCComps(size_t scc_id) const
Definition: replace.h:557
PdtParserType type
Definition: replace.h:97
virtual void SetProperties(uint64_t props, uint64_t mask)=0
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
Definition: replace.h:435
virtual void DeleteArcs(StateId, size_t)=0
virtual StateId Start() const =0
StateId GetState(const LabelStatePair &lsp) const
Definition: replace.h:178
PdtParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:119
std::pair< size_t, StateId > ParenKey
Definition: replace.h:115
const std::vector< LabelFstPair > & FstArray() const
Definition: replace.h:154
std::pair< Label, Label > LabelPair
Definition: replace.h:112
void AddParensToSymbolTables(const std::vector< LabelPair > &parens, MutableFst< Arc > *ofst)
Definition: replace.h:366
typename Arc::Label Label
Definition: replace.h:83
virtual SymbolTable * MutableInputSymbols()=0
virtual void AddArc(StateId, const Arc &)=0
constexpr uint8_t kReplaceSCCRightLinear
Definition: replace-util.h:86
Label start_paren_labels
Definition: replace.h:98
virtual StateId AddState()=0
std::pair< Label, const Fst< Arc > * > LabelFstPair
Definition: replace.h:111
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
void AddParensToFst(const std::vector< LabelPair > &parens, const ParenMap &paren_map, const std::vector< StateId > &open_dest, const std::vector< std::vector< StateWeightPair >> &close_src, const std::vector< bool > &close_non_term_weight, MutableFst< Arc > *ofst)
Definition: replace.h:305
virtual void DeleteStates(const std::vector< StateId > &)=0
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
typename Arc::Weight Weight
Definition: replace.h:110
virtual StateId NumStates() const =0
typename Arc::StateId StateId
Definition: replace.h:503
LabelStatePair GetLabelStatePair(StateId os) const
Definition: replace.h:167
PdtLeftParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:420
bool AddAuxiliarySymbols(const std::string &prefix, int64_t start_label, int64_t nlabels, SymbolTable *syms)
size_t SCC(Label label) const
Definition: replace.h:540