FST  openfst-1.8.2
OpenFst Library
replace.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes for the recursive replacement of FSTs.
19 
20 #ifndef FST_REPLACE_H_
21 #define FST_REPLACE_H_
22 
23 #include <cstdint>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/log.h>
30 
31 #include <fst/cache.h>
32 #include <fst/expanded-fst.h>
33 #include <fst/fst-decl.h> // For optional argument declarations.
34 #include <fst/fst.h>
35 #include <fst/matcher.h>
36 #include <fst/replace-util.h>
37 #include <fst/state-table.h>
38 #include <fst/test-properties.h>
39 
40 #include <unordered_map>
41 
42 namespace fst {
43 
44 // Replace state tables have the form:
45 //
46 // template <class Arc, class P>
47 // class ReplaceStateTable {
48 // public:
49 // using Label = typename Arc::Label Label;
50 // using StateId = typename Arc::StateId;
51 //
52 // using PrefixId = P;
53 // using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
54 // using StackPrefix = ReplaceStackPrefix<Label, StateId>;
55 //
56 // // Required constructor.
57 // ReplaceStateTable(
58 // const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
59 // Label root);
60 //
61 // // Required copy constructor that does not copy state.
62 // ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
63 //
64 // // Looks up state ID by tuple, adding it if it doesn't exist.
65 // StateId FindState(const StateTuple &tuple);
66 //
67 // // Looks up state tuple by ID.
68 // const StateTuple &Tuple(StateId id) const;
69 //
70 // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
71 // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
72 //
73 // // Looks up stack prefix by ID.
74 // const StackPrefix &GetStackPrefix(PrefixId id) const;
75 // };
76 
77 // Tuple that uniquely defines a state in replace.
78 template <class S, class P>
80  using StateId = S;
81  using PrefixId = P;
82 
86 
87  PrefixId prefix_id; // Index in prefix table.
88  StateId fst_id; // Current FST being walked.
89  StateId fst_state; // Current state in FST being walked (not to be
90  // confused with the thse StateId of the combined FST).
91 };
92 
93 // Equality of replace state tuples.
94 template <class StateId, class PrefixId>
97  return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
98  x.fst_state == y.fst_state;
99 }
100 
101 // Functor returning true for tuples corresponding to states in the root FST.
102 template <class StateId, class PrefixId>
104  public:
106  return tuple.prefix_id == 0;
107  }
108 };
109 
110 // Functor for fingerprinting replace state tuples.
111 template <class StateId, class PrefixId>
113  public:
114  explicit ReplaceFingerprint(const std::vector<uint64_t> *size_array)
115  : size_array_(size_array) {}
116 
117  uint64_t operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
118  return tuple.prefix_id * size_array_->back() +
119  size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
120  }
121 
122  private:
123  const std::vector<uint64_t> *size_array_;
124 };
125 
126 // Useful when the fst_state uniquely define the tuple.
127 template <class StateId, class PrefixId>
129  public:
130  uint64_t operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
131  return tuple.fst_state;
132  }
133 };
134 
135 // A generic hash function for replace state tuples.
136 template <typename S, typename P>
137 class ReplaceHash {
138  public:
139  size_t operator()(const ReplaceStateTuple<S, P> &t) const {
140  // TODO(189578616): Use the AbslHashValue framework. See
141  // b/189578616#comment8 for a proposed patch and a discussion of the
142  // concerns for making such a change.
143  //
144  // We want three prime numbers that are all reasonably large and whose
145  // differences are far from each other. (E.g., we want prime1-prime0 to be
146  // far from prime2-prime1). It would be safer still to use large prime
147  // numbers (i.e., prime numbers that use all 64 bits on a 64-bit machine and
148  // all 32-bits on a 32-bit machine), so that all 64-bits (respectively
149  // 32-bits) of the resulting hash would look random. However, these
150  // modest-sized prime numbers are good enough for hash tables (such as
151  // std::unordered_set and std::unordered_set) that use the low-order bits
152  // of the hash.
153  //
154  // It is important that all three components are multiplied by a prime
155  // number. E.g., don't compute
156  // t.prefix_id + t.fst_id * prime1 + t.fst_state * prime2
157  // which is just the identity on t.prefix_id. Using the identity will
158  // result in long probe sequences in open-addressed hash tables (such as
159  // std::unordered_map).
160  static constexpr size_t prime0 = 7853;
161  static constexpr size_t prime1 = 9001;
162  static constexpr size_t prime2 = 100003;
163  return t.prefix_id * prime0 + t.fst_id * prime1 + t.fst_state * prime2;
164  }
165 };
166 
167 // Container for stack prefix.
168 template <class Label, class StateId>
170  public:
171  struct PrefixTuple {
173  : fst_id(fst_id), nextstate(nextstate) {}
174 
175  Label fst_id;
177  };
178 
180 
182  : prefix_(other.prefix_) {}
183 
184  void Push(StateId fst_id, StateId nextstate) {
185  prefix_.push_back(PrefixTuple(fst_id, nextstate));
186  }
187 
188  void Pop() { prefix_.pop_back(); }
189 
190  const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
191 
192  size_t Depth() const { return prefix_.size(); }
193 
194  public:
195  std::vector<PrefixTuple> prefix_;
196 };
197 
198 // Equality stack prefix classes.
199 template <class Label, class StateId>
202  if (x.prefix_.size() != y.prefix_.size()) return false;
203  for (size_t i = 0; i < x.prefix_.size(); ++i) {
204  if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
205  x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
206  return false;
207  }
208  }
209  return true;
210 }
211 
212 // Hash function for stack prefix to prefix id.
213 template <class Label, class StateId>
215  public:
216  size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
217  size_t sum = 0;
218  for (const auto &pair : prefix.prefix_) {
219  static constexpr size_t prime = 7863;
220  sum += pair.fst_id + pair.nextstate * prime;
221  }
222  return sum;
223  }
224 };
225 
226 // Replace state tables.
227 
228 // A two-level state table for replace. Warning: calls CountStates to compute
229 // the number of states of each component FST.
230 template <class Arc, class P = ssize_t>
232  public:
233  using Label = typename Arc::Label;
234  using StateId = typename Arc::StateId;
235 
236  using PrefixId = P;
237 
239  using StateTable =
245  using StackPrefixTable =
248 
250  const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
251  Label root)
252  : root_size_(0) {
253  size_array_.push_back(0);
254  for (const auto &[label, fst] : fst_list) {
255  if (label == root) {
256  root_size_ = CountStates(*fst);
257  size_array_.push_back(size_array_.back());
258  } else {
259  size_array_.push_back(size_array_.back() +
260  CountStates(*fst));
261  }
262  }
263  state_table_ = std::make_unique<StateTable>(
264  ReplaceRootSelector<StateId, PrefixId>(),
265  ReplaceFstStateFingerprint<StateId, PrefixId>(),
266  ReplaceFingerprint<StateId, PrefixId>(&size_array_), root_size_,
267  root_size_ + size_array_.back());
268  }
269 
272  : root_size_(table.root_size_),
273  size_array_(table.size_array_),
274  prefix_table_(table.prefix_table_) {
275  state_table_ = std::make_unique<StateTable>(
276  ReplaceRootSelector<StateId, PrefixId>(),
277  ReplaceFstStateFingerprint<StateId, PrefixId>(),
278  ReplaceFingerprint<StateId, PrefixId>(&size_array_), root_size_,
279  root_size_ + size_array_.back());
280  }
281 
282  StateId FindState(const StateTuple &tuple) {
283  return state_table_->FindState(tuple);
284  }
285 
286  const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
287 
288  PrefixId FindPrefixId(const StackPrefix &prefix) {
289  return prefix_table_.FindId(prefix);
290  }
291 
292  const StackPrefix &GetStackPrefix(PrefixId id) const {
293  return prefix_table_.FindEntry(id);
294  }
295 
296  private:
297  StateId root_size_;
298  std::vector<uint64_t> size_array_;
299  std::unique_ptr<StateTable> state_table_;
300  StackPrefixTable prefix_table_;
301 };
302 
303 // Default replace state table.
304 template <class Arc, class P /* = size_t */>
306  : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
307  ReplaceHash<typename Arc::StateId, P>> {
308  public:
309  using Label = typename Arc::Label;
310  using StateId = typename Arc::StateId;
311 
312  using PrefixId = P;
314  using StateTable =
317  using StackPrefixTable =
320 
321  using StateTable::FindState;
322  using StateTable::Tuple;
323 
325  const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
326 
328  : StateTable(), prefix_table_(table.prefix_table_) {}
329 
330  PrefixId FindPrefixId(const StackPrefix &prefix) {
331  return prefix_table_.FindId(prefix);
332  }
333 
334  const StackPrefix &GetStackPrefix(PrefixId id) const {
335  return prefix_table_.FindEntry(id);
336  }
337 
338  private:
339  StackPrefixTable prefix_table_;
340 };
341 
342 // By default ReplaceFst will copy the input label of the replace arc.
343 // The call_label_type and return_label_type options specify how to manage
344 // the labels of the call arc and the return arc of the replace FST
345 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
346  class CacheStore = DefaultCacheStore<Arc>>
347 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
348  using Label = typename Arc::Label;
349 
350  // Index of root rule for expansion.
352  // How to label call arc.
354  // How to label return arc.
356  // Specifies output label to put on call arc; if kNoLabel, use existing label
357  // on call arc. Otherwise, use this field as the output label.
358  Label call_output_label = kNoLabel;
359  // Specifies label to put on return arc.
360  Label return_label = 0;
361  // Take ownership of input FSTs?
362  bool take_ownership = false;
363  // Pointer to optional pre-constructed state table.
364  StateTable *state_table = nullptr;
365 
367  Label root = kNoLabel)
368  : CacheImplOptions<CacheStore>(opts), root(root) {}
369 
370  explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
371  : CacheImplOptions<CacheStore>(opts), root(root) {}
372 
373  // FIXME(kbg): There are too many constructors here. Come up with a consistent
374  // position for call_output_label (probably the very end) so that it is
375  // possible to express all the remaining constructors with a single
376  // default-argument constructor. Also move clients off of the "backwards
377  // compatibility" constructor, for good.
378 
379  explicit ReplaceFstOptions(Label root) : root(root) {}
380 
381  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
382  ReplaceLabelType return_label_type,
383  Label return_label)
384  : root(root),
385  call_label_type(call_label_type),
386  return_label_type(return_label_type),
387  return_label(return_label) {}
388 
389  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
390  ReplaceLabelType return_label_type,
391  Label call_output_label, Label return_label)
392  : root(root),
393  call_label_type(call_label_type),
394  return_label_type(return_label_type),
395  call_output_label(call_output_label),
396  return_label(return_label) {}
397 
398  explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
399  : ReplaceFstOptions(opts.root, opts.call_label_type,
400  opts.return_label_type, opts.return_label) {}
401 
403 
404  // For backwards compatibility.
405  ReplaceFstOptions(int64_t root, bool epsilon_replace_arc)
406  : root(root),
407  call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
409  call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
410 };
411 
412 // Forward declaration.
413 template <class Arc, class StateTable, class CacheStore>
415 
416 template <class Arc>
417 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
418 
419 // Returns true if label type on arc results in epsilon input label.
420 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
421  return label_type == REPLACE_LABEL_NEITHER ||
422  label_type == REPLACE_LABEL_OUTPUT;
423 }
424 
425 // Returns true if label type on arc results in epsilon input label.
426 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
427  return label_type == REPLACE_LABEL_NEITHER ||
428  label_type == REPLACE_LABEL_INPUT;
429 }
430 
431 // Returns true if for either the call or return arc ilabel != olabel.
432 template <class Label>
433 bool ReplaceTransducer(ReplaceLabelType call_label_type,
434  ReplaceLabelType return_label_type,
435  Label call_output_label) {
436  return call_label_type == REPLACE_LABEL_INPUT ||
437  call_label_type == REPLACE_LABEL_OUTPUT ||
438  (call_label_type == REPLACE_LABEL_BOTH &&
439  call_output_label != kNoLabel) ||
440  return_label_type == REPLACE_LABEL_INPUT ||
441  return_label_type == REPLACE_LABEL_OUTPUT;
442 }
443 
444 template <class Arc>
445 uint64_t ReplaceFstProperties(typename Arc::Label root_label,
446  const FstList<Arc> &fst_list,
447  ReplaceLabelType call_label_type,
448  ReplaceLabelType return_label_type,
449  typename Arc::Label call_output_label,
450  bool *sorted_and_non_empty) {
451  using Label = typename Arc::Label;
452  std::vector<uint64_t> inprops;
453  bool all_ilabel_sorted = true;
454  bool all_olabel_sorted = true;
455  bool all_non_empty = true;
456  // All nonterminals are negative?
457  bool all_negative = true;
458  // All nonterminals are positive and form a dense range containing 1?
459  bool dense_range = true;
460  Label root_fst_idx = 0;
461  for (Label i = 0; i < fst_list.size(); ++i) {
462  const auto label = fst_list[i].first;
463  if (label >= 0) all_negative = false;
464  if (label > fst_list.size() || label <= 0) dense_range = false;
465  if (label == root_label) root_fst_idx = i;
466  const auto *fst = fst_list[i].second;
467  if (fst->Start() == kNoStateId) all_non_empty = false;
468  if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
469  if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
470  inprops.push_back(fst->Properties(kCopyProperties, false));
471  }
472  const auto props = ReplaceProperties(
473  inprops, root_fst_idx, EpsilonOnInput(call_label_type),
474  EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
475  EpsilonOnOutput(return_label_type),
476  ReplaceTransducer(call_label_type, return_label_type, call_output_label),
477  all_non_empty, all_ilabel_sorted, all_olabel_sorted,
478  all_negative || dense_range);
479  const bool sorted = props & (kILabelSorted | kOLabelSorted);
480  *sorted_and_non_empty = all_non_empty && sorted;
481  return props;
482 }
483 
484 namespace internal {
485 
486 // The replace implementation class supports a dynamic expansion of a recursive
487 // transition network represented as label/FST pairs with dynamic replacable
488 // arcs.
489 template <class Arc, class StateTable, class CacheStore>
491  : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
492  public:
493  using Label = typename Arc::Label;
494  using StateId = typename Arc::StateId;
495  using Weight = typename Arc::Weight;
496 
497  using State = typename CacheStore::State;
499  using PrefixId = typename StateTable::PrefixId;
502  using NonTerminalHash = std::unordered_map<Label, Label>;
503 
504  using FstImpl<Arc>::SetType;
511 
512  using CacheImpl::HasArcs;
513  using CacheImpl::HasFinal;
514  using CacheImpl::HasStart;
515  using CacheImpl::PushArc;
516  using CacheImpl::SetArcs;
517  using CacheImpl::SetFinal;
518  using CacheImpl::SetStart;
519 
520  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
521 
522  ReplaceFstImpl(const FstList<Arc> &fst_list,
524  : CacheImpl(opts),
525  call_label_type_(opts.call_label_type),
526  return_label_type_(opts.return_label_type),
527  call_output_label_(opts.call_output_label),
528  return_label_(opts.return_label),
529  state_table_(opts.state_table ? opts.state_table
530  : new StateTable(fst_list, opts.root)) {
531  SetType("replace");
532  // If the label is epsilon, then all replace label options are equivalent,
533  // so we set the label types to NEITHER for simplicity.
534  if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
535  if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
536  if (!fst_list.empty()) {
537  SetInputSymbols(fst_list[0].second->InputSymbols());
538  SetOutputSymbols(fst_list[0].second->OutputSymbols());
539  }
540  fst_array_.push_back(nullptr);
541  for (Label i = 0; i < fst_list.size(); ++i) {
542  const auto label = fst_list[i].first;
543  const auto *fst = fst_list[i].second;
544  nonterminal_hash_[label] = fst_array_.size();
545  nonterminal_set_.insert(label);
546  fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
547  if (i) {
548  if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
549  FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
550  << " do not match input symbols of base FST (0th FST)";
551  SetProperties(kError, kError);
552  }
553  if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
554  FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
555  << " do not match output symbols of base FST (0th FST)";
556  SetProperties(kError, kError);
557  }
558  }
559  }
560  const auto nonterminal = nonterminal_hash_[opts.root];
561  if ((nonterminal == 0) && (fst_array_.size() > 1)) {
562  FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
563  << opts.root << " in the input tuple vector";
564  SetProperties(kError, kError);
565  }
566  root_ = (nonterminal > 0) ? nonterminal : 1;
567  bool all_non_empty_and_sorted = false;
568  SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
569  return_label_type_, call_output_label_,
570  &all_non_empty_and_sorted));
571  // Enables optional caching as long as sorted and all non-empty.
572  always_cache_ = !all_non_empty_and_sorted;
573  VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
574  << (always_cache_ ? "true" : "false");
575  }
576 
578  : CacheImpl(impl),
579  call_label_type_(impl.call_label_type_),
580  return_label_type_(impl.return_label_type_),
581  call_output_label_(impl.call_output_label_),
582  return_label_(impl.return_label_),
583  always_cache_(impl.always_cache_),
584  state_table_(new StateTable(*(impl.state_table_))),
585  nonterminal_set_(impl.nonterminal_set_),
586  nonterminal_hash_(impl.nonterminal_hash_),
587  root_(impl.root_) {
588  SetType("replace");
589  SetProperties(impl.Properties(), kCopyProperties);
590  SetInputSymbols(impl.InputSymbols());
591  SetOutputSymbols(impl.OutputSymbols());
592  fst_array_.reserve(impl.fst_array_.size());
593  fst_array_.emplace_back(nullptr);
594  for (Label i = 1; i < impl.fst_array_.size(); ++i) {
595  fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
596  }
597  }
598 
599  // Computes the dependency graph of the replace class and returns
600  // true if the dependencies are cyclic. Cyclic dependencies will result
601  // in an un-expandable FST.
602  bool CyclicDependencies() const {
603  const ReplaceUtilOptions opts(root_);
604  ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
605  return replace_util.CyclicDependencies();
606  }
607 
609  if (!HasStart()) {
610  if (fst_array_.size() == 1) {
611  SetStart(kNoStateId);
612  return kNoStateId;
613  } else {
614  const auto fst_start = fst_array_[root_]->Start();
615  if (fst_start == kNoStateId) return kNoStateId;
616  const auto prefix = GetPrefixId(StackPrefix());
617  const auto start =
618  state_table_->FindState(StateTuple(prefix, root_, fst_start));
619  SetStart(start);
620  return start;
621  }
622  } else {
623  return CacheImpl::Start();
624  }
625  }
626 
628  if (HasFinal(s)) return CacheImpl::Final(s);
629  const auto &tuple = state_table_->Tuple(s);
630  auto weight = Weight::Zero();
631  if (tuple.prefix_id == 0) {
632  const auto fst_state = tuple.fst_state;
633  weight = fst_array_[tuple.fst_id]->Final(fst_state);
634  }
635  if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
636  return weight;
637  }
638 
639  size_t NumArcs(StateId s) {
640  if (HasArcs(s)) {
641  return CacheImpl::NumArcs(s);
642  } else if (always_cache_) { // If always caching, expands and caches state.
643  Expand(s);
644  return CacheImpl::NumArcs(s);
645  } else { // Otherwise computes the number of arcs without expanding.
646  const auto tuple = state_table_->Tuple(s);
647  if (tuple.fst_state == kNoStateId) return 0;
648  auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
649  if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
650  return num_arcs;
651  }
652  }
653 
654  // Returns whether a given label is a non-terminal.
655  bool IsNonTerminal(Label label) const {
656  if (label < *nonterminal_set_.begin() ||
657  label > *nonterminal_set_.rbegin()) {
658  return false;
659  } else {
660  return nonterminal_hash_.count(label);
661  }
662  // TODO(allauzen): be smarter and take advantage of all_dense or
663  // all_negative. Also use this in ComputeArc. This would require changes to
664  // Replace so that recursing into an empty FST lead to a non co-accessible
665  // state instead of deleting the arc as done currently. The current use
666  // correct, since labels are sorted if all_non_empty is true.
667  }
668 
670  if (HasArcs(s)) {
671  return CacheImpl::NumInputEpsilons(s);
672  } else if (always_cache_ || !Properties(kILabelSorted)) {
673  // If always caching or if the number of input epsilons is too expensive
674  // to compute without caching (i.e., not ilabel-sorted), then expands and
675  // caches state.
676  Expand(s);
677  return CacheImpl::NumInputEpsilons(s);
678  } else {
679  // Otherwise, computes the number of input epsilons without caching.
680  const auto tuple = state_table_->Tuple(s);
681  if (tuple.fst_state == kNoStateId) return 0;
682  size_t num = 0;
683  if (!EpsilonOnInput(call_label_type_)) {
684  // If EpsilonOnInput(c) is false, all input epsilon arcs
685  // are also input epsilons arcs in the underlying machine.
686  num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
687  } else {
688  // Otherwise, one need to consider that all non-terminal arcs
689  // in the underlying machine also become input epsilon arc.
690  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
691  for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
692  IsNonTerminal(aiter.Value().olabel));
693  aiter.Next()) {
694  ++num;
695  }
696  }
697  if (EpsilonOnInput(return_label_type_) &&
698  ComputeFinalArc(tuple, nullptr)) {
699  ++num;
700  }
701  return num;
702  }
703  }
704 
706  if (HasArcs(s)) {
708  } else if (always_cache_ || !Properties(kOLabelSorted)) {
709  // If always caching or if the number of output epsilons is too expensive
710  // to compute without caching (i.e., not olabel-sorted), then expands and
711  // caches state.
712  Expand(s);
714  } else {
715  // Otherwise, computes the number of output epsilons without caching.
716  const auto tuple = state_table_->Tuple(s);
717  if (tuple.fst_state == kNoStateId) return 0;
718  size_t num = 0;
719  if (!EpsilonOnOutput(call_label_type_)) {
720  // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
721  // output epsilons arcs in the underlying machine.
722  num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
723  } else {
724  // Otherwise, one need to consider that all non-terminal arcs in the
725  // underlying machine also become output epsilon arc.
726  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
727  for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
728  IsNonTerminal(aiter.Value().olabel));
729  aiter.Next()) {
730  ++num;
731  }
732  }
733  if (EpsilonOnOutput(return_label_type_) &&
734  ComputeFinalArc(tuple, nullptr)) {
735  ++num;
736  }
737  return num;
738  }
739  }
740 
741  uint64_t Properties() const override { return Properties(kFstProperties); }
742 
743  // Sets error if found, and returns other FST impl properties.
744  uint64_t Properties(uint64_t mask) const override {
745  if (mask & kError) {
746  for (Label i = 1; i < fst_array_.size(); ++i) {
747  if (fst_array_[i]->Properties(kError, false)) {
748  SetProperties(kError, kError);
749  }
750  }
751  }
752  return FstImpl<Arc>::Properties(mask);
753  }
754 
755  // Returns the base arc iterator, and if arcs have not been computed yet,
756  // extends and recurses for new arcs.
758  if (!HasArcs(s)) Expand(s);
759  CacheImpl::InitArcIterator(s, data);
760  // TODO(allauzen): Set behaviour of generic iterator.
761  // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
762  // behaviour.
763  }
764 
765  // Extends current state (walk arcs one level deep).
766  void Expand(StateId s) {
767  const auto tuple = state_table_->Tuple(s);
768  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
769  SetArcs(s);
770  return;
771  }
772  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
773  Arc arc;
774  // Creates a final arc when needed.
775  if (ComputeFinalArc(tuple, &arc)) PushArc(s, std::move(arc));
776  // Expands all arcs leaving the state.
777  for (; !aiter.Done(); aiter.Next()) {
778  if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, std::move(arc));
779  }
780  SetArcs(s);
781  }
782 
783  void Expand(StateId s, const StateTuple &tuple,
784  const ArcIteratorData<Arc> &data) {
785  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
786  SetArcs(s);
787  return;
788  }
789  ArcIterator<Fst<Arc>> aiter(data);
790  Arc arc;
791  // Creates a final arc when needed.
792  if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
793  // Expands all arcs leaving the state.
794  for (; !aiter.Done(); aiter.Next()) {
795  if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
796  }
797  SetArcs(s);
798  }
799 
800  // If acpp is null, only returns true if a final arcp is required, but does
801  // not actually compute it.
802  bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
803  uint8_t flags = kArcValueFlags) {
804  const auto fst_state = tuple.fst_state;
805  if (fst_state == kNoStateId) return false;
806  // If state is final, pops the stack.
807  if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
808  tuple.prefix_id) {
809  if (arcp) {
810  arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
811  arcp->olabel =
812  (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
813  if (flags & kArcNextStateValue) {
814  const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
815  const auto prefix_id = PopPrefix(stack);
816  const auto &top = stack.Top();
817  arcp->nextstate = state_table_->FindState(
818  StateTuple(prefix_id, top.fst_id, top.nextstate));
819  }
820  if (flags & kArcWeightValue) {
821  arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
822  }
823  }
824  return true;
825  } else {
826  return false;
827  }
828  }
829 
830  // Computes an arc in the FST corresponding to one in the underlying machine.
831  // Returns false if the underlying arc corresponds to no arc in the resulting
832  // FST.
833  bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
834  uint8_t flags = kArcValueFlags) {
835  if (!EpsilonOnInput(call_label_type_) &&
836  (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
837  *arcp = arc;
838  return true;
839  }
840  if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
841  arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST.
842  const auto nextstate =
843  flags & kArcNextStateValue
844  ? state_table_->FindState(
845  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
846  : kNoStateId;
847  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
848  } else {
849  // Checks for non-terminal.
850  const auto it = nonterminal_hash_.find(arc.olabel);
851  if (it != nonterminal_hash_.end()) { // Recurses into non-terminal.
852  const auto nonterminal = it->second;
853  const auto nt_prefix =
854  PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
855  tuple.fst_id, arc.nextstate);
856  // If the start state is valid, replace; othewise, the arc is implicitly
857  // deleted.
858  const auto nt_start = fst_array_[nonterminal]->Start();
859  if (nt_start != kNoStateId) {
860  const auto nt_nextstate = flags & kArcNextStateValue
861  ? state_table_->FindState(StateTuple(
862  nt_prefix, nonterminal, nt_start))
863  : kNoStateId;
864  const auto ilabel =
865  (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
866  const auto olabel =
867  (EpsilonOnOutput(call_label_type_))
868  ? 0
869  : ((call_output_label_ == kNoLabel) ? arc.olabel
870  : call_output_label_);
871  *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
872  } else {
873  return false;
874  }
875  } else {
876  const auto nextstate =
877  flags & kArcNextStateValue
878  ? state_table_->FindState(
879  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
880  : kNoStateId;
881  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
882  }
883  }
884  return true;
885  }
886 
887  // Returns the arc iterator flags supported by this FST.
888  uint8_t ArcIteratorFlags() const {
889  uint8_t flags = kArcValueFlags;
890  if (!always_cache_) flags |= kArcNoCache;
891  return flags;
892  }
893 
894  StateTable *GetStateTable() const { return state_table_.get(); }
895 
896  const Fst<Arc> *GetFst(Label fst_id) const {
897  return fst_array_[fst_id].get();
898  }
899 
900  Label GetFstId(Label nonterminal) const {
901  const auto it = nonterminal_hash_.find(nonterminal);
902  if (it == nonterminal_hash_.end()) {
903  FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
904  << nonterminal;
905  }
906  return it->second;
907  }
908 
909  // Returns true if label type on call arc results in epsilon input label.
910  bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
911 
912  private:
913  // The unique index into stack prefix table.
914  PrefixId GetPrefixId(const StackPrefix &prefix) {
915  return state_table_->FindPrefixId(prefix);
916  }
917 
918  // The prefix ID after a stack pop.
919  PrefixId PopPrefix(StackPrefix prefix) {
920  prefix.Pop();
921  return GetPrefixId(prefix);
922  }
923 
924  // The prefix ID after a stack push.
925  PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
926  prefix.Push(fst_id, nextstate);
927  return GetPrefixId(prefix);
928  }
929 
930  // Runtime options
931  ReplaceLabelType call_label_type_; // How to label call arc.
932  ReplaceLabelType return_label_type_; // How to label return arc.
933  int64_t call_output_label_; // Specifies output label to put on call arc
934  int64_t return_label_; // Specifies label to put on return arc.
935  bool always_cache_; // Disable optional caching of arc iterator?
936 
937  // State table.
938  std::unique_ptr<StateTable> state_table_;
939 
940  // Replace components.
941  std::set<Label> nonterminal_set_;
942  NonTerminalHash nonterminal_hash_;
943  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
944  Label root_;
945 };
946 
947 } // namespace internal
948 
949 //
950 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
951 // This replacement is recursive. ReplaceFst can be used to support a variety of
952 // delayed constructions such as recursive
953 // transition networks, union, or closure. It is constructed with an array of
954 // FST(s). One FST represents the root (or topology) machine. The root FST
955 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
956 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
957 // symbols of the arcs to determine whether the arc is a non-terminal arc or
958 // not. A non-terminal can be any label that is not a non-zero terminal label in
959 // the output alphabet.
960 //
961 // Note that the constructor uses a vector of pairs. These correspond to the
962 // tuple of non-terminal Label and corresponding FST. For example to implement
963 // the closure operation we need 2 FSTs. The first root FST is a single
964 // self-loop arc on the start state.
965 //
966 // The ReplaceFst class supports an optionally caching arc iterator.
967 //
968 // The ReplaceFst needs to be built such that it is known to be ilabel- or
969 // olabel-sorted (see usage below).
970 //
971 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
972 // when available (the FST is ilabel-sorted and matching on the input, or the
973 // FST is olabel -orted and matching on the output). In order to obtain the
974 // most efficient behaviour, it is recommended to set call_label_type to
975 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
976 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
977 // does not have epsilon on the input side and the return arc has epsilon on the
978 // input side) and matching on the input side.
979 //
980 // This class attaches interface to implementation and handles reference
981 // counting, delegating most methods to ImplToFst.
982 template <class A, class T /* = DefaultReplaceStateTable<A> */,
983  class CacheStore /* = DefaultCacheStore<A> */>
984 class ReplaceFst
985  : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
986  public:
987  using Arc = A;
988  using Label = typename Arc::Label;
989  using StateId = typename Arc::StateId;
990  using Weight = typename Arc::Weight;
991 
992  using StateTable = T;
993  using Store = CacheStore;
994  using State = typename CacheStore::State;
997 
999 
1000  friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
1001  friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
1002  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
1003 
1004  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
1005  Label root)
1006  : ImplToFst<Impl>(std::make_shared<Impl>(
1007  fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
1008 
1009  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
1011  : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
1012 
1013  // See Fst<>::Copy() for doc.
1014  ReplaceFst(const ReplaceFst &fst, bool safe = false)
1015  : ImplToFst<Impl>(fst, safe) {}
1016 
1017  // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
1018  ReplaceFst *Copy(bool safe = false) const override {
1019  return new ReplaceFst(*this, safe);
1020  }
1021 
1022  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
1023 
1024  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
1025  GetMutableImpl()->InitArcIterator(s, data);
1026  }
1027 
1028  MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
1029  if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1030  ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
1031  (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
1033  match_type);
1034  } else {
1035  VLOG(2) << "Not using replace matcher";
1036  return nullptr;
1037  }
1038  }
1039 
1040  bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1041 
1042  const StateTable &GetStateTable() const {
1043  return *GetImpl()->GetStateTable();
1044  }
1045 
1046  const Fst<Arc> &GetFst(Label nonterminal) const {
1047  return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1048  }
1049 
1050  private:
1053 
1054  ReplaceFst &operator=(const ReplaceFst &) = delete;
1055 };
1056 
1057 // Specialization for ReplaceFst.
1058 template <class Arc, class StateTable, class CacheStore>
1059 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1060  : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1061  public:
1063  : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1064  fst, fst.GetMutableImpl()) {}
1065 };
1066 
1067 // Specialization for ReplaceFst, implementing optional caching. It is be used
1068 // as follows:
1069 //
1070 // ReplaceFst<A> replace;
1071 // ArcIterator<ReplaceFst<A>> aiter(replace, s);
1072 // // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1073 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1074 // // Uses the arc iterator, no arc will be cached, no state will be expanded.
1075 // // Arc flags can be used to decide which component of the arc need to be
1076 // computed.
1077 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1078 // // Wants the ilabel for this arc.
1079 // aiter.Value(); // Does not compute the destination state.
1080 // aiter.Next();
1081 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1082 // // Wants the ilabel and next state for this arc.
1083 // aiter.Value(); // Does compute the destination state and inserts it
1084 // // in the replace state table.
1085 // // No additional arcs have been cached at this point.
1086 template <class Arc, class StateTable, class CacheStore>
1087 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1088  public:
1089  using StateId = typename Arc::StateId;
1090 
1091  using StateTuple = typename StateTable::StateTuple;
1092 
1094  : fst_(fst),
1095  s_(s),
1096  pos_(0),
1097  offset_(0),
1098  flags_(kArcValueFlags),
1099  arcs_(nullptr),
1100  data_flags_(0),
1101  final_flags_(0) {
1102  cache_data_.ref_count = nullptr;
1103  local_data_.ref_count = nullptr;
1104  // If FST does not support optional caching, forces caching.
1105  if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1106  !(fst_.GetImpl()->HasArcs(s_))) {
1107  fst_.GetMutableImpl()->Expand(s_);
1108  }
1109  // If state is already cached, use cached arcs array.
1110  if (fst_.GetImpl()->HasArcs(s_)) {
1111  (fst_.GetImpl())
1112  ->internal::template CacheBaseImpl<
1113  typename CacheStore::State,
1114  CacheStore>::InitArcIterator(s_, &cache_data_);
1115  num_arcs_ = cache_data_.narcs;
1116  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1117  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1118  } else { // Otherwise delay decision until Value() is called.
1119  tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1120  if (tuple_.fst_state == kNoStateId) {
1121  num_arcs_ = 0;
1122  } else {
1123  // The decision to cache or not to cache has been defered until Value()
1124  // or
1125  // SetFlags() is called. However, the arc iterator is set up now to be
1126  // ready for non-caching in order to keep the Value() method simple and
1127  // efficient.
1128  const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1129  rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1130  // arcs_ is a pointer to the arcs in the underlying machine.
1131  arcs_ = local_data_.arcs;
1132  // Computes the final arc (but not its destination state) if a final arc
1133  // is required.
1134  bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1135  tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1136  // Sets the arc value flags that hold for final_arc_.
1137  final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1138  // Computes the number of arcs.
1139  num_arcs_ = local_data_.narcs;
1140  if (has_final_arc) ++num_arcs_;
1141  // Sets the offset between the underlying arc positions and the
1142  // positions
1143  // in the arc iterator.
1144  offset_ = num_arcs_ - local_data_.narcs;
1145  // Defers the decision to cache or not until Value() or SetFlags() is
1146  // called.
1147  data_flags_ = 0;
1148  }
1149  }
1150  }
1151 
1153  if (cache_data_.ref_count) --(*cache_data_.ref_count);
1154  if (local_data_.ref_count) --(*local_data_.ref_count);
1155  }
1156 
1157  void ExpandAndCache() const {
1158  // TODO(allauzen): revisit this.
1159  // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1160  // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1161  // &cache_data_);
1162  //
1163  fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state.
1164  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1165  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1166  offset_ = 0; // No offset.
1167  }
1168 
1169  void Init() {
1170  if (flags_ & kArcNoCache) { // If caching is disabled
1171  // arcs_ is a pointer to the arcs in the underlying machine.
1172  arcs_ = local_data_.arcs;
1173  // Sets the arcs value flags that hold for arcs_.
1174  data_flags_ = kArcWeightValue;
1175  if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1176  data_flags_ |= kArcILabelValue;
1177  }
1178  // Sets the offset between the underlying arc positions and the positions
1179  // in the arc iterator.
1180  offset_ = num_arcs_ - local_data_.narcs;
1181  } else {
1182  ExpandAndCache();
1183  }
1184  }
1185 
1186  bool Done() const { return pos_ >= num_arcs_; }
1187 
1188  const Arc &Value() const {
1189  // If data_flags_ is 0, non-caching was not requested.
1190  if (!data_flags_) {
1191  // TODO(allauzen): Revisit this.
1192  if (flags_ & kArcNoCache) {
1193  // Should never happen.
1194  FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1195  }
1196  ExpandAndCache();
1197  }
1198  if (pos_ - offset_ >= 0) { // The requested arc is not the final arc.
1199  const auto &arc = arcs_[pos_ - offset_];
1200  if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1201  // If the value flags match the recquired value flags then returns the
1202  // arc.
1203  return arc;
1204  } else {
1205  // Otherwise, compute the corresponding arc on-the-fly.
1206  fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1207  flags_ & kArcValueFlags);
1208  return arc_;
1209  }
1210  } else { // The requested arc is the final arc.
1211  if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1212  // If the arc value flags that hold for the final arc do not match the
1213  // requested value flags, then
1214  // final_arc_ needs to be updated.
1215  fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1216  flags_ & kArcValueFlags);
1217  final_flags_ = flags_ & kArcValueFlags;
1218  }
1219  return final_arc_;
1220  }
1221  }
1222 
1223  void Next() { ++pos_; }
1224 
1225  size_t Position() const { return pos_; }
1226 
1227  void Reset() { pos_ = 0; }
1228 
1229  void Seek(size_t pos) { pos_ = pos; }
1230 
1231  uint8_t Flags() const { return flags_; }
1232 
1233  void SetFlags(uint8_t flags, uint8_t mask) {
1234  // Updates the flags taking into account what flags are supported
1235  // by the FST.
1236  flags_ &= ~mask;
1237  flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1238  // If non-caching is not requested (and caching has not already been
1239  // performed), then flush data_flags_ to request caching during the next
1240  // call to Value().
1241  if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1242  if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1243  }
1244  // If data_flags_ has been flushed but non-caching is requested before
1245  // calling Value(), then set up the iterator for non-caching.
1246  if ((flags & kArcNoCache) && (!data_flags_)) Init();
1247  }
1248 
1249  private:
1250  const ReplaceFst<Arc, StateTable, CacheStore> &fst_; // Reference to the FST.
1251  StateId s_; // State in the FST.
1252  mutable StateTuple tuple_; // Tuple corresponding to state_.
1253 
1254  ssize_t pos_; // Current position.
1255  mutable ssize_t offset_; // Offset between position in iterator and in arcs_.
1256  ssize_t num_arcs_; // Number of arcs at state_.
1257  uint8_t flags_; // Behavorial flags for the arc iterator
1258  mutable Arc arc_; // Memory to temporarily store computed arcs.
1259 
1260  mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache.
1261  mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local FST.
1262 
1263  mutable const Arc *arcs_; // Array of arcs.
1264  mutable uint8_t data_flags_; // Arc value flags valid for data in arcs_.
1265  mutable Arc final_arc_; // Final arc (when required).
1266  mutable uint8_t final_flags_; // Arc value flags valid for final_arc_.
1267 
1268  ArcIterator(const ArcIterator &) = delete;
1269  ArcIterator &operator=(const ArcIterator &) = delete;
1270 };
1271 
1272 template <class Arc, class StateTable, class CacheStore>
1273 class ReplaceFstMatcher : public MatcherBase<Arc> {
1274  public:
1275  using Label = typename Arc::Label;
1276  using StateId = typename Arc::StateId;
1277  using Weight = typename Arc::Weight;
1278 
1281 
1282  using StateTuple = typename StateTable::StateTuple;
1283 
1284  // This makes a copy of the FST.
1286  MatchType match_type)
1287  : owned_fst_(fst.Copy()),
1288  fst_(*owned_fst_),
1289  impl_(fst_.GetMutableImpl()),
1290  s_(fst::kNoStateId),
1291  match_type_(match_type),
1292  current_loop_(false),
1293  final_arc_(false),
1294  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1295  if (match_type_ == fst::MATCH_OUTPUT) {
1296  std::swap(loop_.ilabel, loop_.olabel);
1297  }
1298  InitMatchers();
1299  }
1300 
1301  // This doesn't copy the FST.
1303  MatchType match_type)
1304  : fst_(*fst),
1305  impl_(fst_.GetMutableImpl()),
1306  s_(fst::kNoStateId),
1307  match_type_(match_type),
1308  current_loop_(false),
1309  final_arc_(false),
1310  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1311  if (match_type_ == fst::MATCH_OUTPUT) {
1312  std::swap(loop_.ilabel, loop_.olabel);
1313  }
1314  InitMatchers();
1315  }
1316 
1317  // This makes a copy of the FST.
1318  ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe = false)
1319  : owned_fst_(matcher.fst_.Copy(safe)),
1320  fst_(*owned_fst_),
1321  impl_(fst_.GetMutableImpl()),
1322  s_(fst::kNoStateId),
1323  match_type_(matcher.match_type_),
1324  current_loop_(false),
1325  final_arc_(false),
1326  loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1327  if (match_type_ == fst::MATCH_OUTPUT) {
1328  std::swap(loop_.ilabel, loop_.olabel);
1329  }
1330  InitMatchers();
1331  }
1332 
1333  // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1334  // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1335  // non-terminal arc, since these non-terminal
1336  // turn into epsilons on recursion.
1337  void InitMatchers() {
1338  const auto &fst_array = impl_->fst_array_;
1339  matcher_.resize(fst_array.size());
1340  for (Label i = 0; i < fst_array.size(); ++i) {
1341  if (fst_array[i]) {
1342  matcher_[i] = std::make_unique<LocalMatcher>(
1343  *fst_array[i], match_type_, kMultiEpsList);
1344  auto it = impl_->nonterminal_set_.begin();
1345  for (; it != impl_->nonterminal_set_.end(); ++it) {
1346  matcher_[i]->AddMultiEpsLabel(*it);
1347  }
1348  }
1349  }
1350  }
1351 
1352  ReplaceFstMatcher *Copy(bool safe = false) const override {
1353  return new ReplaceFstMatcher(*this, safe);
1354  }
1355 
1356  MatchType Type(bool test) const override {
1357  if (match_type_ == MATCH_NONE) return match_type_;
1358  const auto true_prop =
1359  match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1360  const auto false_prop =
1361  match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1362  const auto props = fst_.Properties(true_prop | false_prop, test);
1363  if (props & true_prop) {
1364  return match_type_;
1365  } else if (props & false_prop) {
1366  return MATCH_NONE;
1367  } else {
1368  return MATCH_UNKNOWN;
1369  }
1370  }
1371 
1372  const Fst<Arc> &GetFst() const override { return fst_; }
1373 
1374  uint64_t Properties(uint64_t props) const override { return props; }
1375 
1376  // Sets the state from which our matching happens.
1377  void SetState(StateId s) final {
1378  if (s_ == s) return;
1379  s_ = s;
1380  tuple_ = impl_->GetStateTable()->Tuple(s_);
1381  if (tuple_.fst_state == kNoStateId) {
1382  done_ = true;
1383  return;
1384  }
1385  // Gets current matcher, used for non-epsilon matching.
1386  current_matcher_ = matcher_[tuple_.fst_id].get();
1387  current_matcher_->SetState(tuple_.fst_state);
1388  loop_.nextstate = s_;
1389  final_arc_ = false;
1390  }
1391 
1392  // Searches for label from previous set state. If label == 0, first
1393  // hallucinate an epsilon loop; otherwise use the underlying matcher to
1394  // search for the label or epsilons. Note since the ReplaceFst recursion
1395  // on non-terminal arcs causes epsilon transitions to be created we use
1396  // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1397  // component FST
1398  // reaches a final state we also need to add the exiting final arc.
1399  bool Find(Label label) final {
1400  bool found = false;
1401  label_ = label;
1402  if (label_ == 0 || label_ == kNoLabel) {
1403  // Computes loop directly, avoiding Replace::ComputeArc.
1404  if (label_ == 0) {
1405  current_loop_ = true;
1406  found = true;
1407  }
1408  // Searches for matching multi-epsilons.
1409  final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1410  found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1411  } else {
1412  // Searches on a sub machine directly using sub machine matcher.
1413  found = current_matcher_->Find(label_);
1414  }
1415  return found;
1416  }
1417 
1418  bool Done() const final {
1419  return !current_loop_ && !final_arc_ && current_matcher_->Done();
1420  }
1421 
1422  const Arc &Value() const final {
1423  if (current_loop_) return loop_;
1424  if (final_arc_) {
1425  impl_->ComputeFinalArc(tuple_, &arc_);
1426  return arc_;
1427  }
1428  const auto &component_arc = current_matcher_->Value();
1429  impl_->ComputeArc(tuple_, component_arc, &arc_);
1430  return arc_;
1431  }
1432 
1433  void Next() final {
1434  if (current_loop_) {
1435  current_loop_ = false;
1436  return;
1437  }
1438  if (final_arc_) {
1439  final_arc_ = false;
1440  return;
1441  }
1442  current_matcher_->Next();
1443  }
1444 
1445  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1446 
1447  private:
1448  std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
1451  LocalMatcher *current_matcher_;
1452  std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1453  StateId s_; // Current state.
1454  Label label_; // Current label.
1455  MatchType match_type_; // Supplied by caller.
1456  mutable bool done_;
1457  mutable bool current_loop_; // Current arc is the implicit loop.
1458  mutable bool final_arc_; // Current arc for exiting recursion.
1459  mutable StateTuple tuple_; // Tuple corresponding to state_.
1460  mutable Arc arc_;
1461  Arc loop_;
1462 
1463  ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1464 };
1465 
1466 template <class Arc, class StateTable, class CacheStore>
1468  StateIteratorData<Arc> *data) const {
1469  data->base =
1470  std::make_unique<StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>>(
1471  *this);
1472 }
1473 
1475 
1476 // Recursively replaces arcs in the root FSTs with other FSTs.
1477 // This version writes the result of replacement to an output MutableFst.
1478 //
1479 // Replace supports replacement of arcs in one Fst with another FST. This
1480 // replacement is recursive. Replace takes an array of FST(s). One FST
1481 // represents the root (or topology) machine. The root FST refers to other FSTs
1482 // by recursively replacing arcs labeled as non-terminals with the matching
1483 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1484 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1485 // any label that is not a non-zero terminal label in the output alphabet.
1486 //
1487 // Note that input argument is a vector of pairs. These correspond to the tuple
1488 // of non-terminal Label and corresponding FST.
1489 template <class Arc>
1490 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1491  &ifst_array,
1492  MutableFst<Arc> *ofst,
1494  opts.gc = true;
1495  opts.gc_limit = 0; // Caches only the last state for fastest copy.
1496  *ofst = ReplaceFst<Arc>(ifst_array, opts);
1497 }
1498 
1499 template <class Arc>
1500 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1501  &ifst_array,
1502  MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1503  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1504 }
1505 
1506 // For backwards compatibility.
1507 template <class Arc>
1508 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1509  &ifst_array,
1510  MutableFst<Arc> *ofst, typename Arc::Label root,
1511  bool epsilon_on_replace) {
1512  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1513 }
1514 
1515 template <class Arc>
1516 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1517  &ifst_array,
1518  MutableFst<Arc> *ofst, typename Arc::Label root) {
1519  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1520 }
1521 
1522 } // namespace fst
1523 
1524 #endif // FST_REPLACE_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:114
ReplaceFstOptions(const CacheOptions &opts, Label root=kNoLabel)
Definition: replace.h:370
ReplaceFingerprint(const std::vector< uint64_t > *size_array)
Definition: replace.h:114
uint64_t Properties(uint64_t mask) const override
Definition: replace.h:744
typename CacheStore::State State
Definition: replace.h:497
void SetState(StateId s) final
Definition: replace.h:1377
constexpr uint8_t kArcValueFlags
Definition: fst.h:453
ArcIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst, StateId s)
Definition: replace.h:1093
const StateTuple & Tuple(StateId id) const
Definition: replace.h:286
constexpr int kNoLabel
Definition: fst.h:201
std::vector< PrefixTuple > prefix_
Definition: replace.h:195
MatcherBase< Arc > * InitMatcher(MatchType match_type) const override
Definition: replace.h:1028
void Expand(StateId s, const StateTuple &tuple, const ArcIteratorData< Arc > &data)
Definition: replace.h:783
constexpr uint8_t kArcNoCache
Definition: fst.h:452
const StateTable & GetStateTable() const
Definition: replace.h:1042
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:334
void Next() final
Definition: replace.h:1433
const Fst< Arc > * GetFst(Label fst_id) const
Definition: replace.h:896
ReplaceStateTuple(PrefixId prefix_id=-1, StateId fst_id=kNoStateId, StateId fst_state=kNoStateId)
Definition: replace.h:83
typename StateTable::PrefixId PrefixId
Definition: replace.h:499
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:330
ReplaceFstOptions(const ReplaceUtilOptions &opts)
Definition: replace.h:398
PrefixId prefix_id
Definition: replace.h:87
size_t NumOutputEpsilons(StateId s)
Definition: replace.h:705
bool EpsilonOnOutput(ReplaceLabelType label_type)
Definition: replace.h:426
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > *fst, MatchType match_type)
Definition: replace.h:1302
typename CacheStore::State State
Definition: replace.h:994
ReplaceLabelType
Definition: replace-util.h:43
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:97
constexpr uint32_t kMultiEpsList
Definition: matcher.h:1227
const SymbolTable * OutputSymbols() const
Definition: fst.h:762
MatchType
Definition: fst.h:193
uint8_t ArcIteratorFlags() const
Definition: replace.h:888
uint64_t Properties(uint64_t props) const override
Definition: replace.h:1374
constexpr uint64_t kError
Definition: properties.h:51
uint64_t Properties() const override
Definition: replace.h:741
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:1009
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: replace.h:1024
bool ReplaceTransducer(ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label)
Definition: replace.h:433
uint64_t operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:130
void Push(StateId fst_id, StateId nextstate)
Definition: replace.h:184
VectorHashReplaceStateTable(const VectorHashReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:270
SetType
Definition: set-weight.h:52
MatchType Type(bool test) const override
Definition: replace.h:1356
typename Arc::Weight Weight
Definition: replace.h:1277
typename ReplaceFst< Arc, StateTable, CacheStore >::Arc Arc
Definition: cache.h:1149
constexpr uint8_t kArcILabelValue
Definition: fst.h:444
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:103
constexpr uint64_t kNotOLabelSorted
Definition: properties.h:100
bool EpsilonOnInput(ReplaceLabelType label_type)
Definition: replace.h:420
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:793
bool CyclicDependencies() const
Definition: replace.h:1040
VectorHashReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_list, Label root)
Definition: replace.h:249
ReplaceFstOptions(const CacheImplOptions< CacheStore > &opts, Label root=kNoLabel)
Definition: replace.h:366
bool Find(Label label) final
Definition: replace.h:1399
const PrefixTuple & Top() const
Definition: replace.h:190
const Fst< Arc > & GetFst(Label nonterminal) const
Definition: replace.h:1046
constexpr int kNoStateId
Definition: fst.h:202
const Arc & Value() const
Definition: fst.h:537
DefaultReplaceStateTable(const DefaultReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:327
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: replace.h:757
std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> FstList
Definition: replace.h:417
bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp, uint8_t flags=kArcValueFlags)
Definition: replace.h:833
#define FSTERROR()
Definition: util.h:53
bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp, uint8_t flags=kArcValueFlags)
Definition: replace.h:802
ReplaceFstImpl(const ReplaceFstImpl &impl)
Definition: replace.h:577
size_t operator()(const ReplaceStateTuple< S, P > &t) const
Definition: replace.h:139
constexpr uint64_t kNotILabelSorted
Definition: properties.h:95
typename Arc::Weight Weight
Definition: replace.h:495
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:108
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: replace.h:1467
constexpr uint64_t kOLabelSorted
Definition: properties.h:98
constexpr uint64_t kCopyProperties
Definition: properties.h:162
typename Arc::StateId StateId
Definition: replace.h:234
uint64_t ReplaceProperties(const std::vector< uint64_t > &inprops, size_t root, bool epsilon_on_call, bool epsilon_on_return, bool out_epsilon_on_call, bool out_epsilon_on_return, bool replace_transducer, bool no_empty_fst, bool all_ilabel_sorted, bool all_olabel_sorted, bool all_negative_or_dense)
Definition: properties.cc:252
const Fst< Arc > & GetFst() const override
Definition: replace.h:1372
Weight Final(StateId s)
Definition: replace.h:627
ReplaceFstOptions(int64_t root, bool epsilon_replace_arc)
Definition: replace.h:405
#define VLOG(level)
Definition: log.h:50
size_t NumArcs(StateId s)
Definition: replace.h:639
typename Arc::Label Label
Definition: replace.h:233
bool Done() const
Definition: fst.h:533
ReplaceFst * Copy(bool safe=false) const override
Definition: replace.h:1018
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:288
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:292
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, Label root)
Definition: replace.h:1004
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label return_label)
Definition: replace.h:381
uint64_t operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:117
typename Arc::Label Label
Definition: replace.h:348
constexpr uint8_t kArcWeightValue
Definition: fst.h:448
constexpr uint64_t kILabelSorted
Definition: properties.h:93
ReplaceFstMatcher * Copy(bool safe=false) const override
Definition: replace.h:1352
void Expand(StateId s)
Definition: replace.h:766
uint64_t ReplaceFstProperties(typename Arc::Label root_label, const FstList< Arc > &fst_list, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, typename Arc::Label call_output_label, bool *sorted_and_non_empty)
Definition: replace.h:445
constexpr uint64_t kFstProperties
Definition: properties.h:325
ReplaceFst(const ReplaceFst &fst, bool safe=false)
Definition: replace.h:1014
const Arc & Value() const final
Definition: replace.h:1422
Label GetFstId(Label nonterminal) const
Definition: replace.h:900
typename internal::ReplaceFstImpl< Arc, StateTable, CacheStore >::Arc Arc
Definition: fst.h:211
size_t Depth() const
Definition: replace.h:192
const SymbolTable * InputSymbols() const
Definition: fst.h:760
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:169
constexpr uint8_t kArcNextStateValue
Definition: fst.h:450
bool CyclicDependencies() const
Definition: replace-util.h:133
typename Arc::Label Label
Definition: replace.h:1275
size_t operator()(const ReplaceStackPrefix< Label, StateId > &prefix) const
Definition: replace.h:216
size_t NumInputEpsilons(StateId s)
Definition: replace.h:669
ssize_t Priority(StateId s) final
Definition: replace.h:1445
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label, Label return_label)
Definition: replace.h:389
ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe=false)
Definition: replace.h:1318
ReplaceFstOptions(Label root)
Definition: replace.h:379
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:50
bool CyclicDependencies() const
Definition: replace.h:602
typename StateTable::StateTuple StateTuple
Definition: replace.h:1282
ReplaceFstImpl(const FstList< Arc > &fst_list, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:522
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > &fst, MatchType match_type)
Definition: replace.h:1285
typename Arc::StateId StateId
Definition: replace.h:310
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:316
ReplaceStackPrefix(const ReplaceStackPrefix &other)
Definition: replace.h:181
typename Arc::StateId StateId
Definition: replace.h:494
bool Done() const final
Definition: replace.h:1418
StateId FindState(const StateTuple &tuple)
Definition: replace.h:282
bool IsNonTerminal(Label label) const
Definition: replace.h:655
typename Arc::StateId StateId
Definition: replace.h:1276
DefaultReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &, Label)
Definition: replace.h:324
StateIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst)
Definition: replace.h:1062
typename Arc::Label Label
Definition: replace.h:493
PrefixTuple(Label fst_id=kNoLabel, StateId nextstate=kNoStateId)
Definition: replace.h:172
bool operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:105
StateTable * GetStateTable() const
Definition: replace.h:894
std::unordered_map< Label, Label > NonTerminalHash
Definition: replace.h:502
void Next()
Definition: fst.h:541