FST  openfst-1.7.5
OpenFst Library
replace.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for the recursive replacement of FSTs.
5 
6 #ifndef FST_REPLACE_H_
7 #define FST_REPLACE_H_
8 
9 #include <set>
10 #include <string>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14 
15 #include <fst/types.h>
16 #include <fst/log.h>
17 
18 #include <fst/cache.h>
19 #include <fst/expanded-fst.h>
20 #include <fst/fst-decl.h> // For optional argument declarations.
21 #include <fst/fst.h>
22 #include <fst/matcher.h>
23 #include <fst/replace-util.h>
24 #include <fst/state-table.h>
25 #include <fst/test-properties.h>
26 
27 namespace fst {
28 
29 // Replace state tables have the form:
30 //
31 // template <class Arc, class P>
32 // class ReplaceStateTable {
33 // public:
34 // using Label = typename Arc::Label Label;
35 // using StateId = typename Arc::StateId;
36 //
37 // using PrefixId = P;
38 // using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
39 // using StackPrefix = ReplaceStackPrefix<Label, StateId>;
40 //
41 // // Required constructor.
42 // ReplaceStateTable(
43 // const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
44 // Label root);
45 //
46 // // Required copy constructor that does not copy state.
47 // ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
48 //
49 // // Looks up state ID by tuple, adding it if it doesn't exist.
50 // StateId FindState(const StateTuple &tuple);
51 //
52 // // Looks up state tuple by ID.
53 // const StateTuple &Tuple(StateId id) const;
54 //
55 // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
56 // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
57 //
58 // // Looks up stack prefix by ID.
59 // const StackPrefix &GetStackPrefix(PrefixId id) const;
60 // };
61 
62 // Tuple that uniquely defines a state in replace.
63 template <class S, class P>
65  using StateId = S;
66  using PrefixId = P;
67 
71 
72  PrefixId prefix_id; // Index in prefix table.
73  StateId fst_id; // Current FST being walked.
74  StateId fst_state; // Current state in FST being walked (not to be
75  // confused with the thse StateId of the combined FST).
76 };
77 
78 // Equality of replace state tuples.
79 template <class StateId, class PrefixId>
82  return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
83  x.fst_state == y.fst_state;
84 }
85 
86 // Functor returning true for tuples corresponding to states in the root FST.
87 template <class StateId, class PrefixId>
89  public:
91  return tuple.prefix_id == 0;
92  }
93 };
94 
95 // Functor for fingerprinting replace state tuples.
96 template <class StateId, class PrefixId>
98  public:
99  explicit ReplaceFingerprint(const std::vector<uint64> *size_array)
100  : size_array_(size_array) {}
101 
103  return tuple.prefix_id * size_array_->back() +
104  size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
105  }
106 
107  private:
108  const std::vector<uint64> *size_array_;
109 };
110 
111 // Useful when the fst_state uniquely define the tuple.
112 template <class StateId, class PrefixId>
114  public:
116  return tuple.fst_state;
117  }
118 };
119 
120 // A generic hash function for replace state tuples.
121 template <typename S, typename P>
122 class ReplaceHash {
123  public:
124  size_t operator()(const ReplaceStateTuple<S, P>& t) const {
125  static constexpr size_t prime0 = 7853;
126  static constexpr size_t prime1 = 7867;
127  return t.prefix_id + t.fst_id * prime0 + t.fst_state * prime1;
128  }
129 };
130 
131 // Container for stack prefix.
132 template <class Label, class StateId>
134  public:
135  struct PrefixTuple {
137  : fst_id(fst_id), nextstate(nextstate) {}
138 
139  Label fst_id;
141  };
142 
144 
146  : prefix_(other.prefix_) {}
147 
148  void Push(StateId fst_id, StateId nextstate) {
149  prefix_.push_back(PrefixTuple(fst_id, nextstate));
150  }
151 
152  void Pop() { prefix_.pop_back(); }
153 
154  const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
155 
156  size_t Depth() const { return prefix_.size(); }
157 
158  public:
159  std::vector<PrefixTuple> prefix_;
160 };
161 
162 // Equality stack prefix classes.
163 template <class Label, class StateId>
166  if (x.prefix_.size() != y.prefix_.size()) return false;
167  for (size_t i = 0; i < x.prefix_.size(); ++i) {
168  if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
169  x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
170  return false;
171  }
172  }
173  return true;
174 }
175 
176 // Hash function for stack prefix to prefix id.
177 template <class Label, class StateId>
179  public:
180  size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
181  size_t sum = 0;
182  for (const auto &pair : prefix.prefix_) {
183  static constexpr size_t prime = 7863;
184  sum += pair.fst_id + pair.nextstate * prime;
185  }
186  return sum;
187  }
188 };
189 
190 // Replace state tables.
191 
192 // A two-level state table for replace. Warning: calls CountStates to compute
193 // the number of states of each component FST.
194 template <class Arc, class P = ssize_t>
196  public:
197  using Label = typename Arc::Label;
198  using StateId = typename Arc::StateId;
199 
200  using PrefixId = P;
201 
203  using StateTable =
209  using StackPrefixTable =
212 
214  const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
215  Label root)
216  : root_size_(0) {
217  size_array_.push_back(0);
218  for (const auto &fst_pair : fst_list) {
219  if (fst_pair.first == root) {
220  root_size_ = CountStates(*(fst_pair.second));
221  size_array_.push_back(size_array_.back());
222  } else {
223  size_array_.push_back(size_array_.back() +
224  CountStates(*(fst_pair.second)));
225  }
226  }
227  state_table_.reset(
228  new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
229  new ReplaceFstStateFingerprint<StateId, PrefixId>,
230  new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
231  root_size_, root_size_ + size_array_.back()));
232  }
233 
236  : root_size_(table.root_size_),
237  size_array_(table.size_array_),
238  prefix_table_(table.prefix_table_) {
239  state_table_.reset(
240  new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
241  new ReplaceFstStateFingerprint<StateId, PrefixId>,
242  new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
243  root_size_, root_size_ + size_array_.back()));
244  }
245 
246  StateId FindState(const StateTuple &tuple) {
247  return state_table_->FindState(tuple);
248  }
249 
250  const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
251 
252  PrefixId FindPrefixId(const StackPrefix &prefix) {
253  return prefix_table_.FindId(prefix);
254  }
255 
256  const StackPrefix& GetStackPrefix(PrefixId id) const {
257  return prefix_table_.FindEntry(id);
258  }
259 
260  private:
261  StateId root_size_;
262  std::vector<uint64> size_array_;
263  std::unique_ptr<StateTable> state_table_;
264  StackPrefixTable prefix_table_;
265 };
266 
267 // Default replace state table.
268 template <class Arc, class P /* = size_t */>
270  : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
271  ReplaceHash<typename Arc::StateId, P>> {
272  public:
273  using Label = typename Arc::Label;
274  using StateId = typename Arc::StateId;
275 
276  using PrefixId = P;
278  using StateTable =
281  using StackPrefixTable =
284 
285  using StateTable::FindState;
286  using StateTable::Tuple;
287 
289  const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
290 
292  : StateTable(), prefix_table_(table.prefix_table_) {}
293 
294  PrefixId FindPrefixId(const StackPrefix &prefix) {
295  return prefix_table_.FindId(prefix);
296  }
297 
298  const StackPrefix &GetStackPrefix(PrefixId id) const {
299  return prefix_table_.FindEntry(id);
300  }
301 
302  private:
303  StackPrefixTable prefix_table_;
304 };
305 
306 // By default ReplaceFst will copy the input label of the replace arc.
307 // The call_label_type and return_label_type options specify how to manage
308 // the labels of the call arc and the return arc of the replace FST
309 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
310  class CacheStore = DefaultCacheStore<Arc>>
311 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
312  using Label = typename Arc::Label;
313 
314  // Index of root rule for expansion.
316  // How to label call arc.
318  // How to label return arc.
320  // Specifies output label to put on call arc; if kNoLabel, use existing label
321  // on call arc. Otherwise, use this field as the output label.
322  Label call_output_label = kNoLabel;
323  // Specifies label to put on return arc.
324  Label return_label = 0;
325  // Take ownership of input FSTs?
326  bool take_ownership = false;
327  // Pointer to optional pre-constructed state table.
328  StateTable *state_table = nullptr;
329 
331  Label root = kNoLabel)
332  : CacheImplOptions<CacheStore>(opts), root(root) {}
333 
334  explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
335  : CacheImplOptions<CacheStore>(opts), root(root) {}
336 
337  // FIXME(kbg): There are too many constructors here. Come up with a consistent
338  // position for call_output_label (probably the very end) so that it is
339  // possible to express all the remaining constructors with a single
340  // default-argument constructor. Also move clients off of the "backwards
341  // compatibility" constructor, for good.
342 
343  explicit ReplaceFstOptions(Label root) : root(root) {}
344 
345  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
346  ReplaceLabelType return_label_type,
347  Label return_label)
348  : root(root),
349  call_label_type(call_label_type),
350  return_label_type(return_label_type),
351  return_label(return_label) {}
352 
353  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
354  ReplaceLabelType return_label_type,
355  Label call_output_label, Label return_label)
356  : root(root),
357  call_label_type(call_label_type),
358  return_label_type(return_label_type),
359  call_output_label(call_output_label),
360  return_label(return_label) {}
361 
362  explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
363  : ReplaceFstOptions(opts.root, opts.call_label_type,
364  opts.return_label_type, opts.return_label) {}
365 
367 
368  // For backwards compatibility.
369  ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
370  : root(root),
371  call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
373  call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
374 };
375 
376 
377 // Forward declaration.
378 template <class Arc, class StateTable, class CacheStore>
380 
381 template <class Arc>
382 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
383 
384 // Returns true if label type on arc results in epsilon input label.
385 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
386  return label_type == REPLACE_LABEL_NEITHER ||
387  label_type == REPLACE_LABEL_OUTPUT;
388 }
389 
390 // Returns true if label type on arc results in epsilon input label.
391 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
392  return label_type == REPLACE_LABEL_NEITHER ||
393  label_type == REPLACE_LABEL_INPUT;
394 }
395 
396 // Returns true if for either the call or return arc ilabel != olabel.
397 template <class Label>
398 bool ReplaceTransducer(ReplaceLabelType call_label_type,
399  ReplaceLabelType return_label_type,
400  Label call_output_label) {
401  return call_label_type == REPLACE_LABEL_INPUT ||
402  call_label_type == REPLACE_LABEL_OUTPUT ||
403  (call_label_type == REPLACE_LABEL_BOTH &&
404  call_output_label != kNoLabel) ||
405  return_label_type == REPLACE_LABEL_INPUT ||
406  return_label_type == REPLACE_LABEL_OUTPUT;
407 }
408 
409 template <class Arc>
410 uint64 ReplaceFstProperties(typename Arc::Label root_label,
411  const FstList<Arc> &fst_list,
412  ReplaceLabelType call_label_type,
413  ReplaceLabelType return_label_type,
414  typename Arc::Label call_output_label,
415  bool *sorted_and_non_empty) {
416  using Label = typename Arc::Label;
417  std::vector<uint64> inprops;
418  bool all_ilabel_sorted = true;
419  bool all_olabel_sorted = true;
420  bool all_non_empty = true;
421  // All nonterminals are negative?
422  bool all_negative = true;
423  // All nonterminals are positive and form a dense range containing 1?
424  bool dense_range = true;
425  Label root_fst_idx = 0;
426  for (Label i = 0; i < fst_list.size(); ++i) {
427  const auto label = fst_list[i].first;
428  if (label >= 0) all_negative = false;
429  if (label > fst_list.size() || label <= 0) dense_range = false;
430  if (label == root_label) root_fst_idx = i;
431  const auto *fst = fst_list[i].second;
432  if (fst->Start() == kNoStateId) all_non_empty = false;
433  if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
434  if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
435  inprops.push_back(fst->Properties(kCopyProperties, false));
436  }
437  const auto props = ReplaceProperties(
438  inprops, root_fst_idx, EpsilonOnInput(call_label_type),
439  EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
440  EpsilonOnOutput(return_label_type),
441  ReplaceTransducer(call_label_type, return_label_type, call_output_label),
442  all_non_empty, all_ilabel_sorted, all_olabel_sorted,
443  all_negative || dense_range);
444  const bool sorted = props & (kILabelSorted | kOLabelSorted);
445  *sorted_and_non_empty = all_non_empty && sorted;
446  return props;
447 }
448 
449 namespace internal {
450 
451 // The replace implementation class supports a dynamic expansion of a recursive
452 // transition network represented as label/FST pairs with dynamic replacable
453 // arcs.
454 template <class Arc, class StateTable, class CacheStore>
456  : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
457  public:
458  using Label = typename Arc::Label;
459  using StateId = typename Arc::StateId;
460  using Weight = typename Arc::Weight;
461 
462  using State = typename CacheStore::State;
464  using PrefixId = typename StateTable::PrefixId;
467  using NonTerminalHash = std::unordered_map<Label, Label>;
468 
469  using FstImpl<Arc>::SetType;
476 
477  using CacheImpl::PushArc;
478  using CacheImpl::HasArcs;
479  using CacheImpl::HasFinal;
480  using CacheImpl::HasStart;
481  using CacheImpl::SetArcs;
482  using CacheImpl::SetFinal;
483  using CacheImpl::SetStart;
484 
485  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
486 
487  ReplaceFstImpl(const FstList<Arc> &fst_list,
489  : CacheImpl(opts),
490  call_label_type_(opts.call_label_type),
491  return_label_type_(opts.return_label_type),
492  call_output_label_(opts.call_output_label),
493  return_label_(opts.return_label),
494  state_table_(opts.state_table ? opts.state_table
495  : new StateTable(fst_list, opts.root)) {
496  SetType("replace");
497  // If the label is epsilon, then all replace label options are equivalent,
498  // so we set the label types to NEITHER for simplicity.
499  if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
500  if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
501  if (!fst_list.empty()) {
502  SetInputSymbols(fst_list[0].second->InputSymbols());
503  SetOutputSymbols(fst_list[0].second->OutputSymbols());
504  }
505  fst_array_.push_back(nullptr);
506  for (Label i = 0; i < fst_list.size(); ++i) {
507  const auto label = fst_list[i].first;
508  const auto *fst = fst_list[i].second;
509  nonterminal_hash_[label] = fst_array_.size();
510  nonterminal_set_.insert(label);
511  fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
512  if (i) {
513  if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
514  FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
515  << " do not match input symbols of base FST (0th FST)";
516  SetProperties(kError, kError);
517  }
518  if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
519  FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
520  << " do not match output symbols of base FST (0th FST)";
521  SetProperties(kError, kError);
522  }
523  }
524  }
525  const auto nonterminal = nonterminal_hash_[opts.root];
526  if ((nonterminal == 0) && (fst_array_.size() > 1)) {
527  FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
528  << opts.root << " in the input tuple vector";
529  SetProperties(kError, kError);
530  }
531  root_ = (nonterminal > 0) ? nonterminal : 1;
532  bool all_non_empty_and_sorted = false;
533  SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
534  return_label_type_, call_output_label_,
535  &all_non_empty_and_sorted));
536  // Enables optional caching as long as sorted and all non-empty.
537  always_cache_ = !all_non_empty_and_sorted;
538  VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
539  << (always_cache_ ? "true" : "false");
540  }
541 
543  : CacheImpl(impl),
544  call_label_type_(impl.call_label_type_),
545  return_label_type_(impl.return_label_type_),
546  call_output_label_(impl.call_output_label_),
547  return_label_(impl.return_label_),
548  always_cache_(impl.always_cache_),
549  state_table_(new StateTable(*(impl.state_table_))),
550  nonterminal_set_(impl.nonterminal_set_),
551  nonterminal_hash_(impl.nonterminal_hash_),
552  root_(impl.root_) {
553  SetType("replace");
554  SetProperties(impl.Properties(), kCopyProperties);
555  SetInputSymbols(impl.InputSymbols());
556  SetOutputSymbols(impl.OutputSymbols());
557  fst_array_.reserve(impl.fst_array_.size());
558  fst_array_.emplace_back(nullptr);
559  for (Label i = 1; i < impl.fst_array_.size(); ++i) {
560  fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
561  }
562  }
563 
564  // Computes the dependency graph of the replace class and returns
565  // true if the dependencies are cyclic. Cyclic dependencies will result
566  // in an un-expandable FST.
567  bool CyclicDependencies() const {
568  const ReplaceUtilOptions opts(root_);
569  ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
570  return replace_util.CyclicDependencies();
571  }
572 
574  if (!HasStart()) {
575  if (fst_array_.size() == 1) {
576  SetStart(kNoStateId);
577  return kNoStateId;
578  } else {
579  const auto fst_start = fst_array_[root_]->Start();
580  if (fst_start == kNoStateId) return kNoStateId;
581  const auto prefix = GetPrefixId(StackPrefix());
582  const auto start =
583  state_table_->FindState(StateTuple(prefix, root_, fst_start));
584  SetStart(start);
585  return start;
586  }
587  } else {
588  return CacheImpl::Start();
589  }
590  }
591 
593  if (HasFinal(s)) return CacheImpl::Final(s);
594  const auto &tuple = state_table_->Tuple(s);
595  auto weight = Weight::Zero();
596  if (tuple.prefix_id == 0) {
597  const auto fst_state = tuple.fst_state;
598  weight = fst_array_[tuple.fst_id]->Final(fst_state);
599  }
600  if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
601  return weight;
602  }
603 
604  size_t NumArcs(StateId s) {
605  if (HasArcs(s)) {
606  return CacheImpl::NumArcs(s);
607  } else if (always_cache_) { // If always caching, expands and caches state.
608  Expand(s);
609  return CacheImpl::NumArcs(s);
610  } else { // Otherwise computes the number of arcs without expanding.
611  const auto tuple = state_table_->Tuple(s);
612  if (tuple.fst_state == kNoStateId) return 0;
613  auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
614  if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
615  return num_arcs;
616  }
617  }
618 
619  // Returns whether a given label is a non-terminal.
620  bool IsNonTerminal(Label label) const {
621  if (label < *nonterminal_set_.begin() ||
622  label > *nonterminal_set_.rbegin()) {
623  return false;
624  } else {
625  return nonterminal_hash_.count(label);
626  }
627  // TODO(allauzen): be smarter and take advantage of all_dense or
628  // all_negative. Also use this in ComputeArc. This would require changes to
629  // Replace so that recursing into an empty FST lead to a non co-accessible
630  // state instead of deleting the arc as done currently. The current use
631  // correct, since labels are sorted if all_non_empty is true.
632  }
633 
635  if (HasArcs(s)) {
636  return CacheImpl::NumInputEpsilons(s);
637  } else if (always_cache_ || !Properties(kILabelSorted)) {
638  // If always caching or if the number of input epsilons is too expensive
639  // to compute without caching (i.e., not ilabel-sorted), then expands and
640  // caches state.
641  Expand(s);
642  return CacheImpl::NumInputEpsilons(s);
643  } else {
644  // Otherwise, computes the number of input epsilons without caching.
645  const auto tuple = state_table_->Tuple(s);
646  if (tuple.fst_state == kNoStateId) return 0;
647  size_t num = 0;
648  if (!EpsilonOnInput(call_label_type_)) {
649  // If EpsilonOnInput(c) is false, all input epsilon arcs
650  // are also input epsilons arcs in the underlying machine.
651  num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
652  } else {
653  // Otherwise, one need to consider that all non-terminal arcs
654  // in the underlying machine also become input epsilon arc.
655  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
656  for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
657  IsNonTerminal(aiter.Value().olabel));
658  aiter.Next()) {
659  ++num;
660  }
661  }
662  if (EpsilonOnInput(return_label_type_) &&
663  ComputeFinalArc(tuple, nullptr)) {
664  ++num;
665  }
666  return num;
667  }
668  }
669 
671  if (HasArcs(s)) {
673  } else if (always_cache_ || !Properties(kOLabelSorted)) {
674  // If always caching or if the number of output epsilons is too expensive
675  // to compute without caching (i.e., not olabel-sorted), then expands and
676  // caches state.
677  Expand(s);
679  } else {
680  // Otherwise, computes the number of output epsilons without caching.
681  const auto tuple = state_table_->Tuple(s);
682  if (tuple.fst_state == kNoStateId) return 0;
683  size_t num = 0;
684  if (!EpsilonOnOutput(call_label_type_)) {
685  // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
686  // output epsilons arcs in the underlying machine.
687  num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
688  } else {
689  // Otherwise, one need to consider that all non-terminal arcs in the
690  // underlying machine also become output epsilon arc.
691  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
692  for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
693  IsNonTerminal(aiter.Value().olabel));
694  aiter.Next()) {
695  ++num;
696  }
697  }
698  if (EpsilonOnOutput(return_label_type_) &&
699  ComputeFinalArc(tuple, nullptr)) {
700  ++num;
701  }
702  return num;
703  }
704  }
705 
706  uint64 Properties() const override { return Properties(kFstProperties); }
707 
708  // Sets error if found, and returns other FST impl properties.
709  uint64 Properties(uint64 mask) const override {
710  if (mask & kError) {
711  for (Label i = 1; i < fst_array_.size(); ++i) {
712  if (fst_array_[i]->Properties(kError, false)) {
713  SetProperties(kError, kError);
714  }
715  }
716  }
717  return FstImpl<Arc>::Properties(mask);
718  }
719 
720  // Returns the base arc iterator, and if arcs have not been computed yet,
721  // extends and recurses for new arcs.
723  if (!HasArcs(s)) Expand(s);
724  CacheImpl::InitArcIterator(s, data);
725  // TODO(allauzen): Set behaviour of generic iterator.
726  // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
727  // behaviour.
728  }
729 
730  // Extends current state (walk arcs one level deep).
731  void Expand(StateId s) {
732  const auto tuple = state_table_->Tuple(s);
733  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
734  SetArcs(s);
735  return;
736  }
737  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
738  Arc arc;
739  // Creates a final arc when needed.
740  if (ComputeFinalArc(tuple, &arc)) PushArc(s, std::move(arc));
741  // Expands all arcs leaving the state.
742  for (; !aiter.Done(); aiter.Next()) {
743  if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, std::move(arc));
744  }
745  SetArcs(s);
746  }
747 
748  void Expand(StateId s, const StateTuple &tuple,
749  const ArcIteratorData<Arc> &data) {
750  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
751  SetArcs(s);
752  return;
753  }
754  ArcIterator<Fst<Arc>> aiter(data);
755  Arc arc;
756  // Creates a final arc when needed.
757  if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
758  // Expands all arcs leaving the state.
759  for (; !aiter.Done(); aiter.Next()) {
760  if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
761  }
762  SetArcs(s);
763  }
764 
765  // If acpp is null, only returns true if a final arcp is required, but does
766  // not actually compute it.
767  bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
768  uint8 flags = kArcValueFlags) {
769  const auto fst_state = tuple.fst_state;
770  if (fst_state == kNoStateId) return false;
771  // If state is final, pops the stack.
772  if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
773  tuple.prefix_id) {
774  if (arcp) {
775  arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
776  arcp->olabel =
777  (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
778  if (flags & kArcNextStateValue) {
779  const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
780  const auto prefix_id = PopPrefix(stack);
781  const auto &top = stack.Top();
782  arcp->nextstate = state_table_->FindState(
783  StateTuple(prefix_id, top.fst_id, top.nextstate));
784  }
785  if (flags & kArcWeightValue) {
786  arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
787  }
788  }
789  return true;
790  } else {
791  return false;
792  }
793  }
794 
795  // Computes an arc in the FST corresponding to one in the underlying machine.
796  // Returns false if the underlying arc corresponds to no arc in the resulting
797  // FST.
798  bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
799  uint8 flags = kArcValueFlags) {
800  if (!EpsilonOnInput(call_label_type_) &&
801  (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
802  *arcp = arc;
803  return true;
804  }
805  if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
806  arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST.
807  const auto nextstate =
808  flags & kArcNextStateValue
809  ? state_table_->FindState(
810  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
811  : kNoStateId;
812  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
813  } else {
814  // Checks for non-terminal.
815  const auto it = nonterminal_hash_.find(arc.olabel);
816  if (it != nonterminal_hash_.end()) { // Recurses into non-terminal.
817  const auto nonterminal = it->second;
818  const auto nt_prefix =
819  PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
820  tuple.fst_id, arc.nextstate);
821  // If the start state is valid, replace; othewise, the arc is implicitly
822  // deleted.
823  const auto nt_start = fst_array_[nonterminal]->Start();
824  if (nt_start != kNoStateId) {
825  const auto nt_nextstate = flags & kArcNextStateValue
826  ? state_table_->FindState(StateTuple(
827  nt_prefix, nonterminal, nt_start))
828  : kNoStateId;
829  const auto ilabel =
830  (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
831  const auto olabel =
832  (EpsilonOnOutput(call_label_type_))
833  ? 0
834  : ((call_output_label_ == kNoLabel) ? arc.olabel
835  : call_output_label_);
836  *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
837  } else {
838  return false;
839  }
840  } else {
841  const auto nextstate =
842  flags & kArcNextStateValue
843  ? state_table_->FindState(
844  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
845  : kNoStateId;
846  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
847  }
848  }
849  return true;
850  }
851 
852  // Returns the arc iterator flags supported by this FST.
854  uint8 flags = kArcValueFlags;
855  if (!always_cache_) flags |= kArcNoCache;
856  return flags;
857  }
858 
859  StateTable *GetStateTable() const { return state_table_.get(); }
860 
861  const Fst<Arc> *GetFst(Label fst_id) const {
862  return fst_array_[fst_id].get();
863  }
864 
865  Label GetFstId(Label nonterminal) const {
866  const auto it = nonterminal_hash_.find(nonterminal);
867  if (it == nonterminal_hash_.end()) {
868  FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
869  << nonterminal;
870  }
871  return it->second;
872  }
873 
874  // Returns true if label type on call arc results in epsilon input label.
875  bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
876 
877  private:
878  // The unique index into stack prefix table.
879  PrefixId GetPrefixId(const StackPrefix &prefix) {
880  return state_table_->FindPrefixId(prefix);
881  }
882 
883  // The prefix ID after a stack pop.
884  PrefixId PopPrefix(StackPrefix prefix) {
885  prefix.Pop();
886  return GetPrefixId(prefix);
887  }
888 
889  // The prefix ID after a stack push.
890  PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
891  prefix.Push(fst_id, nextstate);
892  return GetPrefixId(prefix);
893  }
894 
895  // Runtime options
896  ReplaceLabelType call_label_type_; // How to label call arc.
897  ReplaceLabelType return_label_type_; // How to label return arc.
898  int64 call_output_label_; // Specifies output label to put on call arc
899  int64 return_label_; // Specifies label to put on return arc.
900  bool always_cache_; // Disable optional caching of arc iterator?
901 
902  // State table.
903  std::unique_ptr<StateTable> state_table_;
904 
905  // Replace components.
906  std::set<Label> nonterminal_set_;
907  NonTerminalHash nonterminal_hash_;
908  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
909  Label root_;
910 };
911 
912 } // namespace internal
913 
914 //
915 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
916 // This replacement is recursive. ReplaceFst can be used to support a variety of
917 // delayed constructions such as recursive
918 // transition networks, union, or closure. It is constructed with an array of
919 // FST(s). One FST represents the root (or topology) machine. The root FST
920 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
921 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
922 // symbols of the arcs to determine whether the arc is a non-terminal arc or
923 // not. A non-terminal can be any label that is not a non-zero terminal label in
924 // the output alphabet.
925 //
926 // Note that the constructor uses a vector of pairs. These correspond to the
927 // tuple of non-terminal Label and corresponding FST. For example to implement
928 // the closure operation we need 2 FSTs. The first root FST is a single
929 // self-loop arc on the start state.
930 //
931 // The ReplaceFst class supports an optionally caching arc iterator.
932 //
933 // The ReplaceFst needs to be built such that it is known to be ilabel- or
934 // olabel-sorted (see usage below).
935 //
936 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
937 // when available (the FST is ilabel-sorted and matching on the input, or the
938 // FST is olabel -orted and matching on the output). In order to obtain the
939 // most efficient behaviour, it is recommended to set call_label_type to
940 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
941 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
942 // does not have epsilon on the input side and the return arc has epsilon on the
943 // input side) and matching on the input side.
944 //
945 // This class attaches interface to implementation and handles reference
946 // counting, delegating most methods to ImplToFst.
947 template <class A, class T /* = DefaultReplaceStateTable<A> */,
948  class CacheStore /* = DefaultCacheStore<A> */>
949 class ReplaceFst
950  : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
951  public:
952  using Arc = A;
953  using Label = typename Arc::Label;
954  using StateId = typename Arc::StateId;
955  using Weight = typename Arc::Weight;
956 
957  using StateTable = T;
958  using Store = CacheStore;
959  using State = typename CacheStore::State;
962 
964 
965  friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
966  friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
967  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
968 
969  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
970  Label root)
971  : ImplToFst<Impl>(std::make_shared<Impl>(
972  fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
973 
974  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
976  : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
977 
978  // See Fst<>::Copy() for doc.
979  ReplaceFst(const ReplaceFst &fst, bool safe = false)
980  : ImplToFst<Impl>(fst, safe) {}
981 
982  // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
983  ReplaceFst *Copy(bool safe = false) const override {
984  return new ReplaceFst(*this, safe);
985  }
986 
987  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
988 
989  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
990  GetMutableImpl()->InitArcIterator(s, data);
991  }
992 
993  MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
994  if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
995  ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
996  (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
998  (this, match_type);
999  } else {
1000  VLOG(2) << "Not using replace matcher";
1001  return nullptr;
1002  }
1003  }
1004 
1005  bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1006 
1007  const StateTable &GetStateTable() const {
1008  return *GetImpl()->GetStateTable();
1009  }
1010 
1011  const Fst<Arc> &GetFst(Label nonterminal) const {
1012  return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1013  }
1014 
1015  private:
1018 
1019  ReplaceFst &operator=(const ReplaceFst &) = delete;
1020 };
1021 
1022 // Specialization for ReplaceFst.
1023 template <class Arc, class StateTable, class CacheStore>
1024 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1025  : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1026  public:
1028  : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1029  fst, fst.GetMutableImpl()) {}
1030 };
1031 
1032 // Specialization for ReplaceFst, implementing optional caching. It is be used
1033 // as follows:
1034 //
1035 // ReplaceFst<A> replace;
1036 // ArcIterator<ReplaceFst<A>> aiter(replace, s);
1037 // // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1038 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1039 // // Uses the arc iterator, no arc will be cached, no state will be expanded.
1040 // // Arc flags can be used to decide which component of the arc need to be
1041 // computed.
1042 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1043 // // Wants the ilabel for this arc.
1044 // aiter.Value(); // Does not compute the destination state.
1045 // aiter.Next();
1046 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1047 // // Wants the ilabel and next state for this arc.
1048 // aiter.Value(); // Does compute the destination state and inserts it
1049 // // in the replace state table.
1050 // // No additional arcs have been cached at this point.
1051 template <class Arc, class StateTable, class CacheStore>
1052 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1053  public:
1054  using StateId = typename Arc::StateId;
1055 
1056  using StateTuple = typename StateTable::StateTuple;
1057 
1059  : fst_(fst),
1060  s_(s),
1061  pos_(0),
1062  offset_(0),
1063  flags_(kArcValueFlags),
1064  arcs_(nullptr),
1065  data_flags_(0),
1066  final_flags_(0) {
1067  cache_data_.ref_count = nullptr;
1068  local_data_.ref_count = nullptr;
1069  // If FST does not support optional caching, forces caching.
1070  if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1071  !(fst_.GetImpl()->HasArcs(s_))) {
1072  fst_.GetMutableImpl()->Expand(s_);
1073  }
1074  // If state is already cached, use cached arcs array.
1075  if (fst_.GetImpl()->HasArcs(s_)) {
1076  (fst_.GetImpl())
1077  ->internal::template CacheBaseImpl<
1078  typename CacheStore::State,
1079  CacheStore>::InitArcIterator(s_, &cache_data_);
1080  num_arcs_ = cache_data_.narcs;
1081  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1082  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1083  } else { // Otherwise delay decision until Value() is called.
1084  tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1085  if (tuple_.fst_state == kNoStateId) {
1086  num_arcs_ = 0;
1087  } else {
1088  // The decision to cache or not to cache has been defered until Value()
1089  // or
1090  // SetFlags() is called. However, the arc iterator is set up now to be
1091  // ready for non-caching in order to keep the Value() method simple and
1092  // efficient.
1093  const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1094  rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1095  // arcs_ is a pointer to the arcs in the underlying machine.
1096  arcs_ = local_data_.arcs;
1097  // Computes the final arc (but not its destination state) if a final arc
1098  // is required.
1099  bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1100  tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1101  // Sets the arc value flags that hold for final_arc_.
1102  final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1103  // Computes the number of arcs.
1104  num_arcs_ = local_data_.narcs;
1105  if (has_final_arc) ++num_arcs_;
1106  // Sets the offset between the underlying arc positions and the
1107  // positions
1108  // in the arc iterator.
1109  offset_ = num_arcs_ - local_data_.narcs;
1110  // Defers the decision to cache or not until Value() or SetFlags() is
1111  // called.
1112  data_flags_ = 0;
1113  }
1114  }
1115  }
1116 
1118  if (cache_data_.ref_count) --(*cache_data_.ref_count);
1119  if (local_data_.ref_count) --(*local_data_.ref_count);
1120  }
1121 
1122  void ExpandAndCache() const {
1123  // TODO(allauzen): revisit this.
1124  // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1125  // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1126  // &cache_data_);
1127  //
1128  fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state.
1129  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1130  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1131  offset_ = 0; // No offset.
1132  }
1133 
1134  void Init() {
1135  if (flags_ & kArcNoCache) { // If caching is disabled
1136  // arcs_ is a pointer to the arcs in the underlying machine.
1137  arcs_ = local_data_.arcs;
1138  // Sets the arcs value flags that hold for arcs_.
1139  data_flags_ = kArcWeightValue;
1140  if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1141  data_flags_ |= kArcILabelValue;
1142  }
1143  // Sets the offset between the underlying arc positions and the positions
1144  // in the arc iterator.
1145  offset_ = num_arcs_ - local_data_.narcs;
1146  } else {
1147  ExpandAndCache();
1148  }
1149  }
1150 
1151  bool Done() const { return pos_ >= num_arcs_; }
1152 
1153  const Arc &Value() const {
1154  // If data_flags_ is 0, non-caching was not requested.
1155  if (!data_flags_) {
1156  // TODO(allauzen): Revisit this.
1157  if (flags_ & kArcNoCache) {
1158  // Should never happen.
1159  FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1160  }
1161  ExpandAndCache();
1162  }
1163  if (pos_ - offset_ >= 0) { // The requested arc is not the final arc.
1164  const auto &arc = arcs_[pos_ - offset_];
1165  if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1166  // If the value flags match the recquired value flags then returns the
1167  // arc.
1168  return arc;
1169  } else {
1170  // Otherwise, compute the corresponding arc on-the-fly.
1171  fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1172  flags_ & kArcValueFlags);
1173  return arc_;
1174  }
1175  } else { // The requested arc is the final arc.
1176  if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1177  // If the arc value flags that hold for the final arc do not match the
1178  // requested value flags, then
1179  // final_arc_ needs to be updated.
1180  fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1181  flags_ & kArcValueFlags);
1182  final_flags_ = flags_ & kArcValueFlags;
1183  }
1184  return final_arc_;
1185  }
1186  }
1187 
1188  void Next() { ++pos_; }
1189 
1190  size_t Position() const { return pos_; }
1191 
1192  void Reset() { pos_ = 0; }
1193 
1194  void Seek(size_t pos) { pos_ = pos; }
1195 
1196  uint8 Flags() const { return flags_; }
1197 
1198  void SetFlags(uint8 flags, uint8 mask) {
1199  // Updates the flags taking into account what flags are supported
1200  // by the FST.
1201  flags_ &= ~mask;
1202  flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1203  // If non-caching is not requested (and caching has not already been
1204  // performed), then flush data_flags_ to request caching during the next
1205  // call to Value().
1206  if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1207  if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1208  }
1209  // If data_flags_ has been flushed but non-caching is requested before
1210  // calling Value(), then set up the iterator for non-caching.
1211  if ((flags & kArcNoCache) && (!data_flags_)) Init();
1212  }
1213 
1214  private:
1215  const ReplaceFst<Arc, StateTable, CacheStore> &fst_; // Reference to the FST.
1216  StateId s_; // State in the FST.
1217  mutable StateTuple tuple_; // Tuple corresponding to state_.
1218 
1219  ssize_t pos_; // Current position.
1220  mutable ssize_t offset_; // Offset between position in iterator and in arcs_.
1221  ssize_t num_arcs_; // Number of arcs at state_.
1222  uint8 flags_; // Behavorial flags for the arc iterator
1223  mutable Arc arc_; // Memory to temporarily store computed arcs.
1224 
1225  mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache.
1226  mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local FST.
1227 
1228  mutable const Arc *arcs_; // Array of arcs.
1229  mutable uint8 data_flags_; // Arc value flags valid for data in arcs_.
1230  mutable Arc final_arc_; // Final arc (when required).
1231  mutable uint8 final_flags_; // Arc value flags valid for final_arc_.
1232 
1233  ArcIterator(const ArcIterator &) = delete;
1234  ArcIterator &operator=(const ArcIterator &) = delete;
1235 };
1236 
1237 template <class Arc, class StateTable, class CacheStore>
1238 class ReplaceFstMatcher : public MatcherBase<Arc> {
1239  public:
1240  using Label = typename Arc::Label;
1241  using StateId = typename Arc::StateId;
1242  using Weight = typename Arc::Weight;
1243 
1246 
1247  using StateTuple = typename StateTable::StateTuple;
1248 
1249  // This makes a copy of the FST.
1251  MatchType match_type)
1252  : owned_fst_(fst.Copy()),
1253  fst_(*owned_fst_),
1254  impl_(fst_.GetMutableImpl()),
1255  s_(fst::kNoStateId),
1256  match_type_(match_type),
1257  current_loop_(false),
1258  final_arc_(false),
1259  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1260  if (match_type_ == fst::MATCH_OUTPUT) {
1261  std::swap(loop_.ilabel, loop_.olabel);
1262  }
1263  InitMatchers();
1264  }
1265 
1266  // This doesn't copy the FST.
1268  MatchType match_type)
1269  : fst_(*fst),
1270  impl_(fst_.GetMutableImpl()),
1271  s_(fst::kNoStateId),
1272  match_type_(match_type),
1273  current_loop_(false),
1274  final_arc_(false),
1275  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1276  if (match_type_ == fst::MATCH_OUTPUT) {
1277  std::swap(loop_.ilabel, loop_.olabel);
1278  }
1279  InitMatchers();
1280  }
1281 
1282  // This makes a copy of the FST.
1283  ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe = false)
1284  : owned_fst_(matcher.fst_.Copy(safe)),
1285  fst_(*owned_fst_),
1286  impl_(fst_.GetMutableImpl()),
1287  s_(fst::kNoStateId),
1288  match_type_(matcher.match_type_),
1289  current_loop_(false),
1290  final_arc_(false),
1291  loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1292  if (match_type_ == fst::MATCH_OUTPUT) {
1293  std::swap(loop_.ilabel, loop_.olabel);
1294  }
1295  InitMatchers();
1296  }
1297 
1298  // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1299  // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1300  // non-terminal arc, since these non-terminal
1301  // turn into epsilons on recursion.
1302  void InitMatchers() {
1303  const auto &fst_array = impl_->fst_array_;
1304  matcher_.resize(fst_array.size());
1305  for (Label i = 0; i < fst_array.size(); ++i) {
1306  if (fst_array[i]) {
1307  matcher_[i].reset(
1308  new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList));
1309  auto it = impl_->nonterminal_set_.begin();
1310  for (; it != impl_->nonterminal_set_.end(); ++it) {
1311  matcher_[i]->AddMultiEpsLabel(*it);
1312  }
1313  }
1314  }
1315  }
1316 
1317  ReplaceFstMatcher *Copy(bool safe = false) const override {
1318  return new ReplaceFstMatcher(*this, safe);
1319  }
1320 
1321  MatchType Type(bool test) const override {
1322  if (match_type_ == MATCH_NONE) return match_type_;
1323  const auto true_prop =
1324  match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1325  const auto false_prop =
1326  match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1327  const auto props = fst_.Properties(true_prop | false_prop, test);
1328  if (props & true_prop) {
1329  return match_type_;
1330  } else if (props & false_prop) {
1331  return MATCH_NONE;
1332  } else {
1333  return MATCH_UNKNOWN;
1334  }
1335  }
1336 
1337  const Fst<Arc> &GetFst() const override { return fst_; }
1338 
1339  uint64 Properties(uint64 props) const override { return props; }
1340 
1341  // Sets the state from which our matching happens.
1342  void SetState(StateId s) final {
1343  if (s_ == s) return;
1344  s_ = s;
1345  tuple_ = impl_->GetStateTable()->Tuple(s_);
1346  if (tuple_.fst_state == kNoStateId) {
1347  done_ = true;
1348  return;
1349  }
1350  // Gets current matcher, used for non-epsilon matching.
1351  current_matcher_ = matcher_[tuple_.fst_id].get();
1352  current_matcher_->SetState(tuple_.fst_state);
1353  loop_.nextstate = s_;
1354  final_arc_ = false;
1355  }
1356 
1357  // Searches for label from previous set state. If label == 0, first
1358  // hallucinate an epsilon loop; otherwise use the underlying matcher to
1359  // search for the label or epsilons. Note since the ReplaceFst recursion
1360  // on non-terminal arcs causes epsilon transitions to be created we use
1361  // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1362  // component FST
1363  // reaches a final state we also need to add the exiting final arc.
1364  bool Find(Label label) final {
1365  bool found = false;
1366  label_ = label;
1367  if (label_ == 0 || label_ == kNoLabel) {
1368  // Computes loop directly, avoiding Replace::ComputeArc.
1369  if (label_ == 0) {
1370  current_loop_ = true;
1371  found = true;
1372  }
1373  // Searches for matching multi-epsilons.
1374  final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1375  found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1376  } else {
1377  // Searches on a sub machine directly using sub machine matcher.
1378  found = current_matcher_->Find(label_);
1379  }
1380  return found;
1381  }
1382 
1383  bool Done() const final {
1384  return !current_loop_ && !final_arc_ && current_matcher_->Done();
1385  }
1386 
1387  const Arc &Value() const final {
1388  if (current_loop_) return loop_;
1389  if (final_arc_) {
1390  impl_->ComputeFinalArc(tuple_, &arc_);
1391  return arc_;
1392  }
1393  const auto &component_arc = current_matcher_->Value();
1394  impl_->ComputeArc(tuple_, component_arc, &arc_);
1395  return arc_;
1396  }
1397 
1398  void Next() final {
1399  if (current_loop_) {
1400  current_loop_ = false;
1401  return;
1402  }
1403  if (final_arc_) {
1404  final_arc_ = false;
1405  return;
1406  }
1407  current_matcher_->Next();
1408  }
1409 
1410  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1411 
1412  private:
1413  std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
1416  LocalMatcher *current_matcher_;
1417  std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1418  StateId s_; // Current state.
1419  Label label_; // Current label.
1420  MatchType match_type_; // Supplied by caller.
1421  mutable bool done_;
1422  mutable bool current_loop_; // Current arc is the implicit loop.
1423  mutable bool final_arc_; // Current arc for exiting recursion.
1424  mutable StateTuple tuple_; // Tuple corresponding to state_.
1425  mutable Arc arc_;
1426  Arc loop_;
1427 
1428  ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1429 };
1430 
1431 template <class Arc, class StateTable, class CacheStore>
1433  StateIteratorData<Arc> *data) const {
1434  data->base =
1436 }
1437 
1439 
1440 // Recursively replaces arcs in the root FSTs with other FSTs.
1441 // This version writes the result of replacement to an output MutableFst.
1442 //
1443 // Replace supports replacement of arcs in one Fst with another FST. This
1444 // replacement is recursive. Replace takes an array of FST(s). One FST
1445 // represents the root (or topology) machine. The root FST refers to other FSTs
1446 // by recursively replacing arcs labeled as non-terminals with the matching
1447 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1448 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1449 // any label that is not a non-zero terminal label in the output alphabet.
1450 //
1451 // Note that input argument is a vector of pairs. These correspond to the tuple
1452 // of non-terminal Label and corresponding FST.
1453 template <class Arc>
1454 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1455  &ifst_array,
1456  MutableFst<Arc> *ofst,
1458  opts.gc = true;
1459  opts.gc_limit = 0; // Caches only the last state for fastest copy.
1460  *ofst = ReplaceFst<Arc>(ifst_array, opts);
1461 }
1462 
1463 template <class Arc>
1464 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1465  &ifst_array,
1466  MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1467  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1468 }
1469 
1470 // For backwards compatibility.
1471 template <class Arc>
1472 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1473  &ifst_array,
1474  MutableFst<Arc> *ofst, typename Arc::Label root,
1475  bool epsilon_on_replace) {
1476  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1477 }
1478 
1479 template <class Arc>
1480 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1481  &ifst_array,
1482  MutableFst<Arc> *ofst, typename Arc::Label root) {
1483  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1484 }
1485 
1486 } // namespace fst
1487 
1488 #endif // FST_REPLACE_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:100
ReplaceFstOptions(const CacheOptions &opts, Label root=kNoLabel)
Definition: replace.h:334
typename CacheStore::State State
Definition: replace.h:462
void SetState(StateId s) final
Definition: replace.h:1342
ArcIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst, StateId s)
Definition: replace.h:1058
const StateTuple & Tuple(StateId id) const
Definition: replace.h:250
constexpr int kNoLabel
Definition: fst.h:179
std::vector< PrefixTuple > prefix_
Definition: replace.h:159
MatcherBase< Arc > * InitMatcher(MatchType match_type) const override
Definition: replace.h:993
void Expand(StateId s, const StateTuple &tuple, const ArcIteratorData< Arc > &data)
Definition: replace.h:748
const StateTable & GetStateTable() const
Definition: replace.h:1007
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:298
constexpr uint64 kNotOLabelSorted
Definition: properties.h:84
uint64_t uint64
Definition: types.h:32
void Next() final
Definition: replace.h:1398
const Fst< Arc > * GetFst(Label fst_id) const
Definition: replace.h:861
ReplaceStateTuple(PrefixId prefix_id=-1, StateId fst_id=kNoStateId, StateId fst_state=kNoStateId)
Definition: replace.h:68
typename StateTable::PrefixId PrefixId
Definition: replace.h:464
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:294
ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
Definition: replace.h:369
ReplaceFstOptions(const ReplaceUtilOptions &opts)
Definition: replace.h:362
PrefixId prefix_id
Definition: replace.h:72
size_t NumOutputEpsilons(StateId s)
Definition: replace.h:670
bool EpsilonOnOutput(ReplaceLabelType label_type)
Definition: replace.h:391
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > *fst, MatchType match_type)
Definition: replace.h:1267
typename CacheStore::State State
Definition: replace.h:959
ReplaceLabelType
Definition: replace-util.h:29
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:83
const SymbolTable * OutputSymbols() const
Definition: fst.h:695
MatchType
Definition: fst.h:171
uint64 ReplaceProperties(const std::vector< uint64 > &inprops, size_t root, bool epsilon_on_call, bool epsilon_on_return, bool out_epsilon_on_call, bool out_epsilon_on_return, bool replace_transducer, bool no_empty_fst, bool all_ilabel_sorted, bool all_olabel_sorted, bool all_negative_or_dense)
Definition: properties.cc:236
constexpr uint64 kILabelSorted
Definition: properties.h:77
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:974
uint64 Properties(uint64 mask) const override
Definition: replace.h:709
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: replace.h:989
bool ReplaceTransducer(ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label)
Definition: replace.h:398
void Push(StateId fst_id, StateId nextstate)
Definition: replace.h:148
VectorHashReplaceStateTable(const VectorHashReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:234
SetType
Definition: set-weight.h:38
MatchType Type(bool test) const override
Definition: replace.h:1321
typename Arc::Weight Weight
Definition: replace.h:1242
typename ReplaceFst< Arc, StateTable, CacheStore >::Arc Arc
Definition: cache.h:1151
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:89
constexpr uint64 kNotILabelSorted
Definition: properties.h:79
bool EpsilonOnInput(ReplaceLabelType label_type)
Definition: replace.h:385
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:786
bool CyclicDependencies() const
Definition: replace.h:1005
const uint32 kMultiEpsList
Definition: matcher.h:1226
VectorHashReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_list, Label root)
Definition: replace.h:213
ReplaceFstOptions(const CacheImplOptions< CacheStore > &opts, Label root=kNoLabel)
Definition: replace.h:330
bool Find(Label label) final
Definition: replace.h:1364
const PrefixTuple & Top() const
Definition: replace.h:154
const Fst< Arc > & GetFst(Label nonterminal) const
Definition: replace.h:1011
constexpr uint64 kFstProperties
Definition: properties.h:309
constexpr uint64 kCopyProperties
Definition: properties.h:146
constexpr int kNoStateId
Definition: fst.h:180
uint64 operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:115
const Arc & Value() const
Definition: fst.h:512
DefaultReplaceStateTable(const DefaultReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:291
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: replace.h:722
std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> FstList
Definition: replace.h:382
int64_t int64
Definition: types.h:27
#define FSTERROR()
Definition: util.h:36
ReplaceFstImpl(const ReplaceFstImpl &impl)
Definition: replace.h:542
size_t operator()(const ReplaceStateTuple< S, P > &t) const
Definition: replace.h:124
StateIteratorBase< Arc > * base
Definition: fst.h:359
typename Arc::Weight Weight
Definition: replace.h:460
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:94
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: replace.h:1432
typename Arc::StateId StateId
Definition: replace.h:198
bool operator==(const PdtStateTuple< S, K > &x, const PdtStateTuple< S, K > &y)
Definition: pdt.h:133
uint8_t uint8
Definition: types.h:29
const Fst< Arc > & GetFst() const override
Definition: replace.h:1337
bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp, uint8 flags=kArcValueFlags)
Definition: replace.h:798
Weight Final(StateId s)
Definition: replace.h:592
constexpr uint64 kOLabelSorted
Definition: properties.h:82
#define VLOG(level)
Definition: log.h:47
size_t NumArcs(StateId s)
Definition: replace.h:604
typename Arc::Label Label
Definition: replace.h:197
bool Done() const
Definition: fst.h:508
bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp, uint8 flags=kArcValueFlags)
Definition: replace.h:767
ReplaceFst * Copy(bool safe=false) const override
Definition: replace.h:983
uint64 Properties(uint64 props) const override
Definition: replace.h:1339
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:252
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:256
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, Label root)
Definition: replace.h:969
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label return_label)
Definition: replace.h:345
uint64 operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:102
typename Arc::Label Label
Definition: replace.h:312
ReplaceFstMatcher * Copy(bool safe=false) const override
Definition: replace.h:1317
void Expand(StateId s)
Definition: replace.h:731
ReplaceFst(const ReplaceFst &fst, bool safe=false)
Definition: replace.h:979
const Arc & Value() const final
Definition: replace.h:1387
uint8 ArcIteratorFlags() const
Definition: replace.h:853
Label GetFstId(Label nonterminal) const
Definition: replace.h:865
typename internal::ReplaceFstImpl< Arc, StateTable, CacheStore >::Arc Arc
Definition: fst.h:188
size_t Depth() const
Definition: replace.h:156
const SymbolTable * InputSymbols() const
Definition: fst.h:693
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:155
bool CyclicDependencies() const
Definition: replace-util.h:120
typename Arc::Label Label
Definition: replace.h:1240
size_t operator()(const ReplaceStackPrefix< Label, StateId > &prefix) const
Definition: replace.h:180
size_t NumInputEpsilons(StateId s)
Definition: replace.h:634
ssize_t Priority(StateId s) final
Definition: replace.h:1410
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label, Label return_label)
Definition: replace.h:353
ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe=false)
Definition: replace.h:1283
ReplaceFstOptions(Label root)
Definition: replace.h:343
constexpr uint64 kError
Definition: properties.h:35
bool CyclicDependencies() const
Definition: replace.h:567
typename StateTable::StateTuple StateTuple
Definition: replace.h:1247
ReplaceFstImpl(const FstList< Arc > &fst_list, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:487
uint64 Properties() const override
Definition: replace.h:706
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > &fst, MatchType match_type)
Definition: replace.h:1250
typename Arc::StateId StateId
Definition: replace.h:274
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:303
ReplaceFingerprint(const std::vector< uint64 > *size_array)
Definition: replace.h:99
uint64 ReplaceFstProperties(typename Arc::Label root_label, const FstList< Arc > &fst_list, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, typename Arc::Label call_output_label, bool *sorted_and_non_empty)
Definition: replace.h:410
ReplaceStackPrefix(const ReplaceStackPrefix &other)
Definition: replace.h:145
typename Arc::StateId StateId
Definition: replace.h:459
bool Done() const final
Definition: replace.h:1383
StateId FindState(const StateTuple &tuple)
Definition: replace.h:246
bool IsNonTerminal(Label label) const
Definition: replace.h:620
typename Arc::StateId StateId
Definition: replace.h:1241
DefaultReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &, Label)
Definition: replace.h:288
StateIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst)
Definition: replace.h:1027
typename Arc::Label Label
Definition: replace.h:458
PrefixTuple(Label fst_id=kNoLabel, StateId nextstate=kNoStateId)
Definition: replace.h:136
bool operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:90
StateTable * GetStateTable() const
Definition: replace.h:859
std::unordered_map< Label, Label > NonTerminalHash
Definition: replace.h:467
void Next()
Definition: fst.h:516