FST  openfst-1.8.3
OpenFst Library
replace.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions and classes for the recursive replacement of FSTs.
19 
20 #ifndef FST_REPLACE_H_
21 #define FST_REPLACE_H_
22 
23 #include <sys/types.h>
24 
25 #include <cstddef>
26 #include <cstdint>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <utility>
31 #include <vector>
32 
33 #include <fst/log.h>
34 #include <fst/arc.h>
35 #include <fst/bi-table.h>
36 #include <fst/cache.h>
37 #include <fst/expanded-fst.h>
38 #include <fst/float-weight.h>
39 #include <fst/fst-decl.h> // For optional argument declarations.
40 #include <fst/fst.h>
41 #include <fst/impl-to-fst.h>
42 #include <fst/matcher.h>
43 #include <fst/mutable-fst.h>
44 #include <fst/properties.h>
45 #include <fst/replace-util.h>
46 #include <fst/state-table.h>
47 #include <fst/symbol-table.h>
48 #include <fst/util.h>
49 #include <unordered_map>
50 
51 namespace fst {
52 
53 // Replace state tables have the form:
54 //
55 // template <class Arc, class P>
56 // class ReplaceStateTable {
57 // public:
58 // using Label = typename Arc::Label Label;
59 // using StateId = typename Arc::StateId;
60 //
61 // using PrefixId = P;
62 // using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
63 // using StackPrefix = ReplaceStackPrefix<Label, StateId>;
64 //
65 // // Required constructor.
66 // ReplaceStateTable(
67 // const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
68 // Label root);
69 //
70 // // Required copy constructor that does not copy state.
71 // ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
72 //
73 // // Looks up state ID by tuple, adding it if it doesn't exist.
74 // StateId FindState(const StateTuple &tuple);
75 //
76 // // Looks up state tuple by ID.
77 // const StateTuple &Tuple(StateId id) const;
78 //
79 // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
80 // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
81 //
82 // // Looks up stack prefix by ID.
83 // const StackPrefix &GetStackPrefix(PrefixId id) const;
84 // };
85 
86 // Tuple that uniquely defines a state in replace.
87 template <class S, class P>
89  using StateId = S;
90  using PrefixId = P;
91 
95 
96  template <typename H>
97  friend H AbslHashValue(H h, const ReplaceStateTuple &t) {
98  return H::combine(std::move(h), t.prefix_id, t.fst_id, t.fst_state);
99  }
100 
101  PrefixId prefix_id; // Index in prefix table.
102  StateId fst_id; // Current FST being walked.
103  StateId fst_state; // Current state in FST being walked (not to be
104  // confused with the thse StateId of the combined FST).
105 };
106 
107 // Equality of replace state tuples.
108 template <class StateId, class PrefixId>
111  return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
112  x.fst_state == y.fst_state;
113 }
114 
115 // Functor returning true for tuples corresponding to states in the root FST.
116 template <class StateId, class PrefixId>
118  public:
120  return tuple.prefix_id == 0;
121  }
122 };
123 
124 // Functor for fingerprinting replace state tuples.
125 template <class StateId, class PrefixId>
127  public:
128  explicit ReplaceFingerprint(const std::vector<uint64_t> *size_array)
129  : size_array_(size_array) {}
130 
131  uint64_t operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
132  return tuple.prefix_id * size_array_->back() +
133  size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
134  }
135 
136  private:
137  const std::vector<uint64_t> *size_array_;
138 };
139 
140 // Useful when the fst_state uniquely define the tuple.
141 template <class StateId, class PrefixId>
143  public:
144  uint64_t operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
145  return tuple.fst_state;
146  }
147 };
148 
149 // A generic hash function for replace state tuples.
150 template <typename S, typename P>
151 class ReplaceHash {
152  public:
153  size_t operator()(const ReplaceStateTuple<S, P> &t) const {
154  // We want three prime numbers that are all reasonably large and whose
155  // differences are far from each other. (E.g., we want prime1-prime0 to be
156  // far from prime2-prime1). It would be safer still to use large prime
157  // numbers (i.e., prime numbers that use all 64 bits on a 64-bit machine and
158  // all 32-bits on a 32-bit machine), so that all 64-bits (respectively
159  // 32-bits) of the resulting hash would look random. However, these
160  // modest-sized prime numbers are good enough for hash tables (such as
161  // std::unordered_set and std::unordered_set) that use the low-order bits
162  // of the hash.
163  //
164  // It is important that all three components are multiplied by a prime
165  // number. E.g., don't compute
166  // t.prefix_id + t.fst_id * prime1 + t.fst_state * prime2
167  // which is just the identity on t.prefix_id. Using the identity will
168  // result in long probe sequences in open-addressed hash tables (such as
169  // std::unordered_map).
170  static constexpr size_t prime0 = 7853;
171  static constexpr size_t prime1 = 9001;
172  static constexpr size_t prime2 = 100003;
173  return t.prefix_id * prime0 + t.fst_id * prime1 + t.fst_state * prime2;
174  }
175 };
176 
177 // Container for stack prefix.
178 template <class Label, class StateId>
180  public:
181  struct PrefixTuple {
183  : fst_id(fst_id), nextstate(nextstate) {}
184 
185  Label fst_id;
187  };
188 
189  ReplaceStackPrefix() = default;
190 
192  : prefix_(other.prefix_) {}
193 
194  void Push(StateId fst_id, StateId nextstate) {
195  prefix_.push_back(PrefixTuple(fst_id, nextstate));
196  }
197 
198  void Pop() { prefix_.pop_back(); }
199 
200  const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
201 
202  size_t Depth() const { return prefix_.size(); }
203 
204  public:
205  std::vector<PrefixTuple> prefix_;
206 };
207 
208 // Equality stack prefix classes.
209 template <class Label, class StateId>
212  if (x.prefix_.size() != y.prefix_.size()) return false;
213  for (size_t i = 0; i < x.prefix_.size(); ++i) {
214  if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
215  x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
216  return false;
217  }
218  }
219  return true;
220 }
221 
222 // Hash function for stack prefix to prefix id.
223 template <class Label, class StateId>
225  public:
226  size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
227  size_t sum = 0;
228  for (const auto &pair : prefix.prefix_) {
229  static constexpr size_t prime = 7863;
230  sum += pair.fst_id + pair.nextstate * prime;
231  }
232  return sum;
233  }
234 };
235 
236 // Replace state tables.
237 
238 // A two-level state table for replace. Warning: calls CountStates to compute
239 // the number of states of each component FST.
240 template <class Arc, class P = ssize_t>
242  public:
243  using Label = typename Arc::Label;
244  using StateId = typename Arc::StateId;
245 
246  using PrefixId = P;
247 
249  using StateTable =
255  using StackPrefixTable =
258 
260  const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
261  Label root)
262  : root_size_(0) {
263  size_array_.push_back(0);
264  for (const auto &[label, fst] : fst_list) {
265  if (label == root) {
266  root_size_ = CountStates(*fst);
267  size_array_.push_back(size_array_.back());
268  } else {
269  size_array_.push_back(size_array_.back() + CountStates(*fst));
270  }
271  }
272  state_table_ = std::make_unique<StateTable>(
273  ReplaceRootSelector<StateId, PrefixId>(),
274  ReplaceFstStateFingerprint<StateId, PrefixId>(),
275  ReplaceFingerprint<StateId, PrefixId>(&size_array_), root_size_,
276  root_size_ + size_array_.back());
277  }
278 
281  : root_size_(table.root_size_),
282  size_array_(table.size_array_),
283  prefix_table_(table.prefix_table_) {
284  state_table_ = std::make_unique<StateTable>(
285  ReplaceRootSelector<StateId, PrefixId>(),
286  ReplaceFstStateFingerprint<StateId, PrefixId>(),
287  ReplaceFingerprint<StateId, PrefixId>(&size_array_), root_size_,
288  root_size_ + size_array_.back());
289  }
290 
291  StateId FindState(const StateTuple &tuple) {
292  return state_table_->FindState(tuple);
293  }
294 
295  const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
296 
297  PrefixId FindPrefixId(const StackPrefix &prefix) {
298  return prefix_table_.FindId(prefix);
299  }
300 
301  const StackPrefix &GetStackPrefix(PrefixId id) const {
302  return prefix_table_.FindEntry(id);
303  }
304 
305  private:
306  StateId root_size_;
307  std::vector<uint64_t> size_array_;
308  std::unique_ptr<StateTable> state_table_;
309  StackPrefixTable prefix_table_;
310 };
311 
312 // Default replace state table.
313 template <class Arc, class P /* = size_t */>
315  : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
316  ReplaceHash<typename Arc::StateId, P>> {
317  public:
318  using Label = typename Arc::Label;
319  using StateId = typename Arc::StateId;
320 
321  using PrefixId = P;
323  using StateTable =
326  using StackPrefixTable =
329 
330  using StateTable::FindState;
331  using StateTable::Tuple;
332 
334  const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
335 
337  : StateTable(), prefix_table_(table.prefix_table_) {}
338 
339  PrefixId FindPrefixId(const StackPrefix &prefix) {
340  return prefix_table_.FindId(prefix);
341  }
342 
343  const StackPrefix &GetStackPrefix(PrefixId id) const {
344  return prefix_table_.FindEntry(id);
345  }
346 
347  private:
348  StackPrefixTable prefix_table_;
349 };
350 
351 // By default ReplaceFst will copy the input label of the replace arc.
352 // The call_label_type and return_label_type options specify how to manage
353 // the labels of the call arc and the return arc of the replace FST
354 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
355  class CacheStore = DefaultCacheStore<Arc>>
356 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
357  using Label = typename Arc::Label;
358 
359  // Index of root rule for expansion.
361  // How to label call arc.
363  // How to label return arc.
365  // Specifies output label to put on call arc; if kNoLabel, use existing label
366  // on call arc. Otherwise, use this field as the output label.
367  Label call_output_label = kNoLabel;
368  // Specifies label to put on return arc.
369  Label return_label = 0;
370  // Take ownership of input FSTs?
371  bool take_ownership = false;
372  // Pointer to optional pre-constructed state table.
373  StateTable *state_table = nullptr;
374 
376  Label root = kNoLabel)
377  : CacheImplOptions<CacheStore>(opts), root(root) {}
378 
379  explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
380  : CacheImplOptions<CacheStore>(opts), root(root) {}
381 
382  // FIXME(kbg): There are too many constructors here. Come up with a consistent
383  // position for call_output_label (probably the very end) so that it is
384  // possible to express all the remaining constructors with a single
385  // default-argument constructor. Also move clients off of the "backwards
386  // compatibility" constructor, for good.
387 
388  explicit ReplaceFstOptions(Label root) : root(root) {}
389 
390  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
391  ReplaceLabelType return_label_type,
392  Label return_label)
393  : root(root),
394  call_label_type(call_label_type),
395  return_label_type(return_label_type),
396  return_label(return_label) {}
397 
398  explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
399  ReplaceLabelType return_label_type,
400  Label call_output_label, Label return_label)
401  : root(root),
402  call_label_type(call_label_type),
403  return_label_type(return_label_type),
404  call_output_label(call_output_label),
405  return_label(return_label) {}
406 
407  explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
408  : ReplaceFstOptions(opts.root, opts.call_label_type,
409  opts.return_label_type, opts.return_label) {}
410 
412 
413  // For backwards compatibility.
414  ReplaceFstOptions(int64_t root, bool epsilon_replace_arc)
415  : root(root),
416  call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
418  call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
419 };
420 
421 // Forward declaration.
422 template <class Arc, class StateTable, class CacheStore>
424 
425 template <class Arc>
426 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
427 
428 // Returns true if label type on arc results in epsilon input label.
429 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
430  return label_type == REPLACE_LABEL_NEITHER ||
431  label_type == REPLACE_LABEL_OUTPUT;
432 }
433 
434 // Returns true if label type on arc results in epsilon input label.
435 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
436  return label_type == REPLACE_LABEL_NEITHER ||
437  label_type == REPLACE_LABEL_INPUT;
438 }
439 
440 // Returns true if for either the call or return arc ilabel != olabel.
441 template <class Label>
442 bool ReplaceTransducer(ReplaceLabelType call_label_type,
443  ReplaceLabelType return_label_type,
444  Label call_output_label) {
445  return call_label_type == REPLACE_LABEL_INPUT ||
446  call_label_type == REPLACE_LABEL_OUTPUT ||
447  (call_label_type == REPLACE_LABEL_BOTH &&
448  call_output_label != kNoLabel) ||
449  return_label_type == REPLACE_LABEL_INPUT ||
450  return_label_type == REPLACE_LABEL_OUTPUT;
451 }
452 
453 template <class Arc>
454 uint64_t ReplaceFstProperties(typename Arc::Label root_label,
455  const FstList<Arc> &fst_list,
456  ReplaceLabelType call_label_type,
457  ReplaceLabelType return_label_type,
458  typename Arc::Label call_output_label,
459  bool *sorted_and_non_empty) {
460  using Label = typename Arc::Label;
461  std::vector<uint64_t> inprops;
462  bool all_ilabel_sorted = true;
463  bool all_olabel_sorted = true;
464  bool all_non_empty = true;
465  // All nonterminals are negative?
466  bool all_negative = true;
467  // All nonterminals are positive and form a dense range containing 1?
468  bool dense_range = true;
469  Label root_fst_idx = 0;
470  for (Label i = 0; i < fst_list.size(); ++i) {
471  const auto label = fst_list[i].first;
472  if (label >= 0) all_negative = false;
473  if (label > fst_list.size() || label <= 0) dense_range = false;
474  if (label == root_label) root_fst_idx = i;
475  const auto *fst = fst_list[i].second;
476  if (fst->Start() == kNoStateId) all_non_empty = false;
477  if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
478  if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
479  inprops.push_back(fst->Properties(kCopyProperties, false));
480  }
481  const auto props = ReplaceProperties(
482  inprops, root_fst_idx, EpsilonOnInput(call_label_type),
483  EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
484  EpsilonOnOutput(return_label_type),
485  ReplaceTransducer(call_label_type, return_label_type, call_output_label),
486  all_non_empty, all_ilabel_sorted, all_olabel_sorted,
487  all_negative || dense_range);
488  const bool sorted = props & (kILabelSorted | kOLabelSorted);
489  *sorted_and_non_empty = all_non_empty && sorted;
490  return props;
491 }
492 
493 namespace internal {
494 
495 // The replace implementation class supports a dynamic expansion of a recursive
496 // transition network represented as label/FST pairs with dynamic replacable
497 // arcs.
498 template <class Arc, class StateTable, class CacheStore>
500  : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
501  public:
502  using Label = typename Arc::Label;
503  using StateId = typename Arc::StateId;
504  using Weight = typename Arc::Weight;
505 
506  using State = typename CacheStore::State;
508  using PrefixId = typename StateTable::PrefixId;
511  using NonTerminalHash = std::unordered_map<Label, Label>;
512 
513  using FstImpl<Arc>::SetType;
520 
521  using CacheImpl::HasArcs;
522  using CacheImpl::HasFinal;
523  using CacheImpl::HasStart;
524  using CacheImpl::PushArc;
525  using CacheImpl::SetArcs;
526  using CacheImpl::SetFinal;
527  using CacheImpl::SetStart;
528 
529  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
530 
531  ReplaceFstImpl(const FstList<Arc> &fst_list,
533  : CacheImpl(opts),
534  call_label_type_(opts.call_label_type),
535  return_label_type_(opts.return_label_type),
536  call_output_label_(opts.call_output_label),
537  return_label_(opts.return_label),
538  state_table_(opts.state_table ? opts.state_table
539  : new StateTable(fst_list, opts.root)) {
540  SetType("replace");
541  // If the label is epsilon, then all replace label options are equivalent,
542  // so we set the label types to NEITHER for simplicity.
543  if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
544  if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
545  if (!fst_list.empty()) {
546  SetInputSymbols(fst_list[0].second->InputSymbols());
547  SetOutputSymbols(fst_list[0].second->OutputSymbols());
548  }
549  fst_array_.push_back(nullptr);
550  for (Label i = 0; i < fst_list.size(); ++i) {
551  const auto label = fst_list[i].first;
552  const auto *fst = fst_list[i].second;
553  nonterminal_hash_[label] = fst_array_.size();
554  nonterminal_set_.insert(label);
555  fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
556  if (i) {
557  if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
558  FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
559  << " do not match input symbols of base FST (0th FST)";
560  SetProperties(kError, kError);
561  }
562  if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
563  FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
564  << " do not match output symbols of base FST (0th FST)";
565  SetProperties(kError, kError);
566  }
567  }
568  }
569  const auto nonterminal = nonterminal_hash_[opts.root];
570  if ((nonterminal == 0) && (fst_array_.size() > 1)) {
571  FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
572  << opts.root << " in the input tuple vector";
573  SetProperties(kError, kError);
574  }
575  root_ = (nonterminal > 0) ? nonterminal : 1;
576  bool all_non_empty_and_sorted = false;
577  SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
578  return_label_type_, call_output_label_,
579  &all_non_empty_and_sorted));
580  // Enables optional caching as long as sorted and all non-empty.
581  always_cache_ = !all_non_empty_and_sorted;
582  VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
583  << (always_cache_ ? "true" : "false");
584  }
585 
587  : CacheImpl(impl),
588  call_label_type_(impl.call_label_type_),
589  return_label_type_(impl.return_label_type_),
590  call_output_label_(impl.call_output_label_),
591  return_label_(impl.return_label_),
592  always_cache_(impl.always_cache_),
593  state_table_(new StateTable(*(impl.state_table_))),
594  nonterminal_set_(impl.nonterminal_set_),
595  nonterminal_hash_(impl.nonterminal_hash_),
596  root_(impl.root_) {
597  SetType("replace");
598  SetProperties(impl.Properties(), kCopyProperties);
599  SetInputSymbols(impl.InputSymbols());
600  SetOutputSymbols(impl.OutputSymbols());
601  fst_array_.reserve(impl.fst_array_.size());
602  fst_array_.emplace_back(nullptr);
603  for (Label i = 1; i < impl.fst_array_.size(); ++i) {
604  fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
605  }
606  }
607 
608  // Computes the dependency graph of the replace class and returns
609  // true if the dependencies are cyclic. Cyclic dependencies will result
610  // in an un-expandable FST.
611  bool CyclicDependencies() const {
612  const ReplaceUtilOptions opts(root_);
613  ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
614  return replace_util.CyclicDependencies();
615  }
616 
618  if (!HasStart()) {
619  if (fst_array_.size() == 1) {
620  SetStart(kNoStateId);
621  return kNoStateId;
622  } else {
623  const auto fst_start = fst_array_[root_]->Start();
624  if (fst_start == kNoStateId) return kNoStateId;
625  const auto prefix = GetPrefixId(StackPrefix());
626  const auto start =
627  state_table_->FindState(StateTuple(prefix, root_, fst_start));
628  SetStart(start);
629  return start;
630  }
631  } else {
632  return CacheImpl::Start();
633  }
634  }
635 
637  if (HasFinal(s)) return CacheImpl::Final(s);
638  const auto &tuple = state_table_->Tuple(s);
639  auto weight = Weight::Zero();
640  if (tuple.prefix_id == 0) {
641  const auto fst_state = tuple.fst_state;
642  weight = fst_array_[tuple.fst_id]->Final(fst_state);
643  }
644  if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
645  return weight;
646  }
647 
648  size_t NumArcs(StateId s) {
649  if (HasArcs(s)) {
650  return CacheImpl::NumArcs(s);
651  } else if (always_cache_) { // If always caching, expands and caches state.
652  Expand(s);
653  return CacheImpl::NumArcs(s);
654  } else { // Otherwise computes the number of arcs without expanding.
655  const auto tuple = state_table_->Tuple(s);
656  if (tuple.fst_state == kNoStateId) return 0;
657  auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
658  if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
659  return num_arcs;
660  }
661  }
662 
663  // Returns whether a given label is a non-terminal.
664  bool IsNonTerminal(Label label) const {
665  if (label < *nonterminal_set_.begin() ||
666  label > *nonterminal_set_.rbegin()) {
667  return false;
668  } else {
669  return nonterminal_hash_.count(label);
670  }
671  // TODO(allauzen): be smarter and take advantage of all_dense or
672  // all_negative. Also use this in ComputeArc. This would require changes to
673  // Replace so that recursing into an empty FST lead to a non co-accessible
674  // state instead of deleting the arc as done currently. The current use
675  // correct, since labels are sorted if all_non_empty is true.
676  }
677 
679  if (HasArcs(s)) {
680  return CacheImpl::NumInputEpsilons(s);
681  } else if (always_cache_ || !Properties(kILabelSorted)) {
682  // If always caching or if the number of input epsilons is too expensive
683  // to compute without caching (i.e., not ilabel-sorted), then expands and
684  // caches state.
685  Expand(s);
686  return CacheImpl::NumInputEpsilons(s);
687  } else {
688  // Otherwise, computes the number of input epsilons without caching.
689  const auto tuple = state_table_->Tuple(s);
690  if (tuple.fst_state == kNoStateId) return 0;
691  size_t num = 0;
692  if (!EpsilonOnInput(call_label_type_)) {
693  // If EpsilonOnInput(c) is false, all input epsilon arcs
694  // are also input epsilons arcs in the underlying machine.
695  num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
696  } else {
697  // Otherwise, one need to consider that all non-terminal arcs
698  // in the underlying machine also become input epsilon arc.
699  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
700  for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
701  IsNonTerminal(aiter.Value().olabel));
702  aiter.Next()) {
703  ++num;
704  }
705  }
706  if (EpsilonOnInput(return_label_type_) &&
707  ComputeFinalArc(tuple, nullptr)) {
708  ++num;
709  }
710  return num;
711  }
712  }
713 
715  if (HasArcs(s)) {
717  } else if (always_cache_ || !Properties(kOLabelSorted)) {
718  // If always caching or if the number of output epsilons is too expensive
719  // to compute without caching (i.e., not olabel-sorted), then expands and
720  // caches state.
721  Expand(s);
723  } else {
724  // Otherwise, computes the number of output epsilons without caching.
725  const auto tuple = state_table_->Tuple(s);
726  if (tuple.fst_state == kNoStateId) return 0;
727  size_t num = 0;
728  if (!EpsilonOnOutput(call_label_type_)) {
729  // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
730  // output epsilons arcs in the underlying machine.
731  num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
732  } else {
733  // Otherwise, one need to consider that all non-terminal arcs in the
734  // underlying machine also become output epsilon arc.
735  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
736  for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
737  IsNonTerminal(aiter.Value().olabel));
738  aiter.Next()) {
739  ++num;
740  }
741  }
742  if (EpsilonOnOutput(return_label_type_) &&
743  ComputeFinalArc(tuple, nullptr)) {
744  ++num;
745  }
746  return num;
747  }
748  }
749 
750  uint64_t Properties() const override { return Properties(kFstProperties); }
751 
752  // Sets error if found, and returns other FST impl properties.
753  uint64_t Properties(uint64_t mask) const override {
754  if (mask & kError) {
755  for (Label i = 1; i < fst_array_.size(); ++i) {
756  if (fst_array_[i]->Properties(kError, false)) {
757  SetProperties(kError, kError);
758  }
759  }
760  }
761  return FstImpl<Arc>::Properties(mask);
762  }
763 
764  // Returns the base arc iterator, and if arcs have not been computed yet,
765  // extends and recurses for new arcs.
767  if (!HasArcs(s)) Expand(s);
768  CacheImpl::InitArcIterator(s, data);
769  // TODO(allauzen): Set behaviour of generic iterator.
770  // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
771  // behaviour.
772  }
773 
774  // Extends current state (walk arcs one level deep).
775  void Expand(StateId s) {
776  const auto tuple = state_table_->Tuple(s);
777  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
778  SetArcs(s);
779  return;
780  }
781  ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
782  Arc arc;
783  // Creates a final arc when needed.
784  if (ComputeFinalArc(tuple, &arc)) PushArc(s, std::move(arc));
785  // Expands all arcs leaving the state.
786  for (; !aiter.Done(); aiter.Next()) {
787  if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, std::move(arc));
788  }
789  SetArcs(s);
790  }
791 
792  void Expand(StateId s, const StateTuple &tuple,
793  const ArcIteratorData<Arc> &data) {
794  if (tuple.fst_state == kNoStateId) { // Local FST is empty.
795  SetArcs(s);
796  return;
797  }
798  ArcIterator<Fst<Arc>> aiter(data);
799  Arc arc;
800  // Creates a final arc when needed.
801  if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
802  // Expands all arcs leaving the state.
803  for (; !aiter.Done(); aiter.Next()) {
804  if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
805  }
806  SetArcs(s);
807  }
808 
809  // If acpp is null, only returns true if a final arcp is required, but does
810  // not actually compute it.
811  bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
812  uint8_t flags = kArcValueFlags) {
813  const auto fst_state = tuple.fst_state;
814  if (fst_state == kNoStateId) return false;
815  // If state is final, pops the stack.
816  if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
817  tuple.prefix_id) {
818  if (arcp) {
819  arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
820  arcp->olabel =
821  (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
822  if (flags & kArcNextStateValue) {
823  const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
824  const auto prefix_id = PopPrefix(stack);
825  const auto &top = stack.Top();
826  arcp->nextstate = state_table_->FindState(
827  StateTuple(prefix_id, top.fst_id, top.nextstate));
828  }
829  if (flags & kArcWeightValue) {
830  arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
831  }
832  }
833  return true;
834  } else {
835  return false;
836  }
837  }
838 
839  // Computes an arc in the FST corresponding to one in the underlying machine.
840  // Returns false if the underlying arc corresponds to no arc in the resulting
841  // FST.
842  bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
843  uint8_t flags = kArcValueFlags) {
844  if (!EpsilonOnInput(call_label_type_) &&
845  (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
846  *arcp = arc;
847  return true;
848  }
849  if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
850  arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST.
851  const auto nextstate =
852  flags & kArcNextStateValue
853  ? state_table_->FindState(
854  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
855  : kNoStateId;
856  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
857  } else {
858  // Checks for non-terminal.
859  if (const auto it = nonterminal_hash_.find(arc.olabel);
860  it != nonterminal_hash_.end()) { // Recurses into non-terminal.
861  const auto nonterminal = it->second;
862  const auto nt_prefix =
863  PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
864  tuple.fst_id, arc.nextstate);
865  // If the start state is valid, replace; othewise, the arc is implicitly
866  // deleted.
867  const auto nt_start = fst_array_[nonterminal]->Start();
868  if (nt_start != kNoStateId) {
869  const auto nt_nextstate = flags & kArcNextStateValue
870  ? state_table_->FindState(StateTuple(
871  nt_prefix, nonterminal, nt_start))
872  : kNoStateId;
873  const auto ilabel =
874  (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
875  const auto olabel =
876  (EpsilonOnOutput(call_label_type_))
877  ? 0
878  : ((call_output_label_ == kNoLabel) ? arc.olabel
879  : call_output_label_);
880  *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
881  } else {
882  return false;
883  }
884  } else {
885  const auto nextstate =
886  flags & kArcNextStateValue
887  ? state_table_->FindState(
888  StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
889  : kNoStateId;
890  *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
891  }
892  }
893  return true;
894  }
895 
896  // Returns the arc iterator flags supported by this FST.
897  uint8_t ArcIteratorFlags() const {
898  uint8_t flags = kArcValueFlags;
899  if (!always_cache_) flags |= kArcNoCache;
900  return flags;
901  }
902 
903  StateTable *GetStateTable() const { return state_table_.get(); }
904 
905  const Fst<Arc> *GetFst(Label fst_id) const {
906  return fst_array_[fst_id].get();
907  }
908 
909  Label GetFstId(Label nonterminal) const {
910  const auto it = nonterminal_hash_.find(nonterminal);
911  if (it == nonterminal_hash_.end()) {
912  FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
913  << nonterminal;
914  }
915  return it->second;
916  }
917 
918  // Returns true if label type on call arc results in epsilon input label.
919  bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
920 
921  private:
922  // The unique index into stack prefix table.
923  PrefixId GetPrefixId(const StackPrefix &prefix) {
924  return state_table_->FindPrefixId(prefix);
925  }
926 
927  // The prefix ID after a stack pop.
928  PrefixId PopPrefix(StackPrefix prefix) {
929  prefix.Pop();
930  return GetPrefixId(prefix);
931  }
932 
933  // The prefix ID after a stack push.
934  PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
935  prefix.Push(fst_id, nextstate);
936  return GetPrefixId(prefix);
937  }
938 
939  // Runtime options
940  ReplaceLabelType call_label_type_; // How to label call arc.
941  ReplaceLabelType return_label_type_; // How to label return arc.
942  int64_t call_output_label_; // Specifies output label to put on call arc
943  int64_t return_label_; // Specifies label to put on return arc.
944  bool always_cache_; // Disable optional caching of arc iterator?
945 
946  // State table.
947  std::unique_ptr<StateTable> state_table_;
948 
949  // Replace components.
950  std::set<Label> nonterminal_set_;
951  NonTerminalHash nonterminal_hash_;
952  std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
953  Label root_;
954 };
955 
956 } // namespace internal
957 
958 //
959 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
960 // This replacement is recursive. ReplaceFst can be used to support a variety of
961 // delayed constructions such as recursive
962 // transition networks, union, or closure. It is constructed with an array of
963 // FST(s). One FST represents the root (or topology) machine. The root FST
964 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
965 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
966 // symbols of the arcs to determine whether the arc is a non-terminal arc or
967 // not. A non-terminal can be any label that is not a non-zero terminal label in
968 // the output alphabet.
969 //
970 // Note that the constructor uses a vector of pairs. These correspond to the
971 // tuple of non-terminal Label and corresponding FST. For example to implement
972 // the closure operation we need 2 FSTs. The first root FST is a single
973 // self-loop arc on the start state.
974 //
975 // The ReplaceFst class supports an optionally caching arc iterator.
976 //
977 // The ReplaceFst needs to be built such that it is known to be ilabel- or
978 // olabel-sorted (see usage below).
979 //
980 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
981 // when available (the FST is ilabel-sorted and matching on the input, or the
982 // FST is olabel -orted and matching on the output). In order to obtain the
983 // most efficient behaviour, it is recommended to set call_label_type to
984 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
985 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
986 // does not have epsilon on the input side and the return arc has epsilon on the
987 // input side) and matching on the input side.
988 //
989 // This class attaches interface to implementation and handles reference
990 // counting, delegating most methods to ImplToFst.
991 template <class A, class T /* = DefaultReplaceStateTable<A> */,
992  class CacheStore /* = DefaultCacheStore<A> */>
993 class ReplaceFst
994  : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
995  public:
996  using Arc = A;
997  using Label = typename Arc::Label;
998  using StateId = typename Arc::StateId;
999  using Weight = typename Arc::Weight;
1000 
1001  using StateTable = T;
1002  using Store = CacheStore;
1003  using State = typename CacheStore::State;
1006 
1008 
1009  friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
1010  friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
1011  friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
1012 
1013  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
1014  Label root)
1015  : ImplToFst<Impl>(std::make_shared<Impl>(
1016  fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
1017 
1018  ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
1020  : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
1021 
1022  // See Fst<>::Copy() for doc.
1023  ReplaceFst(const ReplaceFst &fst, bool safe = false)
1024  : ImplToFst<Impl>(fst, safe) {}
1025 
1026  // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
1027  ReplaceFst *Copy(bool safe = false) const override {
1028  return new ReplaceFst(*this, safe);
1029  }
1030 
1031  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
1032 
1033  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
1034  GetMutableImpl()->InitArcIterator(s, data);
1035  }
1036 
1037  MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
1038  if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1039  ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
1040  (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
1042  match_type);
1043  } else {
1044  VLOG(2) << "Not using replace matcher";
1045  return nullptr;
1046  }
1047  }
1048 
1049  bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1050 
1051  const StateTable &GetStateTable() const {
1052  return *GetImpl()->GetStateTable();
1053  }
1054 
1055  const Fst<Arc> &GetFst(Label nonterminal) const {
1056  return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1057  }
1058 
1059  private:
1062 
1063  ReplaceFst &operator=(const ReplaceFst &) = delete;
1064 };
1065 
1066 // Specialization for ReplaceFst.
1067 template <class Arc, class StateTable, class CacheStore>
1068 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1069  : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1070  public:
1072  : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1073  fst, fst.GetMutableImpl()) {}
1074 };
1075 
1076 // Specialization for ReplaceFst, implementing optional caching. It is be used
1077 // as follows:
1078 //
1079 // ReplaceFst<A> replace;
1080 // ArcIterator<ReplaceFst<A>> aiter(replace, s);
1081 // // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1082 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1083 // // Uses the arc iterator, no arc will be cached, no state will be expanded.
1084 // // Arc flags can be used to decide which component of the arc need to be
1085 // computed.
1086 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1087 // // Wants the ilabel for this arc.
1088 // aiter.Value(); // Does not compute the destination state.
1089 // aiter.Next();
1090 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1091 // // Wants the ilabel and next state for this arc.
1092 // aiter.Value(); // Does compute the destination state and inserts it
1093 // // in the replace state table.
1094 // // No additional arcs have been cached at this point.
1095 template <class Arc, class StateTable, class CacheStore>
1096 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1097  public:
1098  using StateId = typename Arc::StateId;
1099 
1100  using StateTuple = typename StateTable::StateTuple;
1101 
1103  : fst_(fst),
1104  s_(s),
1105  pos_(0),
1106  offset_(0),
1107  flags_(kArcValueFlags),
1108  arcs_(nullptr),
1109  data_flags_(0),
1110  final_flags_(0) {
1111  cache_data_.ref_count = nullptr;
1112  local_data_.ref_count = nullptr;
1113  // If FST does not support optional caching, forces caching.
1114  if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1115  !(fst_.GetImpl()->HasArcs(s_))) {
1116  fst_.GetMutableImpl()->Expand(s_);
1117  }
1118  // If state is already cached, use cached arcs array.
1119  if (fst_.GetImpl()->HasArcs(s_)) {
1120  (fst_.GetImpl())
1121  ->internal::template CacheBaseImpl<
1122  typename CacheStore::State,
1123  CacheStore>::InitArcIterator(s_, &cache_data_);
1124  num_arcs_ = cache_data_.narcs;
1125  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1126  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1127  } else { // Otherwise delay decision until Value() is called.
1128  tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1129  if (tuple_.fst_state == kNoStateId) {
1130  num_arcs_ = 0;
1131  } else {
1132  // The decision to cache or not to cache has been defered until Value()
1133  // or
1134  // SetFlags() is called. However, the arc iterator is set up now to be
1135  // ready for non-caching in order to keep the Value() method simple and
1136  // efficient.
1137  const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1138  rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1139  // arcs_ is a pointer to the arcs in the underlying machine.
1140  arcs_ = local_data_.arcs;
1141  // Computes the final arc (but not its destination state) if a final arc
1142  // is required.
1143  bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1144  tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1145  // Sets the arc value flags that hold for final_arc_.
1146  final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1147  // Computes the number of arcs.
1148  num_arcs_ = local_data_.narcs;
1149  if (has_final_arc) ++num_arcs_;
1150  // Sets the offset between the underlying arc positions and the
1151  // positions
1152  // in the arc iterator.
1153  offset_ = num_arcs_ - local_data_.narcs;
1154  // Defers the decision to cache or not until Value() or SetFlags() is
1155  // called.
1156  data_flags_ = 0;
1157  }
1158  }
1159  }
1160 
1162  if (cache_data_.ref_count) --(*cache_data_.ref_count);
1163  if (local_data_.ref_count) --(*local_data_.ref_count);
1164  }
1165 
1166  void ExpandAndCache() const {
1167  // TODO(allauzen): revisit this.
1168  // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1169  // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1170  // &cache_data_);
1171  //
1172  fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state.
1173  arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1174  data_flags_ = kArcValueFlags; // All the arc member values are valid.
1175  offset_ = 0; // No offset.
1176  }
1177 
1178  void Init() {
1179  if (flags_ & kArcNoCache) { // If caching is disabled
1180  // arcs_ is a pointer to the arcs in the underlying machine.
1181  arcs_ = local_data_.arcs;
1182  // Sets the arcs value flags that hold for arcs_.
1183  data_flags_ = kArcWeightValue;
1184  if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1185  data_flags_ |= kArcILabelValue;
1186  }
1187  // Sets the offset between the underlying arc positions and the positions
1188  // in the arc iterator.
1189  offset_ = num_arcs_ - local_data_.narcs;
1190  } else {
1191  ExpandAndCache();
1192  }
1193  }
1194 
1195  bool Done() const { return pos_ >= num_arcs_; }
1196 
1197  const Arc &Value() const {
1198  // If data_flags_ is 0, non-caching was not requested.
1199  if (!data_flags_) {
1200  // TODO(allauzen): Revisit this.
1201  if (flags_ & kArcNoCache) {
1202  // Should never happen.
1203  FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1204  }
1205  ExpandAndCache();
1206  }
1207  if (pos_ - offset_ >= 0) { // The requested arc is not the final arc.
1208  const auto &arc = arcs_[pos_ - offset_];
1209  if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1210  // If the value flags match the recquired value flags then returns the
1211  // arc.
1212  return arc;
1213  } else {
1214  // Otherwise, compute the corresponding arc on-the-fly.
1215  fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1216  flags_ & kArcValueFlags);
1217  return arc_;
1218  }
1219  } else { // The requested arc is the final arc.
1220  if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1221  // If the arc value flags that hold for the final arc do not match the
1222  // requested value flags, then
1223  // final_arc_ needs to be updated.
1224  fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1225  flags_ & kArcValueFlags);
1226  final_flags_ = flags_ & kArcValueFlags;
1227  }
1228  return final_arc_;
1229  }
1230  }
1231 
1232  void Next() { ++pos_; }
1233 
1234  size_t Position() const { return pos_; }
1235 
1236  void Reset() { pos_ = 0; }
1237 
1238  void Seek(size_t pos) { pos_ = pos; }
1239 
1240  uint8_t Flags() const { return flags_; }
1241 
1242  void SetFlags(uint8_t flags, uint8_t mask) {
1243  // Updates the flags taking into account what flags are supported
1244  // by the FST.
1245  flags_ &= ~mask;
1246  flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1247  // If non-caching is not requested (and caching has not already been
1248  // performed), then flush data_flags_ to request caching during the next
1249  // call to Value().
1250  if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1251  if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1252  }
1253  // If data_flags_ has been flushed but non-caching is requested before
1254  // calling Value(), then set up the iterator for non-caching.
1255  if ((flags & kArcNoCache) && (!data_flags_)) Init();
1256  }
1257 
1258  private:
1259  const ReplaceFst<Arc, StateTable, CacheStore> &fst_; // Reference to the FST.
1260  StateId s_; // State in the FST.
1261  mutable StateTuple tuple_; // Tuple corresponding to state_.
1262 
1263  ssize_t pos_; // Current position.
1264  mutable ssize_t offset_; // Offset between position in iterator and in arcs_.
1265  ssize_t num_arcs_; // Number of arcs at state_.
1266  uint8_t flags_; // Behavorial flags for the arc iterator
1267  mutable Arc arc_; // Memory to temporarily store computed arcs.
1268 
1269  mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache.
1270  mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local FST.
1271 
1272  mutable const Arc *arcs_; // Array of arcs.
1273  mutable uint8_t data_flags_; // Arc value flags valid for data in arcs_.
1274  mutable Arc final_arc_; // Final arc (when required).
1275  mutable uint8_t final_flags_; // Arc value flags valid for final_arc_.
1276 
1277  ArcIterator(const ArcIterator &) = delete;
1278  ArcIterator &operator=(const ArcIterator &) = delete;
1279 };
1280 
1281 template <class Arc, class StateTable, class CacheStore>
1282 class ReplaceFstMatcher : public MatcherBase<Arc> {
1283  public:
1284  using Label = typename Arc::Label;
1285  using StateId = typename Arc::StateId;
1286  using Weight = typename Arc::Weight;
1287 
1290 
1291  using StateTuple = typename StateTable::StateTuple;
1292 
1293  // This makes a copy of the FST.
1295  MatchType match_type)
1296  : owned_fst_(fst.Copy()),
1297  fst_(*owned_fst_),
1298  impl_(fst_.GetMutableImpl()),
1299  s_(fst::kNoStateId),
1300  match_type_(match_type),
1301  current_loop_(false),
1302  final_arc_(false),
1303  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1304  if (match_type_ == fst::MATCH_OUTPUT) {
1305  std::swap(loop_.ilabel, loop_.olabel);
1306  }
1307  InitMatchers();
1308  }
1309 
1310  // This doesn't copy the FST.
1312  MatchType match_type)
1313  : fst_(*fst),
1314  impl_(fst_.GetMutableImpl()),
1315  s_(fst::kNoStateId),
1316  match_type_(match_type),
1317  current_loop_(false),
1318  final_arc_(false),
1319  loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1320  if (match_type_ == fst::MATCH_OUTPUT) {
1321  std::swap(loop_.ilabel, loop_.olabel);
1322  }
1323  InitMatchers();
1324  }
1325 
1326  // This makes a copy of the FST.
1327  ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe = false)
1328  : owned_fst_(matcher.fst_.Copy(safe)),
1329  fst_(*owned_fst_),
1330  impl_(fst_.GetMutableImpl()),
1331  s_(fst::kNoStateId),
1332  match_type_(matcher.match_type_),
1333  current_loop_(false),
1334  final_arc_(false),
1335  loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1336  if (match_type_ == fst::MATCH_OUTPUT) {
1337  std::swap(loop_.ilabel, loop_.olabel);
1338  }
1339  InitMatchers();
1340  }
1341 
1342  // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1343  // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1344  // non-terminal arc, since these non-terminal
1345  // turn into epsilons on recursion.
1346  void InitMatchers() {
1347  const auto &fst_array = impl_->fst_array_;
1348  matcher_.resize(fst_array.size());
1349  for (Label i = 0; i < fst_array.size(); ++i) {
1350  if (fst_array[i]) {
1351  matcher_[i] = std::make_unique<LocalMatcher>(*fst_array[i], match_type_,
1352  kMultiEpsList);
1353  auto it = impl_->nonterminal_set_.begin();
1354  for (; it != impl_->nonterminal_set_.end(); ++it) {
1355  matcher_[i]->AddMultiEpsLabel(*it);
1356  }
1357  }
1358  }
1359  }
1360 
1361  ReplaceFstMatcher *Copy(bool safe = false) const override {
1362  return new ReplaceFstMatcher(*this, safe);
1363  }
1364 
1365  MatchType Type(bool test) const override {
1366  if (match_type_ == MATCH_NONE) return match_type_;
1367  const auto true_prop =
1368  match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1369  const auto false_prop =
1370  match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1371  const auto props = fst_.Properties(true_prop | false_prop, test);
1372  if (props & true_prop) {
1373  return match_type_;
1374  } else if (props & false_prop) {
1375  return MATCH_NONE;
1376  } else {
1377  return MATCH_UNKNOWN;
1378  }
1379  }
1380 
1381  const Fst<Arc> &GetFst() const override { return fst_; }
1382 
1383  uint64_t Properties(uint64_t props) const override { return props; }
1384 
1385  // Sets the state from which our matching happens.
1386  void SetState(StateId s) final {
1387  if (s_ == s) return;
1388  s_ = s;
1389  tuple_ = impl_->GetStateTable()->Tuple(s_);
1390  if (tuple_.fst_state == kNoStateId) {
1391  done_ = true;
1392  return;
1393  }
1394  // Gets current matcher, used for non-epsilon matching.
1395  current_matcher_ = matcher_[tuple_.fst_id].get();
1396  current_matcher_->SetState(tuple_.fst_state);
1397  loop_.nextstate = s_;
1398  final_arc_ = false;
1399  }
1400 
1401  // Searches for label from previous set state. If label == 0, first
1402  // hallucinate an epsilon loop; otherwise use the underlying matcher to
1403  // search for the label or epsilons. Note since the ReplaceFst recursion
1404  // on non-terminal arcs causes epsilon transitions to be created we use
1405  // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1406  // component FST
1407  // reaches a final state we also need to add the exiting final arc.
1408  bool Find(Label label) final {
1409  bool found = false;
1410  label_ = label;
1411  if (label_ == 0 || label_ == kNoLabel) {
1412  // Computes loop directly, avoiding Replace::ComputeArc.
1413  if (label_ == 0) {
1414  current_loop_ = true;
1415  found = true;
1416  }
1417  // Searches for matching multi-epsilons.
1418  final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1419  found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1420  } else {
1421  // Searches on a sub machine directly using sub machine matcher.
1422  found = current_matcher_->Find(label_);
1423  }
1424  return found;
1425  }
1426 
1427  bool Done() const final {
1428  return !current_loop_ && !final_arc_ && current_matcher_->Done();
1429  }
1430 
1431  const Arc &Value() const final {
1432  if (current_loop_) return loop_;
1433  if (final_arc_) {
1434  impl_->ComputeFinalArc(tuple_, &arc_);
1435  return arc_;
1436  }
1437  const auto &component_arc = current_matcher_->Value();
1438  impl_->ComputeArc(tuple_, component_arc, &arc_);
1439  return arc_;
1440  }
1441 
1442  void Next() final {
1443  if (current_loop_) {
1444  current_loop_ = false;
1445  return;
1446  }
1447  if (final_arc_) {
1448  final_arc_ = false;
1449  return;
1450  }
1451  current_matcher_->Next();
1452  }
1453 
1454  ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1455 
1456  private:
1457  std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
1460  LocalMatcher *current_matcher_;
1461  std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1462  StateId s_; // Current state.
1463  Label label_; // Current label.
1464  MatchType match_type_; // Supplied by caller.
1465  mutable bool done_;
1466  mutable bool current_loop_; // Current arc is the implicit loop.
1467  mutable bool final_arc_; // Current arc for exiting recursion.
1468  mutable StateTuple tuple_; // Tuple corresponding to state_.
1469  mutable Arc arc_;
1470  Arc loop_;
1471 
1472  ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1473 };
1474 
1475 template <class Arc, class StateTable, class CacheStore>
1477  StateIteratorData<Arc> *data) const {
1478  data->base =
1479  std::make_unique<StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>>(
1480  *this);
1481 }
1482 
1484 
1485 // Recursively replaces arcs in the root FSTs with other FSTs.
1486 // This version writes the result of replacement to an output MutableFst.
1487 //
1488 // Replace supports replacement of arcs in one Fst with another FST. This
1489 // replacement is recursive. Replace takes an array of FST(s). One FST
1490 // represents the root (or topology) machine. The root FST refers to other FSTs
1491 // by recursively replacing arcs labeled as non-terminals with the matching
1492 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1493 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1494 // any label that is not a non-zero terminal label in the output alphabet.
1495 //
1496 // Note that input argument is a vector of pairs. These correspond to the tuple
1497 // of non-terminal Label and corresponding FST.
1498 template <class Arc>
1499 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1500  &ifst_array,
1501  MutableFst<Arc> *ofst,
1503  opts.gc = true;
1504  opts.gc_limit = 0; // Caches only the last state for fastest copy.
1505  *ofst = ReplaceFst<Arc>(ifst_array, opts);
1506 }
1507 
1508 template <class Arc>
1509 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1510  &ifst_array,
1511  MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1512  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1513 }
1514 
1515 // For backwards compatibility.
1516 template <class Arc>
1517 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1518  &ifst_array,
1519  MutableFst<Arc> *ofst, typename Arc::Label root,
1520  bool epsilon_on_replace) {
1521  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1522 }
1523 
1524 template <class Arc>
1525 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1526  &ifst_array,
1527  MutableFst<Arc> *ofst, typename Arc::Label root) {
1528  Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1529 }
1530 
1531 } // namespace fst
1532 
1533 #endif // FST_REPLACE_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:124
ReplaceFstOptions(const CacheOptions &opts, Label root=kNoLabel)
Definition: replace.h:379
ReplaceFingerprint(const std::vector< uint64_t > *size_array)
Definition: replace.h:128
uint64_t Properties(uint64_t mask) const override
Definition: replace.h:753
typename CacheStore::State State
Definition: replace.h:506
void SetState(StateId s) final
Definition: replace.h:1386
constexpr uint8_t kArcValueFlags
Definition: fst.h:453
ArcIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst, StateId s)
Definition: replace.h:1102
const StateTuple & Tuple(StateId id) const
Definition: replace.h:295
constexpr int kNoLabel
Definition: fst.h:195
std::vector< PrefixTuple > prefix_
Definition: replace.h:205
MatcherBase< Arc > * InitMatcher(MatchType match_type) const override
Definition: replace.h:1037
void Expand(StateId s, const StateTuple &tuple, const ArcIteratorData< Arc > &data)
Definition: replace.h:792
constexpr uint8_t kArcNoCache
Definition: fst.h:452
const StateTable & GetStateTable() const
Definition: replace.h:1051
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:343
void Next() final
Definition: replace.h:1442
const Fst< Arc > * GetFst(Label fst_id) const
Definition: replace.h:905
ReplaceStateTuple(PrefixId prefix_id=-1, StateId fst_id=kNoStateId, StateId fst_state=kNoStateId)
Definition: replace.h:92
typename StateTable::PrefixId PrefixId
Definition: replace.h:508
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:339
ReplaceFstOptions(const ReplaceUtilOptions &opts)
Definition: replace.h:407
size_t NumOutputEpsilons(StateId s)
Definition: replace.h:714
bool EpsilonOnOutput(ReplaceLabelType label_type)
Definition: replace.h:435
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > *fst, MatchType match_type)
Definition: replace.h:1311
ReplaceLabelType
Definition: replace-util.h:48
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:107
constexpr uint32_t kMultiEpsList
Definition: matcher.h:1234
const SymbolTable * OutputSymbols() const
Definition: fst.h:761
MatchType
Definition: fst.h:187
uint8_t ArcIteratorFlags() const
Definition: replace.h:897
uint64_t Properties(uint64_t props) const override
Definition: replace.h:1383
constexpr uint64_t kError
Definition: properties.h:52
uint64_t Properties() const override
Definition: replace.h:750
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:1018
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: replace.h:1033
bool ReplaceTransducer(ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label)
Definition: replace.h:442
uint64_t operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:144
void Push(StateId fst_id, StateId nextstate)
Definition: replace.h:194
VectorHashReplaceStateTable(const VectorHashReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:279
SetType
Definition: set-weight.h:59
MatchType Type(bool test) const override
Definition: replace.h:1365
typename Arc::Weight Weight
Definition: replace.h:1286
typename ReplaceFst< Arc, StateTable, CacheStore >::Arc Arc
Definition: cache.h:1156
constexpr uint8_t kArcILabelValue
Definition: fst.h:444
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:113
constexpr uint64_t kNotOLabelSorted
Definition: properties.h:101
bool EpsilonOnInput(ReplaceLabelType label_type)
Definition: replace.h:429
void Replace(const std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> &ifst_array, MutableFst< Arc > *ofst, std::vector< std::pair< typename Arc::Label, typename Arc::Label >> *parens, const PdtReplaceOptions< Arc > &opts)
Definition: replace.h:802
bool CyclicDependencies() const
Definition: replace.h:1049
VectorHashReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_list, Label root)
Definition: replace.h:259
ReplaceFstOptions(const CacheImplOptions< CacheStore > &opts, Label root=kNoLabel)
Definition: replace.h:375
bool Find(Label label) final
Definition: replace.h:1408
const PrefixTuple & Top() const
Definition: replace.h:200
const Fst< Arc > & GetFst(Label nonterminal) const
Definition: replace.h:1055
constexpr int kNoStateId
Definition: fst.h:196
const Arc & Value() const
Definition: fst.h:536
DefaultReplaceStateTable(const DefaultReplaceStateTable< Arc, PrefixId > &table)
Definition: replace.h:336
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: replace.h:766
std::vector< std::pair< typename Arc::Label, const Fst< Arc > * >> FstList
Definition: replace.h:426
bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp, uint8_t flags=kArcValueFlags)
Definition: replace.h:842
#define FSTERROR()
Definition: util.h:56
bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp, uint8_t flags=kArcValueFlags)
Definition: replace.h:811
ReplaceFstImpl(const ReplaceFstImpl &impl)
Definition: replace.h:586
size_t operator()(const ReplaceStateTuple< S, P > &t) const
Definition: replace.h:153
constexpr uint64_t kNotILabelSorted
Definition: properties.h:96
typename Arc::Weight Weight
Definition: replace.h:504
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:118
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: replace.h:1476
constexpr uint64_t kOLabelSorted
Definition: properties.h:99
constexpr uint64_t kCopyProperties
Definition: properties.h:163
typename Arc::StateId StateId
Definition: replace.h:244
uint64_t ReplaceProperties(const std::vector< uint64_t > &inprops, size_t root, bool epsilon_on_call, bool epsilon_on_return, bool out_epsilon_on_call, bool out_epsilon_on_return, bool replace_transducer, bool no_empty_fst, bool all_ilabel_sorted, bool all_olabel_sorted, bool all_negative_or_dense)
Definition: properties.cc:251
const Fst< Arc > & GetFst() const override
Definition: replace.h:1381
Weight Final(StateId s)
Definition: replace.h:636
ReplaceFstOptions(int64_t root, bool epsilon_replace_arc)
Definition: replace.h:414
#define VLOG(level)
Definition: log.h:54
size_t NumArcs(StateId s)
Definition: replace.h:648
typename Arc::Label Label
Definition: replace.h:243
bool Done() const
Definition: fst.h:532
ReplaceFst * Copy(bool safe=false) const override
Definition: replace.h:1027
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
PrefixId FindPrefixId(const StackPrefix &prefix)
Definition: replace.h:297
const StackPrefix & GetStackPrefix(PrefixId id) const
Definition: replace.h:301
ReplaceFst(const std::vector< std::pair< Label, const Fst< Arc > * >> &fst_array, Label root)
Definition: replace.h:1013
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label return_label)
Definition: replace.h:390
uint64_t operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:131
typename Arc::Label Label
Definition: replace.h:357
constexpr uint8_t kArcWeightValue
Definition: fst.h:448
constexpr uint64_t kILabelSorted
Definition: properties.h:94
ReplaceFstMatcher * Copy(bool safe=false) const override
Definition: replace.h:1361
void Expand(StateId s)
Definition: replace.h:775
uint64_t ReplaceFstProperties(typename Arc::Label root_label, const FstList< Arc > &fst_list, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, typename Arc::Label call_output_label, bool *sorted_and_non_empty)
Definition: replace.h:454
constexpr uint64_t kFstProperties
Definition: properties.h:326
ReplaceFst(const ReplaceFst &fst, bool safe=false)
Definition: replace.h:1023
const Arc & Value() const final
Definition: replace.h:1431
Label GetFstId(Label nonterminal) const
Definition: replace.h:909
typename internal::ReplaceFstImpl< Arc, StateTable, CacheStore >::Arc Arc
Definition: fst.h:205
size_t Depth() const
Definition: replace.h:202
const SymbolTable * InputSymbols() const
Definition: fst.h:759
Arc::StateId CountStates(const Fst< Arc > &fst)
Definition: expanded-fst.h:179
constexpr uint8_t kArcNextStateValue
Definition: fst.h:450
bool CyclicDependencies() const
Definition: replace-util.h:138
typename Arc::Label Label
Definition: replace.h:1284
friend H AbslHashValue(H h, const ReplaceStateTuple &t)
Definition: replace.h:97
size_t operator()(const ReplaceStackPrefix< Label, StateId > &prefix) const
Definition: replace.h:226
size_t NumInputEpsilons(StateId s)
Definition: replace.h:678
ssize_t Priority(StateId s) final
Definition: replace.h:1454
ReplaceFstOptions(Label root, ReplaceLabelType call_label_type, ReplaceLabelType return_label_type, Label call_output_label, Label return_label)
Definition: replace.h:398
ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe=false)
Definition: replace.h:1327
ReplaceFstOptions(Label root)
Definition: replace.h:388
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:51
bool CyclicDependencies() const
Definition: replace.h:611
typename StateTable::StateTuple StateTuple
Definition: replace.h:1291
ReplaceFstImpl(const FstList< Arc > &fst_list, const ReplaceFstOptions< Arc, StateTable, CacheStore > &opts)
Definition: replace.h:531
ReplaceFstMatcher(const ReplaceFst< Arc, StateTable, CacheStore > &fst, MatchType match_type)
Definition: replace.h:1294
typename Arc::StateId StateId
Definition: replace.h:319
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:323
ReplaceStackPrefix(const ReplaceStackPrefix &other)
Definition: replace.h:191
typename Arc::StateId StateId
Definition: replace.h:503
bool Done() const final
Definition: replace.h:1427
StateId FindState(const StateTuple &tuple)
Definition: replace.h:291
bool IsNonTerminal(Label label) const
Definition: replace.h:664
typename Arc::StateId StateId
Definition: replace.h:1285
DefaultReplaceStateTable(const std::vector< std::pair< Label, const Fst< Arc > * >> &, Label)
Definition: replace.h:333
StateIterator(const ReplaceFst< Arc, StateTable, CacheStore > &fst)
Definition: replace.h:1071
typename Arc::Label Label
Definition: replace.h:502
PrefixTuple(Label fst_id=kNoLabel, StateId nextstate=kNoStateId)
Definition: replace.h:182
bool operator()(const ReplaceStateTuple< StateId, PrefixId > &tuple) const
Definition: replace.h:119
StateTable * GetStateTable() const
Definition: replace.h:903
std::unordered_map< Label, Label > NonTerminalHash
Definition: replace.h:511
void Next()
Definition: fst.h:540