FST  openfst-1.7.3
OpenFst Library
replace.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for the recursive replacement of FSTs.
5 
6 #ifndef FST_REPLACE_H_
7 #define FST_REPLACE_H_
8 
9 #include <set>
10 #include <string>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14 
15 #include <fst/log.h>
16 
17 #include <fst/cache.h>
18 #include <fst/expanded-fst.h>
19 #include <fst/fst-decl.h> // For optional argument declarations.
20 #include <fst/fst.h>
21 #include <fst/matcher.h>
22 #include <fst/replace-util.h>
23 #include <fst/state-table.h>
24 #include <fst/test-properties.h>
25 
26 namespace fst {
27 
28 // Replace state tables have the form:
29 //
30 // template <class Arc, class P>
31 // class ReplaceStateTable {
32 // public:
33 // using Label = typename Arc::Label Label;
34 // using StateId = typename Arc::StateId;
35 //
36 // using PrefixId = P;
37 // using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
38 // using StackPrefix = ReplaceStackPrefix<Label, StateId>;
39 //
40 // // Required constructor.
41 // ReplaceStateTable(
42 // const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
43 // Label root);
44 //
45 // // Required copy constructor that does not copy state.
46 // ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
47 //
48 // // Looks up state ID by tuple, adding it if it doesn't exist.
49 // StateId FindState(const StateTuple &tuple);
50 //
51 // // Looks up state tuple by ID.
52 // const StateTuple &Tuple(StateId id) const;
53 //
54 // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
55 // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
56 //
57 // // Looks up stack prefix by ID.
58 // const StackPrefix &GetStackPrefix(PrefixId id) const;
59 // };
60 
61 // Tuple that uniquely defines a state in replace.
62 template <class S, class P>
64  using StateId = S;
65  using PrefixId = P;
66 
70 
71  PrefixId prefix_id; // Index in prefix table.
72  StateId fst_id; // Current FST being walked.
73  StateId fst_state; // Current state in FST being walked (not to be
74  // confused with the thse StateId of the combined FST).
75 };
76 
77 // Equality of replace state tuples.
78 template <class StateId, class PrefixId>
81  return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
82  x.fst_state == y.fst_state;
83 }
84 
85 // Functor returning true for tuples corresponding to states in the root FST.
86 template <class StateId, class PrefixId>
88  public:
90  return tuple.prefix_id == 0;
91  }
92 };
93 
94 // Functor for fingerprinting replace state tuples.
95 template <class StateId, class PrefixId>
97  public:
98  explicit ReplaceFingerprint(const std::vector<uint64> *size_array)
99  : size_array_(size_array) {}
100 
102  return tuple.prefix_id * size_array_->back() +
103  size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
104  }
105 
106  private:
107  const std::vector<uint64> *size_array_;
108 };
109 
110 // Useful when the fst_state uniquely define the tuple.
111 template <class StateId, class PrefixId>
113  public:
115  return tuple.fst_state;
116  }
117 };
118 
119 // A generic hash function for replace state tuples.
120 template <typename S, typename P>
121 class ReplaceHash {
122  public:
123  size_t operator()(const ReplaceStateTuple<S, P>& t) const {
124  static constexpr size_t prime0 = 7853;
125  static constexpr size_t prime1 = 7867;
126  return t.prefix_id + t.fst_id * prime0 + t.fst_state * prime1;
127  }
128 };
129 
130 // Container for stack prefix.
131 template <class Label, class StateId>
133  public:
134  struct PrefixTuple {
136  : fst_id(fst_id), nextstate(nextstate) {}
137 
138  Label fst_id;
140  };
141 
143 
145  : prefix_(other.prefix_) {}
146 
147  void Push(StateId fst_id, StateId nextstate) {
148  prefix_.push_back(PrefixTuple(fst_id, nextstate));
149  }
150 
151  void Pop() { prefix_.pop_back(); }
152 
153  const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
154 
155  size_t Depth() const { return prefix_.size(); }
156 
157  public:
158  std::vector<PrefixTuple> prefix_;
159 };
160 
161 // Equality stack prefix classes.
162 template <class Label, class StateId>
165  if (x.prefix_.size() != y.prefix_.size()) return false;
166  for (size_t i = 0; i < x.prefix_.size(); ++i) {
167  if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
168  x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
169  return false;
170  }
171  }
172  return true;
173 }
174 
175 // Hash function for stack prefix to prefix id.
176 template <class Label, class StateId>
178  public:
179  size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
180  size_t sum = 0;
181  for (const auto &pair : prefix.prefix_) {
182  static constexpr size_t prime = 7863;
183  sum += pair.fst_id + pair.nextstate * prime;
184  }
185  return sum;
186  }
187 };
188 
189 // Replace state tables.
190 
191 // A two-level state table for replace. Warning: calls CountStates to compute
192 // the number of states of each component FST.
193 template <class Arc, class P = ssize_t>
195  public:
196  using Label = typename Arc::Label;
197  using StateId = typename Arc::StateId;
198 
199  using PrefixId = P;
200 
202  using StateTable =
208  using StackPrefixTable =
211 
213  const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
214  Label root)
215  : root_size_(0) {
216  size_array_.push_back(0);
217  for (const auto &fst_pair : fst_list) {
218  if (fst_pair.first == root) {
219  root_size_ = CountStates(*(fst_pair.second));
220  size_array_.push_back(size_array_.back());
221  } else {
222  size_array_.push_back(size_array_.back() +
223  CountStates(*(fst_pair.second)));
224  }
225  }
226  state_table_.reset(
227  new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
228  new ReplaceFstStateFingerprint<StateId, PrefixId>,
229  new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
230  root_size_, root_size_ + size_array_.back()));
231  }
232 
235  : root_size_(table.root_size_),
236  size_array_(table.size_array_),
237  prefix_table_(table.prefix_table_) {
238  state_table_.reset(
239  new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
240  new ReplaceFstStateFingerprint<StateId, PrefixId>,
241  new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
242  root_size_, root_size_ + size_array_.back()));
243  }
244 
245  StateId FindState(const StateTuple &tuple) {
246  return state_table_->FindState(tuple);
247  }
248 
249  const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
250 
251  PrefixId FindPrefixId(const StackPrefix &prefix) {
252  return prefix_table_.FindId(prefix);
253  }
254 
255  const StackPrefix& GetStackPrefix(PrefixId id) const {
256  return prefix_table_.FindEntry(id);
257  }
258 
259  private:
260  StateId root_size_;
261  std::vector<uint64> size_array_;
262  std::unique_ptr<StateTable> state_table_;
263  StackPrefixTable prefix_table_;
264 };
265 
266 // Default replace state table.
267 template <class Arc, class P /* = size_t */>
269  : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
270  ReplaceHash<typename Arc::StateId, P>> {
271  public:
272  using Label = typename Arc::Label;
273  using StateId = typename Arc::StateId;
274 
275  using PrefixId = P;
277  using StateTable =
280  using StackPrefixTable =
283 
284  using StateTable::FindState;
285  using StateTable::Tuple;
286 
288  const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
289 
291  : StateTable(), prefix_table_(table.prefix_table_) {}
292 
293  PrefixId FindPrefixId(const StackPrefix &prefix) {
294  return prefix_table_.FindId(prefix);
295  }
296 
297  const StackPrefix &GetStackPrefix(PrefixId id) const {
298  return prefix_table_.FindEntry(id);
299  }
300 
301  private:
302  StackPrefixTable prefix_table_;
303 };
304 
305 // By default ReplaceFst will copy the input label of the replace arc.
306 // The call_label_type and return_label_type options specify how to manage
307 // the labels of the call arc and the return arc of the replace FST
308 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
309  class CacheStore = DefaultCacheStore<Arc>>
310 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
311  using Label = typename Arc::Label;
312 
313  // Index of root rule for expansion.
315  // How to label call arc.
317  // How to label return arc.
319  // Specifies output label to put on call arc; if kNoLabel, use existing label
320  // on call arc. Otherwise, use this field as the output label.
321  Label call_output_label = kNoLabel;
322  // Specifies label to put on return arc.
323  Label return_label = 0;
324  // Take ownership of input FSTs?
325  bool take_ownership = false;
326  // Pointer to optional pre-constructed state table.
327  StateTable *state_table = nullptr;
328 
330  Label root = kNoLabel)
331  : CacheImplOptions<CacheStore>(opts), root(root) {}
332 
333  explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
334  : CacheImplOptions<CacheStore>(opts), root(root) {}
335 
336  // FIXME(kbg): There are too many constructors here. Come up with a consistent
337  // position for call_output_label (probably the very end) so that it is
338  // possible to express all the remaining constructors with a single
339  // default-argument constructor. Also move clients off of the "backwards
340  // compatibility" constructor, for good.
341 
342  explicit ReplaceFstOptions(Label root) : root(root) {}
343 
344  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
345  ReplaceLabelType return_label_type,
346  Label return_label)
347  : root(root),
348  call_label_type(call_label_type),
349  return_label_type(return_label_type),
350  return_label(return_label) {}
351 
352  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
353  ReplaceLabelType return_label_type,
354  Label call_output_label, Label return_label)
355  : root(root),
356  call_label_type(call_label_type),
357  return_label_type(return_label_type),
358  call_output_label(call_output_label),
359  return_label(return_label) {}
360 
361  explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
362  : ReplaceFstOptions(opts.root, opts.call_label_type,
363  opts.return_label_type, opts.return_label) {}
364 
366 
367  // For backwards compatibility.
368  ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
369  : root(root),
370  call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
372  call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
373 };
374 
375 
376 // Forward declaration.
377 template <class Arc, class StateTable, class CacheStore>
379 
380 template <class Arc>
381 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
382 
383 // Returns true if label type on arc results in epsilon input label.
384 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
385  return label_type == REPLACE_LABEL_NEITHER ||
386  label_type == REPLACE_LABEL_OUTPUT;
387 }
388 
389 // Returns true if label type on arc results in epsilon input label.
390 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
391  return label_type == REPLACE_LABEL_NEITHER ||
392  label_type == REPLACE_LABEL_INPUT;
393 }
394 
395 // Returns true if for either the call or return arc ilabel != olabel.
396 template <class Label>
397 bool ReplaceTransducer(ReplaceLabelType call_label_type,
398  ReplaceLabelType return_label_type,
399  Label call_output_label) {
400  return call_label_type == REPLACE_LABEL_INPUT ||
401  call_label_type == REPLACE_LABEL_OUTPUT ||
402  (call_label_type == REPLACE_LABEL_BOTH &&
403  call_output_label != kNoLabel) ||
404  return_label_type == REPLACE_LABEL_INPUT ||
405  return_label_type == REPLACE_LABEL_OUTPUT;
406 }
407 
408 template <class Arc>
409 uint64 ReplaceFstProperties(typename Arc::Label root_label,
410  const FstList<Arc> &fst_list,
411  ReplaceLabelType call_label_type,
412  ReplaceLabelType return_label_type,
413  typename Arc::Label call_output_label,
414  bool *sorted_and_non_empty) {
415  using Label = typename Arc::Label;
416  std::vector<uint64> inprops;
417  bool all_ilabel_sorted = true;
418  bool all_olabel_sorted = true;
419  bool all_non_empty = true;
420  // All nonterminals are negative?
421  bool all_negative = true;
422  // All nonterminals are positive and form a dense range containing 1?
423  bool dense_range = true;
424  Label root_fst_idx = 0;
425  for (Label i = 0; i < fst_list.size(); ++i) {
426  const auto label = fst_list[i].first;
427  if (label >= 0) all_negative = false;
428  if (label > fst_list.size() || label <= 0) dense_range = false;
429  if (label == root_label) root_fst_idx = i;
430  const auto *fst = fst_list[i].second;
431  if (fst->Start() == kNoStateId) all_non_empty = false;
432  if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
433  if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
434  inprops.push_back(fst->Properties(kCopyProperties, false));
435  }
436  const auto props = ReplaceProperties(
437  inprops, root_fst_idx, EpsilonOnInput(call_label_type),
438  EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
439  EpsilonOnOutput(return_label_type),
440  ReplaceTransducer(call_label_type, return_label_type, call_output_label),
441  all_non_empty, all_ilabel_sorted, all_olabel_sorted,
442  all_negative || dense_range);
443  const bool sorted = props & (kILabelSorted | kOLabelSorted);
444  *sorted_and_non_empty = all_non_empty && sorted;
445  return props;
446 }
447 
448 namespace internal {
449 
450 // The replace implementation class supports a dynamic expansion of a recursive
451 // transition network represented as label/FST pairs with dynamic replacable
452 // arcs.
453 template <class Arc, class StateTable, class CacheStore>
455  : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
456  public:
457  using Label = typename Arc::Label;
458  using StateId = typename Arc::StateId;
459  using Weight = typename Arc::Weight;
460 
461  using State = typename CacheStore::State;
463  using PrefixId = typename StateTable::PrefixId;
466  using NonTerminalHash = std::unordered_map<Label, Label>;
467 
468  using FstImpl<Arc>::SetType;
475 
476  using CacheImpl::PushArc;
477  using CacheImpl::HasArcs;
478  using CacheImpl::HasFinal;
479  using CacheImpl::HasStart;
480  using CacheImpl::SetArcs;
481  using CacheImpl::SetFinal;
482  using CacheImpl::SetStart;
483 
484  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
485 
486  ReplaceFstImpl(const FstList<Arc> &fst_list,
488  : CacheImpl(opts),
489  call_label_type_(opts.call_label_type),
490  return_label_type_(opts.return_label_type),
491  call_output_label_(opts.call_output_label),
492  return_label_(opts.return_label),
493  state_table_(opts.state_table ? opts.state_table
494  : new StateTable(fst_list, opts.root)) {
495  SetType("replace");
496  // If the label is epsilon, then all replace label options are equivalent,
497  // so we set the label types to NEITHER for simplicity.
498  if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
499  if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
500  if (!fst_list.empty()) {
501  SetInputSymbols(fst_list[0].second->InputSymbols());
502  SetOutputSymbols(fst_list[0].second->OutputSymbols());
503  }
504  fst_array_.push_back(nullptr);
505  for (Label i = 0; i < fst_list.size(); ++i) {
506  const auto label = fst_list[i].first;
507  const auto *fst = fst_list[i].second;
508  nonterminal_hash_[label] = fst_array_.size();
509  nonterminal_set_.insert(label);
510  fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
511  if (i) {
512  if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
513  FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
514  << " do not match input symbols of base FST (0th FST)";
515  SetProperties(kError, kError);
516  }
517  if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
518  FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
519  << " do not match output symbols of base FST (0th FST)";
520  SetProperties(kError, kError);
521  }
522  }
523  }
524  const auto nonterminal = nonterminal_hash_[opts.root];
525  if ((nonterminal == 0) && (fst_array_.size() > 1)) {
526  FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
527  << opts.root << " in the input tuple vector";
528  SetProperties(kError, kError);
529  }
530  root_ = (nonterminal > 0) ? nonterminal : 1;
531  bool all_non_empty_and_sorted = false;
532  SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
533  return_label_type_, call_output_label_,
534  &all_non_empty_and_sorted));
535  // Enables optional caching as long as sorted and all non-empty.
536  always_cache_ = !all_non_empty_and_sorted;
537  VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
538  << (always_cache_ ? "true" : "false");
539  }
540 
542  : CacheImpl(impl),
543  call_label_type_(impl.call_label_type_),
544  return_label_type_(impl.return_label_type_),
545  call_output_label_(impl.call_output_label_),
546  return_label_(impl.return_label_),
547  always_cache_(impl.always_cache_),
548  state_table_(new StateTable(*(impl.state_table_))),
549  nonterminal_set_(impl.nonterminal_set_),
550  nonterminal_hash_(impl.nonterminal_hash_),
551  root_(impl.root_) {
552  SetType("replace");
553  SetProperties(impl.Properties(), kCopyProperties);
554  SetInputSymbols(impl.InputSymbols());
555  SetOutputSymbols(impl.OutputSymbols());
556  fst_array_.reserve(impl.fst_array_.size());
557  fst_array_.emplace_back(nullptr);
558  for (Label i = 1; i < impl.fst_array_.size(); ++i) {
559  fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
560  }
561  }
562 
563  // Computes the dependency graph of the replace class and returns
564  // true if the dependencies are cyclic. Cyclic dependencies will result
565  // in an un-expandable FST.
566  bool CyclicDependencies() const {
567  const ReplaceUtilOptions opts(root_);
568  ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
569  return replace_util.CyclicDependencies();
570  }
571 
573  if (!HasStart()) {
574  if (fst_array_.size() == 1) {
575  SetStart(kNoStateId);
576  return kNoStateId;
577  } else {
578  const auto fst_start = fst_array_[root_]->Start();
579  if (fst_start == kNoStateId) return kNoStateId;
580  const auto prefix = GetPrefixId(StackPrefix());
581  const auto start =
582  state_table_->FindState(StateTuple(prefix, root_, fst_start));
583  SetStart(start);
584  return start;
585  }
586  } else {
587  return CacheImpl::Start();
588  }
589  }
590 
592  if (HasFinal(s)) return CacheImpl::Final(s);
593  const auto &tuple = state_table_->Tuple(s);
594  auto weight = Weight::Zero();
595  if (tuple.prefix_id == 0) {
596  const auto fst_state = tuple.fst_state;
597  weight = fst_array_[tuple.fst_id]->Final(fst_state);
598  }
599  if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
600  return weight;
601  }
602 
603  size_t NumArcs(StateId s) {
604  if (HasArcs(s)) {
605  return CacheImpl::NumArcs(s);
606  } else if (always_cache_) { // If always caching, expands and caches state.
607  Expand(s);
608  return CacheImpl::NumArcs(s);
609  } else { // Otherwise computes the number of arcs without expanding.
610  const auto tuple = state_table_->Tuple(s);
611  if (tuple.fst_state == kNoStateId) return 0;
612  auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
613  if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
614  return num_arcs;
615  }
616  }
617 
618  // Returns whether a given label is a non-terminal.
619  bool IsNonTerminal(Label label) const {
620  if (label < *nonterminal_set_.begin() ||
621  label > *nonterminal_set_.rbegin()) {
622  return false;
623  } else {
624  return nonterminal_hash_.count(label);
625  }
626  // TODO(allauzen): be smarter and take advantage of all_dense or
627  // all_negative. Also use this in ComputeArc. This would require changes to
628  // Replace so that recursing into an empty FST lead to a non co-accessible
629  // state instead of deleting the arc as done currently. The current use
630  // correct, since labels are sorted if all_non_empty is true.
631  }
632 
634  if (HasArcs(s)) {
635  return CacheImpl::NumInputEpsilons(s);
636  } else if (always_cache_ || !Properties(kILabelSorted)) {
637  // If always caching or if the number of input epsilons is too expensive
638  // to compute without caching (i.e., not ilabel-sorted), then expands and
639  // caches state.
640  Expand(s);
641  return CacheImpl::NumInputEpsilons(s);
642  } else {
643  // Otherwise, computes the number of input epsilons without caching.
644  const auto tuple = state_table_->Tuple(s);
645  if (tuple.fst_state == kNoStateId) return 0;
646  size_t num = 0;
647  if (!EpsilonOnInput(call_label_type_)) {
648  // If EpsilonOnInput(c) is false, all input epsilon arcs
649  // are also input epsilons arcs in the underlying machine.
650  num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
651  } else {
652  // Otherwise, one need to consider that all non-terminal arcs
653  // in the underlying machine also become input epsilon arc.
654  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
655  for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
656  IsNonTerminal(aiter.Value().olabel));
657  aiter.Next()) {
658  ++num;
659  }
660  }
661  if (EpsilonOnInput(return_label_type_) &&
662  ComputeFinalArc(tuple, nullptr)) {
663  ++num;
664  }
665  return num;
666  }
667  }
668 
670  if (HasArcs(s)) {
672  } else if (always_cache_ || !Properties(kOLabelSorted)) {
673  // If always caching or if the number of output epsilons is too expensive
674  // to compute without caching (i.e., not olabel-sorted), then expands and
675  // caches state.
676  Expand(s);
678  } else {
679  // Otherwise, computes the number of output epsilons without caching.
680  const auto tuple = state_table_->Tuple(s);
681  if (tuple.fst_state == kNoStateId) return 0;
682  size_t num = 0;
683  if (!EpsilonOnOutput(call_label_type_)) {
684  // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
685  // output epsilons arcs in the underlying machine.
686  num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
687  } else {
688  // Otherwise, one need to consider that all non-terminal arcs in the
689  // underlying machine also become output epsilon arc.
690  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
691  for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
692  IsNonTerminal(aiter.Value().olabel));
693  aiter.Next()) {
694  ++num;
695  }
696  }
697  if (EpsilonOnOutput(return_label_type_) &&
698  ComputeFinalArc(tuple, nullptr)) {
699  ++num;
700  }
701  return num;
702  }
703  }
704 
705  uint64 Properties() const override { return Properties(kFstProperties); }
706 
707  // Sets error if found, and returns other FST impl properties.
708  uint64 Properties(uint64 mask) const override {
709  if (mask & kError) {
710  for (Label i = 1; i < fst_array_.size(); ++i) {
711  if (fst_array_[i]->Properties(kError, false)) {
712  SetProperties(kError, kError);
713  }
714  }
715  }
716  return FstImpl<Arc>::Properties(mask);
717  }
718 
719  // Returns the base arc iterator, and if arcs have not been computed yet,
720  // extends and recurses for new arcs.
722  if (!HasArcs(s)) Expand(s);
723  CacheImpl::InitArcIterator(s, data);
724  // TODO(allauzen): Set behaviour of generic iterator.
725  // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
726  // behaviour.
727  }
728 
729  // Extends current state (walk arcs one level deep).
730  void Expand(StateId s) {
731  const auto tuple = state_table_->Tuple(s);
732  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
733  SetArcs(s);
734  return;
735  }
736  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
737  Arc arc;
738  // Creates a final arc when needed.
739  if (ComputeFinalArc(tuple, &arc)) PushArc(s, std::move(arc));
740  // Expands all arcs leaving the state.
741  for (; !aiter.Done(); aiter.Next()) {
742  if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, std::move(arc));
743  }
744  SetArcs(s);
745  }
746 
747  void Expand(StateId s, const StateTuple &tuple,
748  const ArcIteratorData<Arc> &data) {
749  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
750  SetArcs(s);
751  return;
752  }
753  ArcIterator<Fst<Arc>> aiter(data);
754  Arc arc;
755  // Creates a final arc when needed.
756  if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
757  // Expands all arcs leaving the state.
758  for (; !aiter.Done(); aiter.Next()) {
759  if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
760  }
761  SetArcs(s);
762  }
763 
764  // If acpp is null, only returns true if a final arcp is required, but does
765  // not actually compute it.
766  bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
767  uint32 flags = kArcValueFlags) {
768  const auto fst_state = tuple.fst_state;
769  if (fst_state == kNoStateId) return false;
770  // If state is final, pops the stack.
771  if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
772  tuple.prefix_id) {
773  if (arcp) {
774  arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
775  arcp->olabel =
776  (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
777  if (flags & kArcNextStateValue) {
778  const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
779  const auto prefix_id = PopPrefix(stack);
780  const auto &top = stack.Top();
781  arcp->nextstate = state_table_->FindState(
782  StateTuple(prefix_id, top.fst_id, top.nextstate));
783  }
784  if (flags & kArcWeightValue) {
785  arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
786  }
787  }
788  return true;
789  } else {
790  return false;
791  }
792  }
793 
794  // Computes an arc in the FST corresponding to one in the underlying machine.
795  // Returns false if the underlying arc corresponds to no arc in the resulting
796  // FST.
797  bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
798  uint32 flags = kArcValueFlags) {
799  if (!EpsilonOnInput(call_label_type_) &&
800  (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
801  *arcp = arc;
802  return true;
803  }
804  if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
805  arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST.
806  const auto nextstate =
807  flags & kArcNextStateValue
808  ? state_table_->FindState(
809  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
810  : kNoStateId;
811  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
812  } else {
813  // Checks for non-terminal.
814  const auto it = nonterminal_hash_.find(arc.olabel);
815  if (it != nonterminal_hash_.end()) { // Recurses into non-terminal.
816  const auto nonterminal = it->second;
817  const auto nt_prefix =
818  PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
819  tuple.fst_id, arc.nextstate);
820  // If the start state is valid, replace; othewise, the arc is implicitly
821  // deleted.
822  const auto nt_start = fst_array_[nonterminal]->Start();
823  if (nt_start != kNoStateId) {
824  const auto nt_nextstate = flags & kArcNextStateValue
825  ? state_table_->FindState(StateTuple(
826  nt_prefix, nonterminal, nt_start))
827  : kNoStateId;
828  const auto ilabel =
829  (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
830  const auto olabel =
831  (EpsilonOnOutput(call_label_type_))
832  ? 0
833  : ((call_output_label_ == kNoLabel) ? arc.olabel
834  : call_output_label_);
835  *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
836  } else {
837  return false;
838  }
839  } else {
840  const auto nextstate =
841  flags & kArcNextStateValue
842  ? state_table_->FindState(
843  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
844  : kNoStateId;
845  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
846  }
847  }
848  return true;
849  }
850 
851  // Returns the arc iterator flags supported by this FST.
853  uint32 flags = kArcValueFlags;
854  if (!always_cache_) flags |= kArcNoCache;
855  return flags;
856  }
857 
858  StateTable *GetStateTable() const { return state_table_.get(); }
859 
860  const Fst<Arc> *GetFst(Label fst_id) const {
861  return fst_array_[fst_id].get();
862  }
863 
864  Label GetFstId(Label nonterminal) const {
865  const auto it = nonterminal_hash_.find(nonterminal);
866  if (it == nonterminal_hash_.end()) {
867  FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
868  << nonterminal;
869  }
870  return it->second;
871  }
872 
873  // Returns true if label type on call arc results in epsilon input label.
874  bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
875 
876  private:
877  // The unique index into stack prefix table.
878  PrefixId GetPrefixId(const StackPrefix &prefix) {
879  return state_table_->FindPrefixId(prefix);
880  }
881 
882  // The prefix ID after a stack pop.
883  PrefixId PopPrefix(StackPrefix prefix) {
884  prefix.Pop();
885  return GetPrefixId(prefix);
886  }
887 
888  // The prefix ID after a stack push.
889  PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
890  prefix.Push(fst_id, nextstate);
891  return GetPrefixId(prefix);
892  }
893 
894  // Runtime options
895  ReplaceLabelType call_label_type_; // How to label call arc.
896  ReplaceLabelType return_label_type_; // How to label return arc.
897  int64 call_output_label_; // Specifies output label to put on call arc
898  int64 return_label_; // Specifies label to put on return arc.
899  bool always_cache_; // Disable optional caching of arc iterator?
900 
901  // State table.
902  std::unique_ptr<StateTable> state_table_;
903 
904  // Replace components.
905  std::set<Label> nonterminal_set_;
906  NonTerminalHash nonterminal_hash_;
907  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
908  Label root_;
909 };
910 
911 } // namespace internal
912 
913 //
914 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
915 // This replacement is recursive. ReplaceFst can be used to support a variety of
916 // delayed constructions such as recursive
917 // transition networks, union, or closure. It is constructed with an array of
918 // FST(s). One FST represents the root (or topology) machine. The root FST
919 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
920 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
921 // symbols of the arcs to determine whether the arc is a non-terminal arc or
922 // not. A non-terminal can be any label that is not a non-zero terminal label in
923 // the output alphabet.
924 //
925 // Note that the constructor uses a vector of pairs. These correspond to the
926 // tuple of non-terminal Label and corresponding FST. For example to implement
927 // the closure operation we need 2 FSTs. The first root FST is a single
928 // self-loop arc on the start state.
929 //
930 // The ReplaceFst class supports an optionally caching arc iterator.
931 //
932 // The ReplaceFst needs to be built such that it is known to be ilabel- or
933 // olabel-sorted (see usage below).
934 //
935 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
936 // when available (the FST is ilabel-sorted and matching on the input, or the
937 // FST is olabel -orted and matching on the output). In order to obtain the
938 // most efficient behaviour, it is recommended to set call_label_type to
939 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
940 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
941 // does not have epsilon on the input side and the return arc has epsilon on the
942 // input side) and matching on the input side.
943 //
944 // This class attaches interface to implementation and handles reference
945 // counting, delegating most methods to ImplToFst.
946 template <class A, class T /* = DefaultReplaceStateTable<A> */,
947  class CacheStore /* = DefaultCacheStore<A> */>
948 class ReplaceFst
949  : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
950  public:
951  using Arc = A;
952  using Label = typename Arc::Label;
953  using StateId = typename Arc::StateId;
954  using Weight = typename Arc::Weight;
955 
956  using StateTable = T;
957  using Store = CacheStore;
958  using State = typename CacheStore::State;
961 
963 
964  friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
965  friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
966  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
967 
968  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
969  Label root)
970  : ImplToFst<Impl>(std::make_shared<Impl>(
971  fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
972 
973  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
975  : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
976 
977  // See Fst<>::Copy() for doc.
979  bool safe = false)
980  : ImplToFst<Impl>(fst, safe) {}
981 
982  // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
984  bool safe = false) const override {
985  return new ReplaceFst<Arc, StateTable, CacheStore>(*this, safe);
986  }
987 
988  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
989 
990  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
991  GetMutableImpl()->InitArcIterator(s, data);
992  }
993 
994  MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
995  if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
996  ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
997  (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
999  (this, match_type);
1000  } else {
1001  VLOG(2) << "Not using replace matcher";
1002  return nullptr;
1003  }
1004  }
1005 
1006  bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1007 
1008  const StateTable &GetStateTable() const {
1009  return *GetImpl()->GetStateTable();
1010  }
1011 
1012  const Fst<Arc> &GetFst(Label nonterminal) const {
1013  return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1014  }
1015 
1016  private:
1019 
1020  ReplaceFst &operator=(const ReplaceFst &) = delete;
1021 };
1022 
1023 // Specialization for ReplaceFst.
1024 template <class Arc, class StateTable, class CacheStore>
1025 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1026  : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1027  public:
1029  : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1030  fst, fst.GetMutableImpl()) {}
1031 };
1032 
1033 // Specialization for ReplaceFst, implementing optional caching. It is be used
1034 // as follows:
1035 //
1036 // ReplaceFst<A> replace;
1037 // ArcIterator<ReplaceFst<A>> aiter(replace, s);
1038 // // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1039 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1040 // // Uses the arc iterator, no arc will be cached, no state will be expanded.
1041 // // Arc flags can be used to decide which component of the arc need to be
1042 // computed.
1043 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1044 // // Wants the ilabel for this arc.
1045 // aiter.Value(); // Does not compute the destination state.
1046 // aiter.Next();
1047 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1048 // // Wants the ilabel and next state for this arc.
1049 // aiter.Value(); // Does compute the destination state and inserts it
1050 // // in the replace state table.
1051 // // No additional arcs have been cached at this point.
1052 template <class Arc, class StateTable, class CacheStore>
1053 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1054  public:
1055  using StateId = typename Arc::StateId;
1056 
1057  using StateTuple = typename StateTable::StateTuple;
1058 
1060  : fst_(fst),
1061  s_(s),
1062  pos_(0),
1063  offset_(0),
1064  flags_(kArcValueFlags),
1065  arcs_(nullptr),
1066  data_flags_(0),
1067  final_flags_(0) {
1068  cache_data_.ref_count = nullptr;
1069  local_data_.ref_count = nullptr;
1070  // If FST does not support optional caching, forces caching.
1071  if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1072  !(fst_.GetImpl()->HasArcs(s_))) {
1073  fst_.GetMutableImpl()->Expand(s_);
1074  }
1075  // If state is already cached, use cached arcs array.
1076  if (fst_.GetImpl()->HasArcs(s_)) {
1077  (fst_.GetImpl())
1078  ->internal::template CacheBaseImpl<
1079  typename CacheStore::State,
1080  CacheStore>::InitArcIterator(s_, &cache_data_);
1081  num_arcs_ = cache_data_.narcs;
1082  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1083  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1084  } else { // Otherwise delay decision until Value() is called.
1085  tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1086  if (tuple_.fst_state == kNoStateId) {
1087  num_arcs_ = 0;
1088  } else {
1089  // The decision to cache or not to cache has been defered until Value()
1090  // or
1091  // SetFlags() is called. However, the arc iterator is set up now to be
1092  // ready for non-caching in order to keep the Value() method simple and
1093  // efficient.
1094  const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1095  rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1096  // arcs_ is a pointer to the arcs in the underlying machine.
1097  arcs_ = local_data_.arcs;
1098  // Computes the final arc (but not its destination state) if a final arc
1099  // is required.
1100  bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1101  tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1102  // Sets the arc value flags that hold for final_arc_.
1103  final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1104  // Computes the number of arcs.
1105  num_arcs_ = local_data_.narcs;
1106  if (has_final_arc) ++num_arcs_;
1107  // Sets the offset between the underlying arc positions and the
1108  // positions
1109  // in the arc iterator.
1110  offset_ = num_arcs_ - local_data_.narcs;
1111  // Defers the decision to cache or not until Value() or SetFlags() is
1112  // called.
1113  data_flags_ = 0;
1114  }
1115  }
1116  }
1117 
1119  if (cache_data_.ref_count) --(*cache_data_.ref_count);
1120  if (local_data_.ref_count) --(*local_data_.ref_count);
1121  }
1122 
1123  void ExpandAndCache() const {
1124  // TODO(allauzen): revisit this.
1125  // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1126  // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1127  // &cache_data_);
1128  //
1129  fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state.
1130  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1131  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1132  offset_ = 0; // No offset.
1133  }
1134 
1135  void Init() {
1136  if (flags_ & kArcNoCache) { // If caching is disabled
1137  // arcs_ is a pointer to the arcs in the underlying machine.
1138  arcs_ = local_data_.arcs;
1139  // Sets the arcs value flags that hold for arcs_.
1140  data_flags_ = kArcWeightValue;
1141  if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1142  data_flags_ |= kArcILabelValue;
1143  }
1144  // Sets the offset between the underlying arc positions and the positions
1145  // in the arc iterator.
1146  offset_ = num_arcs_ - local_data_.narcs;
1147  } else {
1148  ExpandAndCache();
1149  }
1150  }
1151 
1152  bool Done() const { return pos_ >= num_arcs_; }
1153 
1154  const Arc &Value() const {
1155  // If data_flags_ is 0, non-caching was not requested.
1156  if (!data_flags_) {
1157  // TODO(allauzen): Revisit this.
1158  if (flags_ & kArcNoCache) {
1159  // Should never happen.
1160  FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1161  }
1162  ExpandAndCache();
1163  }
1164  if (pos_ - offset_ >= 0) { // The requested arc is not the final arc.
1165  const auto &arc = arcs_[pos_ - offset_];
1166  if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1167  // If the value flags match the recquired value flags then returns the
1168  // arc.
1169  return arc;
1170  } else {
1171  // Otherwise, compute the corresponding arc on-the-fly.
1172  fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1173  flags_ & kArcValueFlags);
1174  return arc_;
1175  }
1176  } else { // The requested arc is the final arc.
1177  if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1178  // If the arc value flags that hold for the final arc do not match the
1179  // requested value flags, then
1180  // final_arc_ needs to be updated.
1181  fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1182  flags_ & kArcValueFlags);
1183  final_flags_ = flags_ & kArcValueFlags;
1184  }
1185  return final_arc_;
1186  }
1187  }
1188 
1189  void Next() { ++pos_; }
1190 
1191  size_t Position() const { return pos_; }
1192 
1193  void Reset() { pos_ = 0; }
1194 
1195  void Seek(size_t pos) { pos_ = pos; }
1196 
1197  uint32 Flags() const { return flags_; }
1198 
1199  void SetFlags(uint32 flags, uint32 mask) {
1200  // Updates the flags taking into account what flags are supported
1201  // by the FST.
1202  flags_ &= ~mask;
1203  flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1204  // If non-caching is not requested (and caching has not already been
1205  // performed), then flush data_flags_ to request caching during the next
1206  // call to Value().
1207  if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1208  if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1209  }
1210  // If data_flags_ has been flushed but non-caching is requested before
1211  // calling Value(), then set up the iterator for non-caching.
1212  if ((flags & kArcNoCache) && (!data_flags_)) Init();
1213  }
1214 
1215  private:
1216  const ReplaceFst<Arc, StateTable, CacheStore> &fst_; // Reference to the FST.
1217  StateId s_; // State in the FST.
1218  mutable StateTuple tuple_; // Tuple corresponding to state_.
1219 
1220  ssize_t pos_; // Current position.
1221  mutable ssize_t offset_; // Offset between position in iterator and in arcs_.
1222  ssize_t num_arcs_; // Number of arcs at state_.
1223  uint32 flags_; // Behavorial flags for the arc iterator
1224  mutable Arc arc_; // Memory to temporarily store computed arcs.
1225 
1226  mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache.
1227  mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local FST.
1228 
1229  mutable const Arc *arcs_; // Array of arcs.
1230  mutable uint32 data_flags_; // Arc value flags valid for data in arcs_.
1231  mutable Arc final_arc_; // Final arc (when required).
1232  mutable uint32 final_flags_; // Arc value flags valid for final_arc_.
1233 
1234  ArcIterator(const ArcIterator &) = delete;
1235  ArcIterator &operator=(const ArcIterator &) = delete;
1236 };
1237 
1238 template <class Arc, class StateTable, class CacheStore>
1239 class ReplaceFstMatcher : public MatcherBase<Arc> {
1240  public:
1241  using Label = typename Arc::Label;
1242  using StateId = typename Arc::StateId;
1243  using Weight = typename Arc::Weight;
1244 
1247 
1248  using StateTuple = typename StateTable::StateTuple;
1249 
1250  // This makes a copy of the FST.
1252  MatchType match_type)
1253  : owned_fst_(fst.Copy()),
1254  fst_(*owned_fst_),
1255  impl_(fst_.GetMutableImpl()),
1256  s_(fst::kNoStateId),
1257  match_type_(match_type),
1258  current_loop_(false),
1259  final_arc_(false),
1260  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1261  if (match_type_ == fst::MATCH_OUTPUT) {
1262  std::swap(loop_.ilabel, loop_.olabel);
1263  }
1264  InitMatchers();
1265  }
1266 
1267  // This doesn't copy the FST.
1269  MatchType match_type)
1270  : fst_(*fst),
1271  impl_(fst_.GetMutableImpl()),
1272  s_(fst::kNoStateId),
1273  match_type_(match_type),
1274  current_loop_(false),
1275  final_arc_(false),
1276  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1277  if (match_type_ == fst::MATCH_OUTPUT) {
1278  std::swap(loop_.ilabel, loop_.olabel);
1279  }
1280  InitMatchers();
1281  }
1282 
1283  // This makes a copy of the FST.
1286  bool safe = false)
1287  : owned_fst_(matcher.fst_.Copy(safe)),
1288  fst_(*owned_fst_),
1289  impl_(fst_.GetMutableImpl()),
1290  s_(fst::kNoStateId),
1291  match_type_(matcher.match_type_),
1292  current_loop_(false),
1293  final_arc_(false),
1294  loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1295  if (match_type_ == fst::MATCH_OUTPUT) {
1296  std::swap(loop_.ilabel, loop_.olabel);
1297  }
1298  InitMatchers();
1299  }
1300 
1301  // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1302  // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1303  // non-terminal arc, since these non-terminal
1304  // turn into epsilons on recursion.
1305  void InitMatchers() {
1306  const auto &fst_array = impl_->fst_array_;
1307  matcher_.resize(fst_array.size());
1308  for (Label i = 0; i < fst_array.size(); ++i) {
1309  if (fst_array[i]) {
1310  matcher_[i].reset(
1311  new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList));
1312  auto it = impl_->nonterminal_set_.begin();
1313  for (; it != impl_->nonterminal_set_.end(); ++it) {
1314  matcher_[i]->AddMultiEpsLabel(*it);
1315  }
1316  }
1317  }
1318  }
1319 
1321  bool safe = false) const override {
1322  return new ReplaceFstMatcher<Arc, StateTable, CacheStore>(*this, safe);
1323  }
1324 
1325  MatchType Type(bool test) const override {
1326  if (match_type_ == MATCH_NONE) return match_type_;
1327  const auto true_prop =
1328  match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1329  const auto false_prop =
1330  match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1331  const auto props = fst_.Properties(true_prop | false_prop, test);
1332  if (props & true_prop) {
1333  return match_type_;
1334  } else if (props & false_prop) {
1335  return MATCH_NONE;
1336  } else {
1337  return MATCH_UNKNOWN;
1338  }
1339  }
1340 
1341  const Fst<Arc> &GetFst() const override { return fst_; }
1342 
1343  uint64 Properties(uint64 props) const override { return props; }
1344 
1345  // Sets the state from which our matching happens.
1346  void SetState(StateId s) final {
1347  if (s_ == s) return;
1348  s_ = s;
1349  tuple_ = impl_->GetStateTable()->Tuple(s_);
1350  if (tuple_.fst_state == kNoStateId) {
1351  done_ = true;
1352  return;
1353  }
1354  // Gets current matcher, used for non-epsilon matching.
1355  current_matcher_ = matcher_[tuple_.fst_id].get();
1356  current_matcher_->SetState(tuple_.fst_state);
1357  loop_.nextstate = s_;
1358  final_arc_ = false;
1359  }
1360 
1361  // Searches for label from previous set state. If label == 0, first
1362  // hallucinate an epsilon loop; otherwise use the underlying matcher to
1363  // search for the label or epsilons. Note since the ReplaceFst recursion
1364  // on non-terminal arcs causes epsilon transitions to be created we use
1365  // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1366  // component FST
1367  // reaches a final state we also need to add the exiting final arc.
1368  bool Find(Label label) final {
1369  bool found = false;
1370  label_ = label;
1371  if (label_ == 0 || label_ == kNoLabel) {
1372  // Computes loop directly, avoiding Replace::ComputeArc.
1373  if (label_ == 0) {
1374  current_loop_ = true;
1375  found = true;
1376  }
1377  // Searches for matching multi-epsilons.
1378  final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1379  found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1380  } else {
1381  // Searches on a sub machine directly using sub machine matcher.
1382  found = current_matcher_->Find(label_);
1383  }
1384  return found;
1385  }
1386 
1387  bool Done() const final {
1388  return !current_loop_ && !final_arc_ && current_matcher_->Done();
1389  }
1390 
1391  const Arc &Value() const final {
1392  if (current_loop_) return loop_;
1393  if (final_arc_) {
1394  impl_->ComputeFinalArc(tuple_, &arc_);
1395  return arc_;
1396  }
1397  const auto &component_arc = current_matcher_->Value();
1398  impl_->ComputeArc(tuple_, component_arc, &arc_);
1399  return arc_;
1400  }
1401 
1402  void Next() final {
1403  if (current_loop_) {
1404  current_loop_ = false;
1405  return;
1406  }
1407  if (final_arc_) {
1408  final_arc_ = false;
1409  return;
1410  }
1411  current_matcher_->Next();
1412  }
1413 
1414  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1415 
1416  private:
1417  std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
1420  LocalMatcher *current_matcher_;
1421  std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1422  StateId s_; // Current state.
1423  Label label_; // Current label.
1424  MatchType match_type_; // Supplied by caller.
1425  mutable bool done_;
1426  mutable bool current_loop_; // Current arc is the implicit loop.
1427  mutable bool final_arc_; // Current arc for exiting recursion.
1428  mutable StateTuple tuple_; // Tuple corresponding to state_.
1429  mutable Arc arc_;
1430  Arc loop_;
1431 
1432  ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1433 };
1434 
1435 template <class Arc, class StateTable, class CacheStore>
1437  StateIteratorData<Arc> *data) const {
1438  data->base =
1440 }
1441 
1443 
1444 // Recursively replaces arcs in the root FSTs with other FSTs.
1445 // This version writes the result of replacement to an output MutableFst.
1446 //
1447 // Replace supports replacement of arcs in one Fst with another FST. This
1448 // replacement is recursive. Replace takes an array of FST(s). One FST
1449 // represents the root (or topology) machine. The root FST refers to other FSTs
1450 // by recursively replacing arcs labeled as non-terminals with the matching
1451 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1452 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1453 // any label that is not a non-zero terminal label in the output alphabet.
1454 //
1455 // Note that input argument is a vector of pairs. These correspond to the tuple
1456 // of non-terminal Label and corresponding FST.
1457 template <class Arc>
1458 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1459  &ifst_array,
1460  MutableFst<Arc> *ofst,
1462  opts.gc = true;
1463  opts.gc_limit = 0; // Caches only the last state for fastest copy.
1464  *ofst = ReplaceFst<Arc>(ifst_array, opts);
1465 }
1466 
1467 template <class Arc>
1468 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1469  &ifst_array,
1470  MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1471  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1472 }
1473 
1474 // For backwards compatibility.
1475 template <class Arc>
1476 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1477  &ifst_array,
1478  MutableFst<Arc> *ofst, typename Arc::Label root,
1479  bool epsilon_on_replace) {
1480  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1481 }
1482 
1483 template <class Arc>
1484 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1485  &ifst_array,
1486  MutableFst<Arc> *ofst, typename Arc::Label root) {
1487  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1488 }
1489 
1490 } // namespace fst
1491 
1492 #endif // FST_REPLACE_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:100
ReplaceFstOptions(const CacheOptions &opts, Label root=kNoLabel)
Definition: replace.h:333
typename CacheStore::State State
Definition: replace.h:461
void SetState(StateId s) final
Definition: replace.h:1346
ArcIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst, StateId s)
Definition: replace.h:1059
const StateTuple & Tuple(StateId id) const
Definition: replace.h:249
constexpr int kNoLabel
Definition: fst.h:178
std::vector< PrefixTuple > prefix_
Definition: replace.h:158
MatcherBase< Arc > * InitMatcher(MatchType match_type) const override
Definition: replace.h:994
void Expand(StateId s, const StateTuple &tuple, const ArcIteratorData< Arc > &data)
Definition: replace.h:747
const StateTable & GetStateTable() const
Definition: replace.h:1008
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:297
constexpr uint64 kNotOLabelSorted
Definition: properties.h:83
uint64_t uint64
Definition: types.h:32
void Next() final
Definition: replace.h:1402
const Fst< Arc > * GetFst(Label fst_id) const
Definition: replace.h:860
ReplaceStateTuple(PrefixId prefix_id=-1, StateId fst_id=kNoStateId, StateId fst_state=kNoStateId)
Definition: replace.h:67
typename StateTable::PrefixId PrefixId
Definition: replace.h:463
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:293
ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
Definition: replace.h:368
ReplaceFstOptions(const ReplaceUtilOptions &opts)
Definition: replace.h:361
PrefixId prefix_id
Definition: replace.h:71
size_t NumOutputEpsilons(StateId s)
Definition: replace.h:669
bool EpsilonOnOutput(ReplaceLabelType label_type)
Definition: replace.h:390
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > *fst, MatchType match_type)
Definition: replace.h:1268
typename CacheStore::State State
Definition: replace.h:958
bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp, uint32 flags=kArcValueFlags)
Definition: replace.h:797
ReplaceLabelType
Definition: replace-util.h:28
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:83
const SymbolTable * OutputSymbols() const
Definition: fst.h:691
MatchType
Definition: fst.h:170
uint64 ReplaceProperties(const std::vector< uint64 > &inprops, size_t root, bool epsilon_on_call, bool epsilon_on_return, bool out_epsilon_on_call, bool out_epsilon_on_return, bool replace_transducer, bool no_empty_fst, bool all_ilabel_sorted, bool all_olabel_sorted, bool all_negative_or_dense)
Definition: properties.cc:234
constexpr uint64 kILabelSorted
Definition: properties.h:76
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:973
uint64 Properties(uint64 mask) const override
Definition: replace.h:708
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: replace.h:990
bool ReplaceTransducer(ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label)
Definition: replace.h:397
void Push(StateId fst_id, StateId nextstate)
Definition: replace.h:147
VectorHashReplaceStateTable(const VectorHashReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:233
SetType
Definition: set-weight.h:36
MatchType Type(bool test) const override
Definition: replace.h:1325
typename Arc::Weight Weight
Definition: replace.h:1243
ReplaceFst< Arc, StateTable, CacheStore > * Copy(bool safe=false) const override
Definition: replace.h:983
typename ReplaceFst< Arc, StateTable, CacheStore >::Arc Arc
Definition: cache.h:1148
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:89
constexpr uint64 kNotILabelSorted
Definition: properties.h:78
bool EpsilonOnInput(ReplaceLabelType label_type)
Definition: replace.h:384
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:786
bool CyclicDependencies() const
Definition: replace.h:1006
const uint32 kMultiEpsList
Definition: matcher.h:1224
VectorHashReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_list, Label root)
Definition: replace.h:212
ReplaceFstOptions(const CacheImplOptions< CacheStore > &opts, Label root=kNoLabel)
Definition: replace.h:329
bool Find(Label label) final
Definition: replace.h:1368
const PrefixTuple & Top() const
Definition: replace.h:153
const Fst< Arc > & GetFst(Label nonterminal) const
Definition: replace.h:1012
constexpr uint64 kFstProperties
Definition: properties.h:308
constexpr uint64 kCopyProperties
Definition: properties.h:145
constexpr int kNoStateId
Definition: fst.h:179
uint64 operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:114
const Arc & Value() const
Definition: fst.h:504
DefaultReplaceStateTable(const DefaultReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:290
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: replace.h:721
std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> FstList
Definition: replace.h:381
int64_t int64
Definition: types.h:27
#define FSTERROR()
Definition: util.h:35
ReplaceFstImpl(const ReplaceFstImpl &impl)
Definition: replace.h:541
size_t operator()(const ReplaceStateTuple< S, P > &t) const
Definition: replace.h:123
StateIteratorBase< Arc > * base
Definition: fst.h:352
typename Arc::Weight Weight
Definition: replace.h:459
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:94
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: replace.h:1436
typename Arc::StateId StateId
Definition: replace.h:197
bool operator==(const PdtStateTuple< S, K > &x, const PdtStateTuple< S, K > &y)
Definition: pdt.h:133
const Fst< Arc > & GetFst() const override
Definition: replace.h:1341
Weight Final(StateId s)
Definition: replace.h:591
constexpr uint64 kOLabelSorted
Definition: properties.h:81
#define VLOG(level)
Definition: log.h:47
size_t NumArcs(StateId s)
Definition: replace.h:603
typename Arc::Label Label
Definition: replace.h:196
uint32 ArcIteratorFlags() const
Definition: replace.h:852
bool Done() const
Definition: fst.h:500
uint64 Properties(uint64 props) const override
Definition: replace.h:1343
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:251
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:255
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, Label root)
Definition: replace.h:968
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label return_label)
Definition: replace.h:344
bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp, uint32 flags=kArcValueFlags)
Definition: replace.h:766
uint64 operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:101
typename Arc::Label Label
Definition: replace.h:311
ReplaceFstMatcher< Arc, StateTable, CacheStore > * Copy(bool safe=false) const override
Definition: replace.h:1320
void Expand(StateId s)
Definition: replace.h:730
ReplaceFst(const ReplaceFst< Arc, StateTable, CacheStore > &fst, bool safe=false)
Definition: replace.h:978
const Arc & Value() const final
Definition: replace.h:1391
uint32_t uint32
Definition: types.h:31
Label GetFstId(Label nonterminal) const
Definition: replace.h:864
typename internal::ReplaceFstImpl< Arc, StateTable, CacheStore >::Arc Arc
Definition: fst.h:187
size_t Depth() const
Definition: replace.h:155
const SymbolTable * InputSymbols() const
Definition: fst.h:689
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:155
bool CyclicDependencies() const
Definition: replace-util.h:119
typename Arc::Label Label
Definition: replace.h:1241
size_t operator()(const ReplaceStackPrefix< Label, StateId > &prefix) const
Definition: replace.h:179
ReplaceFstMatcher(const ReplaceFstMatcher< Arc, StateTable, CacheStore > &matcher, bool safe=false)
Definition: replace.h:1284
size_t NumInputEpsilons(StateId s)
Definition: replace.h:633
ssize_t Priority(StateId s) final
Definition: replace.h:1414
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label, Label return_label)
Definition: replace.h:352
ReplaceFstOptions(Label root)
Definition: replace.h:342
constexpr uint64 kError
Definition: properties.h:34
bool CyclicDependencies() const
Definition: replace.h:566
typename StateTable::StateTuple StateTuple
Definition: replace.h:1248
ReplaceFstImpl(const FstList< Arc > &fst_list, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:486
uint64 Properties() const override
Definition: replace.h:705
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > &fst, MatchType match_type)
Definition: replace.h:1251
typename Arc::StateId StateId
Definition: replace.h:273
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:302
ReplaceFingerprint(const std::vector< uint64 > *size_array)
Definition: replace.h:98
uint64 ReplaceFstProperties(typename Arc::Label root_label, const FstList< Arc > &fst_list, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, typename Arc::Label call_output_label, bool *sorted_and_non_empty)
Definition: replace.h:409
ReplaceStackPrefix(const ReplaceStackPrefix &other)
Definition: replace.h:144
typename Arc::StateId StateId
Definition: replace.h:458
bool Done() const final
Definition: replace.h:1387
StateId FindState(const StateTuple &tuple)
Definition: replace.h:245
bool IsNonTerminal(Label label) const
Definition: replace.h:619
typename Arc::StateId StateId
Definition: replace.h:1242
DefaultReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &, Label)
Definition: replace.h:287
StateIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst)
Definition: replace.h:1028
typename Arc::Label Label
Definition: replace.h:457
PrefixTuple(Label fst_id=kNoLabel, StateId nextstate=kNoStateId)
Definition: replace.h:135
bool operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:89
StateTable * GetStateTable() const
Definition: replace.h:858
std::unordered_map< Label, Label > NonTerminalHash
Definition: replace.h:466
void Next()
Definition: fst.h:508