FST  openfst-1.8.3
OpenFst Library
replace.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Recursively replaces FST arcs with other FSTs, returning a PDT.
19 
20 #ifndef FST_EXTENSIONS_PDT_REPLACE_H_
21 #define FST_EXTENSIONS_PDT_REPLACE_H_
22 
23 #include <cstddef>
24 #include <cstdint>
25 #include <deque>
26 #include <map>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <type_traits>
31 #include <utility>
32 #include <vector>
33 
34 #include <fst/log.h>
35 #include <fst/connect.h>
36 #include <fst/fst.h>
37 #include <fst/mutable-fst.h>
38 #include <fst/properties.h>
39 #include <fst/replace-util.h>
40 #include <fst/replace.h>
41 #include <fst/symbol-table-ops.h>
42 #include <fst/symbol-table.h>
43 #include <fst/util.h>
44 #include <unordered_map>
45 
46 namespace fst {
47 namespace internal {
48 
49 // Hash to paren IDs
50 template <typename S>
52  size_t operator()(const std::pair<size_t, S> &paren) const {
53  static constexpr auto prime = 7853;
54  return paren.first + paren.second * prime;
55  }
56 };
57 
58 } // namespace internal
59 
60 // Parser types characterize the PDT construction method. When applied to a CFG,
61 // each non-terminal is encoded as a DFA that accepts precisely the RHS's of
62 // productions of that non-terminal. For parsing (rather than just recognition),
63 // production numbers can used as outputs (placed as early as possible) in the
64 // DFAs promoted to DFTs. For more information on the strongly regular
65 // construction, see:
66 //
67 // Mohri, M., and Pereira, F. 1998. Dynamic compilation of weighted context-free
68 // grammars. In Proc. ACL, pages 891-897.
69 enum class PdtParserType : uint8_t {
70  // Top-down construction. Applied to a simple LL(1) grammar (among others),
71  // gives a DPDA. If promoted to a DPDT, with outputs being production
72  // numbers, gives a leftmost derivation. Left recursive grammars are
73  // problematic in use.
74  LEFT,
75 
76  // Top-down construction. Similar to LEFT except bounded-stack
77  // (expandable as an FST) result with regular or, more generally, strongly
78  // regular grammars. Epsilons may replace some parentheses, which may
79  // introduce some non-determinism.
80  LEFT_SR,
81 
82  /* TODO(riley):
83  // Bottom-up construction. Applied to a LR(0) grammar, gives a DPDA.
84  // If promoted to a DPDT, with outputs being the production numbers,
85  // gives the reverse of a rightmost derivation.
86  RIGHT,
87  */
88 };
89 
90 template <class Arc>
92  using Label = typename Arc::Label;
93 
94  explicit PdtReplaceOptions(Label root,
96  Label start_paren_labels = kNoLabel,
97  std::string left_paren_prefix = "(_",
98  std::string right_paren_prefix = ")_")
99  : root(root),
100  type(type),
101  start_paren_labels(start_paren_labels),
102  left_paren_prefix(std::move(left_paren_prefix)),
103  right_paren_prefix(std::move(right_paren_prefix)) {}
104 
108  const std::string left_paren_prefix;
109  const std::string right_paren_prefix;
110 };
111 
112 // PdtParser: Base PDT parser class common to specific parsers.
113 
114 template <class Arc>
115 class PdtParser {
116  public:
117  using Label = typename Arc::Label;
118  using StateId = typename Arc::StateId;
119  using Weight = typename Arc::Weight;
120  using LabelFstPair = std::pair<Label, const Fst<Arc> *>;
121  using LabelPair = std::pair<Label, Label>;
122  using LabelStatePair = std::pair<Label, StateId>;
123  using StateWeightPair = std::pair<StateId, Weight>;
124  using ParenKey = std::pair<size_t, StateId>;
125  using ParenMap = std::unordered_map<ParenKey, size_t,
127 
128  PdtParser(const std::vector<LabelFstPair> &fst_array,
129  const PdtReplaceOptions<Arc> &opts)
130  : root_(opts.root),
131  start_paren_labels_(opts.start_paren_labels),
132  left_paren_prefix_(std::move(opts.left_paren_prefix)),
133  right_paren_prefix_(std::move(opts.right_paren_prefix)),
134  error_(false) {
135  for (size_t i = 0; i < fst_array.size(); ++i) {
136  if (!CompatSymbols(fst_array[0].second->InputSymbols(),
137  fst_array[i].second->InputSymbols())) {
138  FSTERROR() << "PdtParser: Input symbol table of input FST " << i
139  << " does not match input symbol table of 0th input FST";
140  error_ = true;
141  }
142  if (!CompatSymbols(fst_array[0].second->OutputSymbols(),
143  fst_array[i].second->OutputSymbols())) {
144  FSTERROR() << "PdtParser: Output symbol table of input FST " << i
145  << " does not match output symbol table of 0th input FST";
146  error_ = true;
147  }
148  fst_array_.emplace_back(fst_array[i].first, fst_array[i].second->Copy());
149  // Builds map from non-terminal label to FST ID.
150  label2id_[fst_array[i].first] = i;
151  }
152  }
153 
154  virtual ~PdtParser() {
155  for (auto &pair : fst_array_) delete pair.second;
156  }
157 
158  // Constructs the output PDT, dependent on the derived parser type.
159  virtual void GetParser(MutableFst<Arc> *ofst,
160  std::vector<LabelPair> *parens) = 0;
161 
162  protected:
163  const std::vector<LabelFstPair> &FstArray() const { return fst_array_; }
164 
165  Label Root() const { return root_; }
166 
167  // Maps from non-terminal label to corresponding FST ID, or returns
168  // kNoStateId to signal lookup failure.
169  StateId Label2Id(Label l) const {
170  auto it = label2id_.find(l);
171  return it == label2id_.end() ? kNoStateId : it->second;
172  }
173 
174  // Maps from output state to input FST label, state pair, or returns a
175  // (kNoLabel, kNoStateId) pair to signal lookup failure.
177  if (os >= label_state_pairs_.size()) {
178  static const LabelStatePair no_pair(kNoLabel, kNoLabel);
179  return no_pair;
180  } else {
181  return label_state_pairs_[os];
182  }
183  }
184 
185  // Maps to output state from input FST (label, state) pair, or returns
186  // kNoStateId to signal lookup failure.
187  StateId GetState(const LabelStatePair &lsp) const {
188  auto it = state_map_.find(lsp);
189  if (it == state_map_.end()) {
190  return kNoStateId;
191  } else {
192  return it->second;
193  }
194  }
195 
196  // Builds single FST combining all referenced input FSTs, leaving in the
197  // non-termnals for now; also tabulates the PDT states that correspond to the
198  // start and final states of the input FSTs.
199  void CreateFst(MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
200  std::vector<std::vector<StateWeightPair>> *close_src);
201 
202  // Assigns parenthesis labels from total allocated paren IDs.
203  void AssignParenLabels(size_t total_nparens, std::vector<LabelPair> *parens) {
204  parens->clear();
205  for (size_t paren_id = 0; paren_id < total_nparens; ++paren_id) {
206  const auto open_paren = start_paren_labels_ + paren_id;
207  const auto close_paren = open_paren + total_nparens;
208  parens->emplace_back(open_paren, close_paren);
209  }
210  }
211 
212  // Determines how non-terminal instances are assigned parentheses IDs.
213  virtual size_t AssignParenIds(const Fst<Arc> &ofst,
214  ParenMap *paren_map) const = 0;
215 
216  // Changes a non-terminal transition to an open parenthesis transition
217  // redirected to the PDT state specified in the open_dest argument, when
218  // indexed by the input FST ID for the non-terminal. Adds close parenthesis
219  // transitions (with specified weights) from the PDT states specified in the
220  // close_src argument, when indexed by the input FST ID for the non-terminal,
221  // to the former destination state of the non-terminal transition. The
222  // paren_map argument gives the parenthesis ID for a given non-terminal FST ID
223  // and destination state pair. The close_non_term_weight vector specifies
224  // non-terminals for which the non-terminal arc weight should be applied on
225  // the close parenthesis (multiplying the close_src weight above) rather than
226  // on the open parenthesis. If no paren ID is found, then an epsilon replaces
227  // the parenthesis that would carry the non-terminal arc weight and the other
228  // parenthesis is omitted (appropriate for the strongly-regular case).
229  void AddParensToFst(
230  const std::vector<LabelPair> &parens, const ParenMap &paren_map,
231  const std::vector<StateId> &open_dest,
232  const std::vector<std::vector<StateWeightPair>> &close_src,
233  const std::vector<bool> &close_non_term_weight, MutableFst<Arc> *ofst);
234 
235  // Ensures that parentheses arcs are added to the symbol table.
236  void AddParensToSymbolTables(const std::vector<LabelPair> &parens,
237  MutableFst<Arc> *ofst);
238 
239  private:
240  std::vector<LabelFstPair> fst_array_;
241  Label root_;
242  // Index to use for the first parenthesis.
243  Label start_paren_labels_;
244  const std::string left_paren_prefix_;
245  const std::string right_paren_prefix_;
246  // Maps from non-terminal label to FST ID.
247  std::unordered_map<Label, StateId> label2id_;
248  // Given an output state, specifies the input FST (label, state) pair.
249  std::vector<LabelStatePair> label_state_pairs_;
250  // Given an FST (label, state) pair, specifies the output FST state ID.
251  std::map<LabelStatePair, StateId> state_map_;
252  bool error_;
253 };
254 
255 template <class Arc>
257  MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
258  std::vector<std::vector<StateWeightPair>> *close_src) {
259  ofst->DeleteStates();
260  if (error_) {
261  ofst->SetProperties(kError, kError);
262  return;
263  }
264  open_dest->resize(fst_array_.size(), kNoStateId);
265  close_src->resize(fst_array_.size());
266  // Queue of non-terminals to replace.
267  std::deque<Label> non_term_queue;
268  non_term_queue.push_back(root_);
269  // Has a non-terminal been enqueued?
270  std::vector<bool> enqueued(fst_array_.size(), false);
271  enqueued[label2id_[root_]] = true;
272  Label max_label = kNoLabel;
273  for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
274  const auto label = non_term_queue.front();
275  non_term_queue.pop_front();
276  StateId fst_id = Label2Id(label);
277  const auto *ifst = fst_array_[fst_id].second;
278  for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
279  const auto is = siter.Value();
280  const auto os = ofst->AddState();
281  const LabelStatePair lsp(label, is);
282  label_state_pairs_.push_back(lsp);
283  state_map_[lsp] = os;
284  if (is == ifst->Start()) {
285  (*open_dest)[fst_id] = os;
286  if (label == root_) ofst->SetStart(os);
287  }
288  if (ifst->Final(is) != Weight::Zero()) {
289  if (label == root_) ofst->SetFinal(os, ifst->Final(is));
290  (*close_src)[fst_id].emplace_back(os, ifst->Final(is));
291  }
292  for (ArcIterator<Fst<Arc>> aiter(*ifst, is); !aiter.Done();
293  aiter.Next()) {
294  auto arc = aiter.Value();
295  arc.nextstate += soff;
296  if (max_label == kNoLabel || arc.olabel > max_label)
297  max_label = arc.olabel;
298  const auto nfst_id = Label2Id(arc.olabel);
299  if (nfst_id != kNoStateId) {
300  if (fst_array_[nfst_id].second->Start() == kNoStateId) continue;
301  if (!enqueued[nfst_id]) {
302  non_term_queue.push_back(arc.olabel);
303  enqueued[nfst_id] = true;
304  }
305  }
306  ofst->AddArc(os, arc);
307  }
308  }
309  }
310  if (start_paren_labels_ == kNoLabel) start_paren_labels_ = max_label + 1;
311 }
312 
313 template <class Arc>
315  const std::vector<LabelPair> &parens, const ParenMap &paren_map,
316  const std::vector<StateId> &open_dest,
317  const std::vector<std::vector<StateWeightPair>> &close_src,
318  const std::vector<bool> &close_non_term_weight, MutableFst<Arc> *ofst) {
319  StateId dead_state = kNoStateId;
320  using MIter = MutableArcIterator<MutableFst<Arc>>;
321  for (StateIterator<Fst<Arc>> siter(*ofst); !siter.Done(); siter.Next()) {
322  StateId os = siter.Value();
323  std::unique_ptr<MIter> aiter(new MIter(ofst, os));
324  for (auto n = 0; !aiter->Done(); aiter->Next(), ++n) {
325  const auto arc = aiter->Value(); // A reference here may go stale.
326  StateId nfst_id = Label2Id(arc.olabel);
327  if (nfst_id != kNoStateId) {
328  // Gets parentheses.
329  const ParenKey paren_key(nfst_id, arc.nextstate);
330  auto it = paren_map.find(paren_key);
331  Label open_paren = 0;
332  Label close_paren = 0;
333  if (it != paren_map.end()) {
334  const auto paren_id = it->second;
335  open_paren = parens[paren_id].first;
336  close_paren = parens[paren_id].second;
337  }
338  // Sets open parenthesis.
339  if (open_paren != 0 || !close_non_term_weight[nfst_id]) {
340  const auto open_weight =
341  close_non_term_weight[nfst_id] ? Weight::One() : arc.weight;
342  const Arc sarc(open_paren, open_paren, open_weight,
343  open_dest[nfst_id]);
344  aiter->SetValue(sarc);
345  } else {
346  if (dead_state == kNoStateId) {
347  dead_state = ofst->AddState();
348  }
349  const Arc sarc(0, 0, Weight::One(), dead_state);
350  aiter->SetValue(sarc);
351  }
352  // Adds close parentheses.
353  if (close_paren != 0 || close_non_term_weight[nfst_id]) {
354  for (size_t i = 0; i < close_src[nfst_id].size(); ++i) {
355  const auto &pair = close_src[nfst_id][i];
356  const auto close_weight = close_non_term_weight[nfst_id]
357  ? Times(arc.weight, pair.second)
358  : pair.second;
359  const Arc farc(close_paren, close_paren, close_weight,
360  arc.nextstate);
361 
362  ofst->AddArc(pair.first, farc);
363  if (os == pair.first) { // Invalidated iterator.
364  aiter.reset(new MIter(ofst, os));
365  aiter->Seek(n);
366  }
367  }
368  }
369  }
370  }
371  }
372 }
373 
374 template <class Arc>
376  const std::vector<LabelPair> &parens, MutableFst<Arc> *ofst) {
377  auto size = parens.size();
378  if (ofst->InputSymbols()) {
379  if (!AddAuxiliarySymbols(left_paren_prefix_, start_paren_labels_, size,
380  ofst->MutableInputSymbols())) {
381  ofst->SetProperties(kError, kError);
382  return;
383  }
384  if (!AddAuxiliarySymbols(right_paren_prefix_, start_paren_labels_ + size,
385  size, ofst->MutableInputSymbols())) {
386  ofst->SetProperties(kError, kError);
387  return;
388  }
389  }
390  if (ofst->OutputSymbols()) {
391  if (!AddAuxiliarySymbols(left_paren_prefix_, start_paren_labels_, size,
392  ofst->MutableOutputSymbols())) {
393  ofst->SetProperties(kError, kError);
394  return;
395  }
396  if (!AddAuxiliarySymbols(right_paren_prefix_, start_paren_labels_ + size,
397  size, ofst->MutableOutputSymbols())) {
398  ofst->SetProperties(kError, kError);
399  return;
400  }
401  }
402 }
403 
404 // Builds a PDT by recursive replacement top-down, where the call and return are
405 // encoded in the parentheses.
406 template <class Arc>
407 class PdtLeftParser final : public PdtParser<Arc> {
408  public:
409  using Label = typename Arc::Label;
410  using StateId = typename Arc::StateId;
411  using Weight = typename Arc::Weight;
418 
427  using PdtParser<Arc>::Root;
428 
429  PdtLeftParser(const std::vector<LabelFstPair> &fst_array,
430  const PdtReplaceOptions<Arc> &opts)
431  : PdtParser<Arc>(fst_array, opts) {}
432 
433  void GetParser(MutableFst<Arc> *ofst,
434  std::vector<LabelPair> *parens) override;
435 
436  protected:
437  // Assigns a unique parenthesis ID for each non-terminal, destination
438  // state pair.
439  size_t AssignParenIds(const Fst<Arc> &ofst,
440  ParenMap *paren_map) const override;
441 };
442 
443 template <class Arc>
445  std::vector<LabelPair> *parens) {
446  ofst->DeleteStates();
447  parens->clear();
448  const auto &fst_array = FstArray();
449  // Map that gives the paren ID for a (non-terminal, dest. state) pair
450  // (which can be unique).
451  ParenMap paren_map;
452  // Specifies the open parenthesis destination state for a given non-terminal.
453  // The source is the non-terminal instance source state.
454  std::vector<StateId> open_dest(fst_array.size(), kNoStateId);
455  // Specifies close parenthesis source states and weights for a given
456  // non-terminal. The destination is the non-terminal instance destination
457  // state.
458  std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
459  // Specifies non-terminals for which the non-terminal arc weight
460  // should be applied on the close parenthesis (multiplying the
461  // 'close_src' weight above) rather than on the open parenthesis.
462  std::vector<bool> close_non_term_weight(fst_array.size(), false);
463  CreateFst(ofst, &open_dest, &close_src);
464  auto total_nparens = AssignParenIds(*ofst, &paren_map);
465  AssignParenLabels(total_nparens, parens);
466  AddParensToFst(*parens, paren_map, open_dest, close_src,
467  close_non_term_weight, ofst);
468  if (!fst_array.empty()) {
469  ofst->SetInputSymbols(fst_array[0].second->InputSymbols());
470  ofst->SetOutputSymbols(fst_array[0].second->OutputSymbols());
471  }
472  AddParensToSymbolTables(*parens, ofst);
473 }
474 
475 template <class Arc>
477  ParenMap *paren_map) const {
478  // Number of distinct parenthesis pairs per FST.
479  std::vector<size_t> nparens(FstArray().size(), 0);
480  // Number of distinct parenthesis pairs overall.
481  size_t total_nparens = 0;
482  for (StateIterator<Fst<Arc>> siter(ofst); !siter.Done(); siter.Next()) {
483  const auto os = siter.Value();
484  for (ArcIterator<Fst<Arc>> aiter(ofst, os); !aiter.Done(); aiter.Next()) {
485  const auto &arc = aiter.Value();
486  const auto nfst_id = Label2Id(arc.olabel);
487  if (nfst_id != kNoStateId) {
488  const ParenKey paren_key(nfst_id, arc.nextstate);
489  auto it = paren_map->find(paren_key);
490  if (it == paren_map->end()) {
491  // Assigns new paren ID for this (FST, dest state) pair.
492  (*paren_map)[paren_key] = nparens[nfst_id]++;
493  if (nparens[nfst_id] > total_nparens)
494  total_nparens = nparens[nfst_id];
495  }
496  }
497  }
498  }
499  return total_nparens;
500 }
501 
502 // Similar to PdtLeftParser but:
503 //
504 // 1. Uses epsilons rather than parentheses labels for any non-terminal
505 // instances within a left- (right-) linear dependency SCC,
506 // 2. Allocates a paren ID uniquely for each such dependency SCC (rather than
507 // non-terminal = dependency state) and destination state.
508 template <class Arc>
509 class PdtLeftSRParser final : public PdtParser<Arc> {
510  public:
511  using Label = typename Arc::Label;
512  using StateId = typename Arc::StateId;
513  using Weight = typename Arc::Weight;
520 
529  using PdtParser<Arc>::Root;
530 
531  PdtLeftSRParser(const std::vector<LabelFstPair> &fst_array,
532  const PdtReplaceOptions<Arc> &opts)
533  : PdtParser<Arc>(fst_array, opts),
534  replace_util_(fst_array, ReplaceUtilOptions(opts.root)) {}
535 
536  void GetParser(MutableFst<Arc> *ofst,
537  std::vector<LabelPair> *parens) override;
538 
539  protected:
540  // Assigns a unique parenthesis ID for each non-terminal, destination state
541  // pair when the non-terminal refers to a non-linear FST. Otherwise, assigns
542  // a unique parenthesis ID for each dependency SCC, destination state pair if
543  // the non-terminal instance is between
544  // SCCs. Otherwise does nothing.
545  size_t AssignParenIds(const Fst<Arc> &ofst,
546  ParenMap *paren_map) const override;
547 
548  // Returns dependency SCC for given label.
549  size_t SCC(Label label) const { return replace_util_.SCC(label); }
550 
551  // Is a given dependency SCC left-linear?
552  bool SCCLeftLinear(size_t scc_id) const {
553  const auto ll_props = kReplaceSCCLeftLinear | kReplaceSCCNonTrivial;
554  const auto scc_props = replace_util_.SCCProperties(scc_id);
555  return (scc_props & ll_props) == ll_props;
556  }
557 
558  // Is a given dependency SCC right-linear?
559  bool SCCRightLinear(size_t scc_id) const {
560  const auto lr_props = kReplaceSCCRightLinear | kReplaceSCCNonTrivial;
561  const auto scc_props = replace_util_.SCCProperties(scc_id);
562  return (scc_props & lr_props) == lr_props;
563  }
564 
565  // Components of left- (right-) linear dependency SCC; empty o.w.
566  const std::vector<size_t> &SCCComps(size_t scc_id) const {
567  if (scc_comps_.empty()) GetSCCComps();
568  return scc_comps_[scc_id];
569  }
570 
571  // Returns the representative state of an SCC. For left-linear grammars, it
572  // is one of the initial states. For right-linear grammars, it is one of the
573  // non-terminal destination states; otherwise, it is kNoStateId.
574  StateId RepState(size_t scc_id) const {
575  if (SCCComps(scc_id).empty()) return kNoStateId;
576  const auto fst_id = SCCComps(scc_id).front();
577  const auto &fst_array = FstArray();
578  const auto label = fst_array[fst_id].first;
579  const auto *ifst = fst_array[fst_id].second;
580  if (SCCLeftLinear(scc_id)) {
581  const LabelStatePair lsp(label, ifst->Start());
582  return GetState(lsp);
583  } else { // Right-linear.
584  const LabelStatePair lsp(label, *NonTermDests(fst_id).begin());
585  return GetState(lsp);
586  }
587  return kNoStateId;
588  }
589 
590  private:
591  // Merges initial (final) states of in a left- (right-) linear dependency SCC
592  // after dealing with the non-terminal arc and final weights.
593  void ProcSCCs(MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
594  std::vector<std::vector<StateWeightPair>> *close_src,
595  std::vector<bool> *close_non_term_weight) const;
596 
597  // Computes components of left- (right-) linear dependency SCC.
598  void GetSCCComps() const {
599  const std::vector<LabelFstPair> &fst_array = FstArray();
600  for (size_t i = 0; i < fst_array.size(); ++i) {
601  const auto label = fst_array[i].first;
602  const auto scc_id = SCC(label);
603  if (scc_comps_.size() <= scc_id) scc_comps_.resize(scc_id + 1);
604  if (SCCLeftLinear(scc_id) || SCCRightLinear(scc_id)) {
605  scc_comps_[scc_id].push_back(i);
606  }
607  }
608  }
609 
610  const std::set<StateId> &NonTermDests(StateId fst_id) const {
611  if (non_term_dests_.empty()) GetNonTermDests();
612  return non_term_dests_[fst_id];
613  }
614 
615  // Finds non-terminal destination states for right-linear FSTS, or does
616  // nothing if not found.
617  void GetNonTermDests() const;
618 
619  // Dependency SCC info.
620  mutable ReplaceUtil<Arc> replace_util_;
621  // Components of left- (right-) linear dependency SCCs, or empty otherwise.
622  mutable std::vector<std::vector<size_t>> scc_comps_;
623  // States that have non-terminals entering them for each (right-linear) FST.
624  mutable std::vector<std::set<StateId>> non_term_dests_;
625 };
626 
627 template <class Arc>
629  std::vector<LabelPair> *parens) {
630  ofst->DeleteStates();
631  parens->clear();
632  const auto &fst_array = FstArray();
633  // Map that gives the paren ID for a (non-terminal, dest. state) pair.
634  ParenMap paren_map;
635  // Specifies the open parenthesis destination state for a given non-terminal.
636  // The source is the non-terminal instance source state.
637  std::vector<StateId> open_dest(fst_array.size(), kNoStateId);
638  // Specifies close parenthesis source states and weights for a given
639  // non-terminal. The destination is the non-terminal instance destination
640  // state.
641  std::vector<std::vector<StateWeightPair>> close_src(fst_array.size());
642  // Specifies non-terminals for which the non-terminal arc weight should be
643  // applied on the close parenthesis (multiplying the close_src weight above)
644  // rather than on the open parenthesis.
645  std::vector<bool> close_non_term_weight(fst_array.size(), false);
646  CreateFst(ofst, &open_dest, &close_src);
647  ProcSCCs(ofst, &open_dest, &close_src, &close_non_term_weight);
648  const auto total_nparens = AssignParenIds(*ofst, &paren_map);
649  AssignParenLabels(total_nparens, parens);
650  AddParensToFst(*parens, paren_map, open_dest, close_src,
651  close_non_term_weight, ofst);
652  if (!fst_array.empty()) {
653  ofst->SetInputSymbols(fst_array[0].second->InputSymbols());
654  ofst->SetOutputSymbols(fst_array[0].second->OutputSymbols());
655  }
656  AddParensToSymbolTables(*parens, ofst);
657  Connect(ofst);
658 }
659 
660 template <class Arc>
662  MutableFst<Arc> *ofst, std::vector<StateId> *open_dest,
663  std::vector<std::vector<StateWeightPair>> *close_src,
664  std::vector<bool> *close_non_term_weight) const {
665  const auto &fst_array = FstArray();
666  for (StateIterator<Fst<Arc>> siter(*ofst); !siter.Done(); siter.Next()) {
667  const auto os = siter.Value();
668  const auto label = GetLabelStatePair(os).first;
669  const auto is = GetLabelStatePair(os).second;
670  const auto fst_id = Label2Id(label);
671  const auto scc_id = SCC(label);
672  const auto rs = RepState(scc_id);
673  const auto *ifst = fst_array[fst_id].second;
674  // SCC LEFT-LINEAR: puts non-terminal weights on close parentheses. Merges
675  // initial states into SCC representative state and updates open_dest.
676  if (SCCLeftLinear(scc_id)) {
677  (*close_non_term_weight)[fst_id] = true;
678  if (is == ifst->Start() && os != rs) {
679  for (ArcIterator<Fst<Arc>> aiter(*ofst, os); !aiter.Done();
680  aiter.Next()) {
681  const auto &arc = aiter.Value();
682  ofst->AddArc(rs, arc);
683  }
684  ofst->DeleteArcs(os);
685  if (os == ofst->Start()) ofst->SetStart(rs);
686  (*open_dest)[fst_id] = rs;
687  }
688  }
689  // SCC RIGHT-LINEAR: pushes back final weights onto non-terminals, if
690  // possible, or adds weighted epsilons to the SCC representative state.
691  // Merges final states into SCC representative state and updates close_src.
692  if (SCCRightLinear(scc_id)) {
693  for (MutableArcIterator<MutableFst<Arc>> aiter(ofst, os); !aiter.Done();
694  aiter.Next()) {
695  auto arc = aiter.Value();
696  const auto idest = GetLabelStatePair(arc.nextstate).second;
697  if (NonTermDests(fst_id).count(idest) > 0) {
698  if (ofst->Final(arc.nextstate) != Weight::Zero()) {
699  ofst->SetFinal(arc.nextstate, Weight::Zero());
700  ofst->SetFinal(rs, Weight::One());
701  }
702  arc.weight = Times(arc.weight, ifst->Final(idest));
703  arc.nextstate = rs;
704  aiter.SetValue(arc);
705  }
706  }
707  const auto final_weight = ifst->Final(is);
708  if (final_weight != Weight::Zero() &&
709  NonTermDests(fst_id).count(is) == 0) {
710  ofst->AddArc(os, Arc(0, 0, final_weight, rs));
711  if (ofst->Final(os) != Weight::Zero()) {
712  ofst->SetFinal(os, Weight::Zero());
713  ofst->SetFinal(rs, Weight::One());
714  }
715  }
716  if (is == ifst->Start()) {
717  (*close_src)[fst_id].clear();
718  (*close_src)[fst_id].emplace_back(rs, Weight::One());
719  }
720  }
721  }
722 }
723 
724 template <class Arc>
726  const auto &fst_array = FstArray();
727  non_term_dests_.resize(fst_array.size());
728  for (size_t fst_id = 0; fst_id < fst_array.size(); ++fst_id) {
729  const auto label = fst_array[fst_id].first;
730  const auto scc_id = SCC(label);
731  if (SCCRightLinear(scc_id)) {
732  const auto *ifst = fst_array[fst_id].second;
733  for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
734  const auto is = siter.Value();
735  for (ArcIterator<Fst<Arc>> aiter(*ifst, is); !aiter.Done();
736  aiter.Next()) {
737  const auto &arc = aiter.Value();
738  if (Label2Id(arc.olabel) != kNoStateId) {
739  non_term_dests_[fst_id].insert(arc.nextstate);
740  }
741  }
742  }
743  }
744  }
745 }
746 
747 template <class Arc>
749  ParenMap *paren_map) const {
750  const auto &fst_array = FstArray();
751  // Number of distinct parenthesis pairs per FST.
752  std::vector<size_t> nparens(fst_array.size(), 0);
753  // Number of distinct parenthesis pairs overall.
754  size_t total_nparens = 0;
755  for (StateIterator<Fst<Arc>> siter(ofst); !siter.Done(); siter.Next()) {
756  const auto os = siter.Value();
757  const auto label = GetLabelStatePair(os).first;
758  const auto scc_id = SCC(label);
759  for (ArcIterator<Fst<Arc>> aiter(ofst, os); !aiter.Done(); aiter.Next()) {
760  const auto &arc = aiter.Value();
761  const auto nfst_id = Label2Id(arc.olabel);
762  if (nfst_id != kNoStateId) {
763  size_t nscc_id = SCC(arc.olabel);
764  bool nscc_linear = !SCCComps(nscc_id).empty();
765  // Assigns a parenthesis ID for the non-terminal transition
766  // if the non-terminal belongs to a (left-/right-) linear dependency
767  // SCC or if the transition is in an FST from a different SCC
768  if (!nscc_linear || scc_id != nscc_id) {
769  // For (left-/right-) linear SCCs instead of using nfst_id, we
770  // will use its SCC prototype pfst_id for assigning distinct
771  // parenthesis IDs.
772  const auto pfst_id =
773  nscc_linear ? SCCComps(nscc_id).front() : nfst_id;
774  ParenKey paren_key(pfst_id, arc.nextstate);
775  const auto it = paren_map->find(paren_key);
776  if (it == paren_map->end()) {
777  // Assigns new paren ID for this (FST/SCC, dest. state) pair.
778  if (nscc_linear) {
779  // This is mapping we'll need, but we also store (harmlessly)
780  // for the prototype below so we can easily keep count per SCC.
781  const ParenKey nparen_key(nfst_id, arc.nextstate);
782  (*paren_map)[nparen_key] = nparens[pfst_id];
783  }
784  (*paren_map)[paren_key] = nparens[pfst_id]++;
785  if (nparens[pfst_id] > total_nparens) {
786  total_nparens = nparens[pfst_id];
787  }
788  }
789  }
790  }
791  }
792  }
793  return total_nparens;
794 }
795 
796 // Builds a pushdown transducer (PDT) from an RTN specification. The result is
797 // a PDT written to a mutable FST where some transitions are labeled with
798 // open or close parentheses. To be interpreted as a PDT, the parens must
799 // balance on a path (see PdtExpand()). The open/close parenthesis label pairs
800 // are returned in the parens argument.
801 template <class Arc>
802 void Replace(
803  const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
804  &ifst_array,
805  MutableFst<Arc> *ofst,
806  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
807  const PdtReplaceOptions<Arc> &opts) {
808  switch (opts.type) {
809  case PdtParserType::LEFT: {
810  PdtLeftParser<Arc> pr(ifst_array, opts);
811  pr.GetParser(ofst, parens);
812  return;
813  }
814  case PdtParserType::LEFT_SR: {
815  PdtLeftSRParser<Arc> pr(ifst_array, opts);
816  pr.GetParser(ofst, parens);
817  return;
818  }
819  default:
820  FSTERROR() << "Replace: Unknown PDT parser type: "
821  << static_cast<std::underlying_type_t<PdtParserType>>(
822  opts.type);
823  ofst->DeleteStates();
824  ofst->SetProperties(kError, kError);
825  parens->clear();
826  return;
827  }
828 }
829 
830 // Variant where the only user-controlled arguments is the root ID.
831 template <class Arc>
832 void Replace(
833  const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
834  &ifst_array,
835  MutableFst<Arc> *ofst,
836  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> *parens,
837  typename Arc::Label root) {
838  PdtReplaceOptions<Arc> opts(root);
839  Replace(ifst_array, ofst, parens, opts);
840 }
841 
842 } // namespace fst
843 
844 #endif // FST_EXTENSIONS_PDT_REPLACE_H_
virtual ~PdtParser()
Definition: replace.h:154
const std::string right_paren_prefix
Definition: replace.h:109
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
Definition: replace.h:628
std::unordered_map< ParenKey, size_t, internal::ReplaceParenHash< StateId >> ParenMap
Definition: replace.h:126
constexpr int kNoLabel
Definition: fst.h:195
virtual SymbolTable * MutableOutputSymbols()=0
PdtLeftSRParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:531
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
Definition: replace.h:476
StateId RepState(size_t scc_id) const
Definition: replace.h:574
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
const SymbolTable * InputSymbols() const override=0
constexpr uint64_t kError
Definition: properties.h:52
virtual void SetInputSymbols(const SymbolTable *isyms)=0
StateId Label2Id(Label l) const
Definition: replace.h:169
Label Root() const
Definition: replace.h:165
size_t AssignParenIds(const Fst< Arc > &ofst, ParenMap *paren_map) const override
Definition: replace.h:748
virtual Weight Final(StateId) const =0
PdtParserType
Definition: replace.h:69
virtual void SetStart(StateId)=0
bool SCCLeftLinear(size_t scc_id) const
Definition: replace.h:552
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:47
constexpr uint8_t kReplaceSCCLeftLinear
Definition: replace-util.h:87
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:802
size_t operator()(const std::pair< size_t, S > &paren) const
Definition: replace.h:52
const std::string left_paren_prefix
Definition: replace.h:108
void AssignParenLabels(size_t total_nparens, std::vector< LabelPair > *parens)
Definition: replace.h:203
constexpr int kNoStateId
Definition: fst.h:196
std::pair< Label, StateId > LabelStatePair
Definition: replace.h:122
bool SCCRightLinear(size_t scc_id) const
Definition: replace.h:559
const Arc & Value() const
Definition: mutable-fst.h:236
typename Arc::StateId StateId
Definition: replace.h:118
PdtReplaceOptions(Label root, PdtParserType type=PdtParserType::LEFT, Label start_paren_labels=kNoLabel, std::string left_paren_prefix="(_", std::string right_paren_prefix=")_")
Definition: replace.h:94
typename Arc::Label Label
Definition: replace.h:117
void CreateFst(MutableFst< Arc > *ofst, std::vector< StateId > *open_dest, std::vector< std::vector< StateWeightPair >> *close_src)
Definition: replace.h:256
#define FSTERROR()
Definition: util.h:56
const SymbolTable * OutputSymbols() const override=0
std::pair< StateId, Weight > StateWeightPair
Definition: replace.h:123
constexpr uint8_t kReplaceSCCNonTrivial
Definition: replace-util.h:94
const std::vector< size_t > & SCCComps(size_t scc_id) const
Definition: replace.h:566
PdtParserType type
Definition: replace.h:106
virtual void SetProperties(uint64_t props, uint64_t mask)=0
void GetParser(MutableFst< Arc > *ofst, std::vector< LabelPair > *parens) override
Definition: replace.h:444
virtual void DeleteArcs(StateId, size_t)=0
virtual StateId Start() const =0
StateId GetState(const LabelStatePair &lsp) const
Definition: replace.h:187
PdtParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:128
std::pair< size_t, StateId > ParenKey
Definition: replace.h:124
const std::vector< LabelFstPair > & FstArray() const
Definition: replace.h:163
std::pair< Label, Label > LabelPair
Definition: replace.h:121
void AddParensToSymbolTables(const std::vector< LabelPair > &parens, MutableFst< Arc > *ofst)
Definition: replace.h:375
typename Arc::Label Label
Definition: replace.h:92
virtual SymbolTable * MutableInputSymbols()=0
virtual void AddArc(StateId, const Arc &)=0
constexpr uint8_t kReplaceSCCRightLinear
Definition: replace-util.h:91
virtual StateId AddState()=0
std::pair< Label, const Fst< Arc > * > LabelFstPair
Definition: replace.h:120
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
void AddParensToFst(const std::vector< LabelPair > &parens, const ParenMap &paren_map, const std::vector< StateId > &open_dest, const std::vector< std::vector< StateWeightPair >> &close_src, const std::vector< bool > &close_non_term_weight, MutableFst< Arc > *ofst)
Definition: replace.h:314
virtual void DeleteStates(const std::vector< StateId > &)=0
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
typename Arc::Weight Weight
Definition: replace.h:119
virtual StateId NumStates() const =0
typename Arc::StateId StateId
Definition: replace.h:512
LabelStatePair GetLabelStatePair(StateId os) const
Definition: replace.h:176
PdtLeftParser(const std::vector< LabelFstPair > &fst_array, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:429
bool AddAuxiliarySymbols(const std::string &prefix, int64_t start_label, int64_t nlabels, SymbolTable *syms)
size_t SCC(Label label) const
Definition: replace.h:549