FST  openfst-1.7.1
OpenFst Library
lookahead-matcher.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to add lookahead to FST matchers, useful for improving composition
5 // efficiency with certain inputs.
6 
7 #ifndef FST_LOOKAHEAD_MATCHER_H_
8 #define FST_LOOKAHEAD_MATCHER_H_
9 
10 #include <memory>
11 #include <utility>
12 #include <vector>
13 
14 #include <fst/flags.h>
15 #include <fst/log.h>
16 
17 #include <fst/add-on.h>
18 #include <fst/const-fst.h>
19 #include <fst/fst.h>
20 #include <fst/label-reachable.h>
21 #include <fst/matcher.h>
22 
23 
24 DECLARE_string(save_relabel_ipairs);
25 DECLARE_string(save_relabel_opairs);
26 
27 namespace fst {
28 
29 // Lookahead matches extend the matcher interface with following additional
30 // methods:
31 //
32 // template <class FST>
33 // class LookAheadMatcher {
34 // public:
35 // using Arc = typename FST::Arc;
36 // using Label = typename Arc::Label;
37 // using StateId = typename Arc::StateId;
38 // using Weight = typename Arc::Weight;
39 //
40 // // Required constructors.
41 // // This makes a copy of the FST.
42 // LookAheadMatcher(const FST &fst, MatchType match_type);
43 // // This doesn't copy the FST.
44 // LookAheadMatcher(const FST *fst, MatchType match_type);
45 // // This makes a copy of the FST.
46 // // See Copy() below.
47 // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
48 //
49 // // If safe = true, the copy is thread-safe (except the lookahead FST is
50 // // preserved). See Fst<>::Copy() for further doc.
51 // LookaheadMatcher<FST> *Copy(bool safe = false) const override;
52 
53 // // Below are methods for looking ahead for a match to a label and more
54 // // generally, to a rational set. Each returns false if there is definitely
55 // // not a match and returns true if there possibly is a match.
56 //
57 // // Optionally pre-specifies the lookahead FST that will be passed to
58 // // LookAheadFst() for possible precomputation. If copy is true, then the FST
59 // // argument is a copy of the FST used in the previous call to this method
60 // // (to avoid unnecessary updates).
61 // void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override;
62 //
63 // // Are there paths from a state in the lookahead FST that can be read from
64 // // the curent matcher state?
65 // bool LookAheadFst(const Fst<Arc> &fst, StateId s) override;
66 //
67 // // Can the label be read from the current matcher state after possibly
68 // // following epsilon transitions?
69 // bool LookAheadLabel(Label label) const override;
70 //
71 // // The following methods allow looking ahead for an arbitrary rational set
72 // // of strings, specified by an FST and a state from which to begin the
73 // // matching. If the lookahead FST is a transducer, this looks on the side
74 // // different from the matcher's match_type (cf. composition).
75 // // Is there is a single non-epsilon arc found in the lookahead FST that
76 // // begins the path (after possibly following any epsilons) in the last call
77 // // to LookAheadFst? If so, return true and copy it to the arc argument;
78 // // otherwise, return false. Non-trivial implementations are useful for
79 // // label-pushing in composition.
80 // bool LookAheadPrefix(Arc *arc) override;
81 //
82 // // Gives an estimate of the combined weight of the paths in the lookahead
83 // // and matcher FSTs for the last call to LookAheadFst. Non-trivial
84 // // implementations are useful for weight-pushing in composition.
85 // Weight LookAheadWeight() const override;
86 // };
87 
88 // Look-ahead flags.
89 // Matcher is a lookahead matcher when match_type is MATCH_INPUT.
90 constexpr uint32 kInputLookAheadMatcher = 0x00000010;
91 
92 // Matcher is a lookahead matcher when match_type is MATCH_OUTPUT.
93 constexpr uint32 kOutputLookAheadMatcher = 0x00000020;
94 
95 // Is a non-trivial implementation of LookAheadWeight() method defined and
96 // if so, should it be used?
97 constexpr uint32 kLookAheadWeight = 0x00000040;
98 
99 // Is a non-trivial implementation of LookAheadPrefix() method defined and
100 // if so, should it be used?
101 constexpr uint32 kLookAheadPrefix = 0x00000080;
102 
103 // Look-ahead of matcher FST non-epsilon arcs?
104 constexpr uint32 kLookAheadNonEpsilons = 0x00000100;
105 
106 // Look-ahead of matcher FST epsilon arcs?
107 constexpr uint32 kLookAheadEpsilons = 0x00000200;
108 
109 // Ignore epsilon paths for the lookahead prefix? This gives correct results in
110 // composition only with an appropriate composition filter since it depends on
111 // the filter blocking the ignored paths.
112 constexpr uint32 kLookAheadNonEpsilonPrefix = 0x00000400;
113 
114 // For LabelLookAheadMatcher, save relabeling data to file?
115 constexpr uint32 kLookAheadKeepRelabelData = 0x00000800;
116 
117 // Flags used for lookahead matchers.
118 constexpr uint32 kLookAheadFlags = 0x00000ff0;
119 
120 // LookAhead Matcher interface, templated on the Arc definition; used
121 // for lookahead matcher specializations that are returned by the
122 // InitMatcher() Fst method.
123 template <class Arc>
124 class LookAheadMatcherBase : public MatcherBase<Arc> {
125  public:
126  using Label = typename Arc::Label;
127  using StateId = typename Arc::StateId;
128  using Weight = typename Arc::Weight;
129 
130  virtual void InitLookAheadFst(const Fst<Arc> &, bool copy = false) = 0;
131  virtual bool LookAheadFst(const Fst<Arc> &, StateId) = 0;
132  virtual bool LookAheadLabel(Label) const = 0;
133 
134  // Suggested concrete implementation of lookahead methods.
135 
136  bool LookAheadPrefix(Arc *arc) const {
137  if (prefix_arc_.nextstate != kNoStateId) {
138  *arc = prefix_arc_;
139  return true;
140  } else {
141  return false;
142  }
143  }
144 
145  Weight LookAheadWeight() const { return weight_; }
146 
147  protected:
148  // Concrete implementations for lookahead helper methods.
149 
150  void ClearLookAheadWeight() { weight_ = Weight::One(); }
151 
152  void SetLookAheadWeight(Weight weight) { weight_ = std::move(weight); }
153 
154  void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
155 
156  void SetLookAheadPrefix(Arc arc) { prefix_arc_ = std::move(arc); }
157 
158  private:
159  Arc prefix_arc_;
160  Weight weight_;
161 };
162 
163 // Doesn't actually lookahead, just declares that the future looks good.
164 template <class M>
166  : public LookAheadMatcherBase<typename M::FST::Arc> {
167  public:
168  using FST = typename M::FST;
169  using Arc = typename FST::Arc;
170  using Label = typename Arc::Label;
171  using StateId = typename Arc::StateId;
172  using Weight = typename Arc::Weight;
173 
174  // This makes a copy of the FST.
176  : matcher_(fst, match_type) {}
177 
178  // This doesn't copy the FST.
180  : matcher_(fst, match_type) {}
181 
182  // This makes a copy of the FST.
184  bool safe = false)
185  : matcher_(lmatcher.matcher_, safe) {}
186 
187  TrivialLookAheadMatcher<M> *Copy(bool safe = false) const override {
188  return new TrivialLookAheadMatcher<M>(*this, safe);
189  }
190 
191  MatchType Type(bool test) const override { return matcher_.Type(test); }
192 
193  void SetState(StateId s) final { return matcher_.SetState(s); }
194 
195  bool Find(Label label) final { return matcher_.Find(label); }
196 
197  bool Done() const final { return matcher_.Done(); }
198 
199  const Arc &Value() const final { return matcher_.Value(); }
200 
201  void Next() final { matcher_.Next(); }
202 
203  Weight Final(StateId s) const final { return matcher_.Final(s); }
204 
205  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
206 
207  const FST &GetFst() const override { return matcher_.GetFst(); }
208 
209  uint64 Properties(uint64 props) const override {
210  return matcher_.Properties(props);
211  }
212 
213  uint32 Flags() const override {
214  return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
215  }
216 
217  // Lookahead methods (all trivial).
218 
219  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {}
220 
221  bool LookAheadFst(const Fst<Arc> &, StateId) final { return true; }
222 
223  bool LookAheadLabel(Label) const final { return true; }
224 
225  bool LookAheadPrefix(Arc *) const { return false; }
226 
227  Weight LookAheadWeight() const { return Weight::One(); }
228 
229  private:
230  M matcher_;
231 };
232 
233 // Look-ahead of one transition. Template argument flags accepts flags to
234 // control behavior.
235 template <class M,
236  uint32 flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
237  kLookAheadWeight | kLookAheadPrefix>
238 class ArcLookAheadMatcher : public LookAheadMatcherBase<typename M::FST::Arc> {
239  public:
240  using FST = typename M::FST;
241  using Arc = typename FST::Arc;
242  using Label = typename Arc::Label;
243  using StateId = typename Arc::StateId;
244  using Weight = typename Arc::Weight;
246 
253 
254  enum : uint32 { kFlags = flags };
255 
256  // This makes a copy of the FST.
258  const FST &fst, MatchType match_type,
259  std::shared_ptr<MatcherData> data = std::shared_ptr<MatcherData>())
260  : matcher_(fst, match_type),
261  fst_(matcher_.GetFst()),
262  lfst_(nullptr),
263  state_(kNoStateId) {}
264 
265  // This doesn't copy the FST.
267  const FST *fst, MatchType match_type,
268  std::shared_ptr<MatcherData> data = std::shared_ptr<MatcherData>())
269  : matcher_(fst, match_type),
270  fst_(matcher_.GetFst()),
271  lfst_(nullptr),
272  state_(kNoStateId) {}
273 
274  // This makes a copy of the FST.
276  bool safe = false)
277  : matcher_(lmatcher.matcher_, safe),
278  fst_(matcher_.GetFst()),
279  lfst_(lmatcher.lfst_),
280  state_(kNoStateId) {}
281 
282  // General matcher methods.
283  ArcLookAheadMatcher<M, flags> *Copy(bool safe = false) const override {
284  return new ArcLookAheadMatcher<M, flags>(*this, safe);
285  }
286 
287  MatchType Type(bool test) const override { return matcher_.Type(test); }
288 
289  void SetState(StateId s) final {
290  state_ = s;
291  matcher_.SetState(s);
292  }
293 
294  bool Find(Label label) final { return matcher_.Find(label); }
295 
296  bool Done() const final { return matcher_.Done(); }
297 
298  const Arc &Value() const final { return matcher_.Value(); }
299 
300  void Next() final { matcher_.Next(); }
301 
302  Weight Final(StateId s) const final { return matcher_.Final(s); }
303 
304  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
305 
306  const FST &GetFst() const override { return fst_; }
307 
308  uint64 Properties(uint64 props) const override {
309  return matcher_.Properties(props);
310  }
311 
312  uint32 Flags() const override {
313  return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
314  kFlags;
315  }
316 
317  const MatcherData *GetData() const { return nullptr; }
318 
319  std::shared_ptr<MatcherData> GetSharedData() const {
320  return std::shared_ptr<MatcherData>();
321  }
322 
323  // Look-ahead methods.
324 
325  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
326  lfst_ = &fst;
327  }
328 
329  // Checks if there is a matching (possibly super-final) transition
330  // at (state_, s).
331  bool LookAheadFst(const Fst<Arc> &, StateId) final;
332 
333  bool LookAheadLabel(Label label) const final { return matcher_.Find(label); }
334 
335  private:
336  mutable M matcher_;
337  const FST &fst_; // Matcher FST.
338  const Fst<Arc> *lfst_; // Look-ahead FST.
339  StateId state_; // Matcher state.
340 };
341 
342 template <class M, uint32 flags>
344  StateId s) {
345  if (&fst != lfst_) InitLookAheadFst(fst);
346  bool result = false;
347  ssize_t nprefix = 0;
348  if (kFlags & kLookAheadWeight) ClearLookAheadWeight();
349  if (kFlags & kLookAheadPrefix) ClearLookAheadPrefix();
350  if (fst_.Final(state_) != Weight::Zero() &&
351  lfst_->Final(s) != Weight::Zero()) {
352  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
353  ++nprefix;
354  if (kFlags & kLookAheadWeight) {
356  Plus(LookAheadWeight(), Times(fst_.Final(state_), lfst_->Final(s))));
357  }
358  result = true;
359  }
360  if (matcher_.Find(kNoLabel)) {
361  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
362  ++nprefix;
363  if (kFlags & kLookAheadWeight) {
364  for (; !matcher_.Done(); matcher_.Next()) {
365  SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
366  }
367  }
368  result = true;
369  }
370  for (ArcIterator<Fst<Arc>> aiter(*lfst_, s); !aiter.Done(); aiter.Next()) {
371  const auto &arc = aiter.Value();
372  Label label = kNoLabel;
373  switch (matcher_.Type(false)) {
374  case MATCH_INPUT:
375  label = arc.olabel;
376  break;
377  case MATCH_OUTPUT:
378  label = arc.ilabel;
379  break;
380  default:
381  FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: Bad match type";
382  return true;
383  }
384  if (label == 0) {
385  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
386  if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
387  if (kFlags & kLookAheadWeight) {
388  SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
389  }
390  result = true;
391  } else if (matcher_.Find(label)) {
392  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
393  for (; !matcher_.Done(); matcher_.Next()) {
394  ++nprefix;
395  if (kFlags & kLookAheadWeight) {
397  Times(arc.weight, matcher_.Value().weight)));
398  }
399  if ((kFlags & kLookAheadPrefix) && nprefix == 1)
400  SetLookAheadPrefix(arc);
401  }
402  result = true;
403  }
404  }
405  if (kFlags & kLookAheadPrefix) {
406  if (nprefix == 1) {
407  ClearLookAheadWeight(); // Avoids double counting.
408  } else {
410  }
411  }
412  return result;
413 }
414 
415 // Template argument flags accepts flags to control behavior. It must include
416 // precisely one of kInputLookAheadMatcher or kOutputLookAheadMatcher.
417 template <class M,
418  uint32 flags = kLookAheadEpsilons | kLookAheadWeight |
419  kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
421  class Accumulator = DefaultAccumulator<typename M::Arc>,
424  : public LookAheadMatcherBase<typename M::FST::Arc> {
425  public:
426  using FST = typename M::FST;
427  using Arc = typename FST::Arc;
428  using Label = typename Arc::Label;
429  using StateId = typename Arc::StateId;
430  using Weight = typename Arc::Weight;
431  using MatcherData = typename Reachable::Data;
432 
439 
440  enum : uint32 { kFlags = flags };
441 
442  // This makes a copy of the FST.
444  const FST &fst, MatchType match_type,
445  std::shared_ptr<MatcherData> data = std::shared_ptr<MatcherData>(),
446  Accumulator *accumulator = nullptr)
447  : matcher_(fst, match_type),
448  lfst_(nullptr),
449  state_(kNoStateId),
450  error_(false) {
451  Init(fst, match_type, data, accumulator);
452  }
453 
454  // This doesn't copy the FST.
456  const FST *fst, MatchType match_type,
457  std::shared_ptr<MatcherData> data = std::shared_ptr<MatcherData>(),
458  Accumulator *accumulator = nullptr)
459  : matcher_(fst, match_type),
460  lfst_(nullptr),
461  state_(kNoStateId),
462  error_(false) {
463  Init(*fst, match_type, data, accumulator);
464  }
465 
466  // This makes a copy of the FST.
469  bool safe = false)
470  : matcher_(lmatcher.matcher_, safe),
471  lfst_(lmatcher.lfst_),
472  label_reachable_(lmatcher.label_reachable_
473  ? new Reachable(*lmatcher.label_reachable_, safe)
474  : nullptr),
475  state_(kNoStateId),
476  error_(lmatcher.error_) {}
477 
479  bool safe = false) const override {
481  safe);
482  }
483 
484  MatchType Type(bool test) const override { return matcher_.Type(test); }
485 
486  void SetState(StateId s) final {
487  if (state_ == s) return;
488  state_ = s;
489  match_set_state_ = false;
490  reach_set_state_ = false;
491  }
492 
493  bool Find(Label label) final {
494  if (!match_set_state_) {
495  matcher_.SetState(state_);
496  match_set_state_ = true;
497  }
498  return matcher_.Find(label);
499  }
500 
501  bool Done() const final { return matcher_.Done(); }
502 
503  const Arc &Value() const final { return matcher_.Value(); }
504 
505  void Next() final { matcher_.Next(); }
506 
507  Weight Final(StateId s) const final { return matcher_.Final(s); }
508 
509  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
510 
511  const FST &GetFst() const override { return matcher_.GetFst(); }
512 
513  uint64 Properties(uint64 inprops) const override {
514  auto outprops = matcher_.Properties(inprops);
515  if (error_ || (label_reachable_ && label_reachable_->Error())) {
516  outprops |= kError;
517  }
518  return outprops;
519  }
520 
521  uint32 Flags() const override {
522  if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
523  return matcher_.Flags() | kFlags | kInputLookAheadMatcher;
524  } else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
525  return matcher_.Flags() | kFlags | kOutputLookAheadMatcher;
526  } else {
527  return matcher_.Flags();
528  }
529  }
530 
531  const MatcherData *GetData() const {
532  return label_reachable_ ? label_reachable_->GetData() : nullptr;
533  };
534 
535  std::shared_ptr<MatcherData> GetSharedData() const {
536  return label_reachable_ ? label_reachable_->GetSharedData()
537  : std::shared_ptr<MatcherData>();
538  }
539  // Checks if there is a matching (possibly super-final) transition at
540  // (state_, s).
541  template <class LFST>
542  bool LookAheadFst(const LFST &fst, StateId s);
543 
544  // Required to make class concrete.
545  bool LookAheadFst(const Fst<Arc> &fst, StateId s) final {
546  return LookAheadFst<Fst<Arc>>(fst, s);
547  }
548 
549  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
550  lfst_ = &fst;
551  if (label_reachable_) {
552  const bool reach_input = Type(false) == MATCH_OUTPUT;
553  label_reachable_->ReachInit(fst, reach_input, copy);
554  }
555  }
556 
557  template <class LFST>
558  void InitLookAheadFst(const LFST &fst, bool copy = false) {
559  lfst_ = static_cast<const Fst<Arc> *>(&fst);
560  if (label_reachable_) {
561  const bool reach_input = Type(false) == MATCH_OUTPUT;
562  label_reachable_->ReachInit(fst, reach_input, copy);
563  }
564  }
565 
566  bool LookAheadLabel(Label label) const final {
567  if (label == 0) return true;
568  if (label_reachable_) {
569  if (!reach_set_state_) {
570  label_reachable_->SetState(state_);
571  reach_set_state_ = true;
572  }
573  return label_reachable_->Reach(label);
574  } else {
575  return true;
576  }
577  }
578 
579  private:
580  void Init(const FST &fst, MatchType match_type,
581  std::shared_ptr<MatcherData> data,
582  Accumulator *accumulator) {
583  if (!(kFlags & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
584  FSTERROR() << "LabelLookaheadMatcher: Bad matcher flags: " << kFlags;
585  error_ = true;
586  }
587  const bool reach_input = match_type == MATCH_INPUT;
588  if (data) {
589  if (reach_input == data->ReachInput()) {
590  label_reachable_.reset(new Reachable(data, accumulator));
591  }
592  } else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
593  (!reach_input && (kFlags & kOutputLookAheadMatcher))) {
594  label_reachable_.reset(new Reachable(fst, reach_input, accumulator,
595  kFlags & kLookAheadKeepRelabelData));
596  }
597  }
598 
599  mutable M matcher_;
600  const Fst<Arc> *lfst_; // Look-ahead FST.
601  std::unique_ptr<Reachable> label_reachable_; // Label reachability info.
602  StateId state_; // Matcher state.
603  bool match_set_state_; // matcher_.SetState called?
604  mutable bool reach_set_state_; // reachable_.SetState called?
605  bool error_; // Error encountered?
606 };
607 
608 template <class M, uint32 flags, class Accumulator, class Reachable>
609 template <class LFST>
610 inline bool LabelLookAheadMatcher<M, flags, Accumulator,
611  Reachable>::LookAheadFst(const LFST &fst,
612  StateId s) {
613  if (static_cast<const Fst<Arc> *>(&fst) != lfst_) InitLookAheadFst(fst);
616  if (!label_reachable_) return true;
617  label_reachable_->SetState(state_, s);
618  reach_set_state_ = true;
619  bool compute_weight = kFlags & kLookAheadWeight;
620  bool compute_prefix = kFlags & kLookAheadPrefix;
621  ArcIterator<LFST> aiter(fst, s);
622  aiter.SetFlags(kArcNoCache, kArcNoCache); // Makes caching optional.
623  const bool reach_arc = label_reachable_->Reach(
624  &aiter, 0, internal::NumArcs(*lfst_, s), compute_weight);
625  const auto lfinal = internal::Final(*lfst_, s);
626  const bool reach_final =
627  lfinal != Weight::Zero() && label_reachable_->ReachFinal();
628  if (reach_arc) {
629  const auto begin = label_reachable_->ReachBegin();
630  const auto end = label_reachable_->ReachEnd();
631  if (compute_prefix && end - begin == 1 && !reach_final) {
632  aiter.Seek(begin);
633  SetLookAheadPrefix(aiter.Value());
634  compute_weight = false;
635  } else if (compute_weight) {
636  SetLookAheadWeight(label_reachable_->ReachWeight());
637  }
638  }
639  if (reach_final && compute_weight) {
640  SetLookAheadWeight(reach_arc ? Plus(LookAheadWeight(), lfinal) : lfinal);
641  }
642  return reach_arc || reach_final;
643 }
644 
645 // Label-lookahead relabeling class.
646 template <class Arc, class Data = LabelReachableData<typename Arc::Label>>
648  public:
649  using Label = typename Arc::Label;
651 
652  // Relabels matcher FST (initialization function object).
653  template <typename Impl>
654  explicit LabelLookAheadRelabeler(std::shared_ptr<Impl> *impl);
655 
656  // Relabels arbitrary FST. Class LFST should be a label-lookahead FST.
657  template <class LFST>
658  static void Relabel(MutableFst<Arc> *fst, const LFST &mfst,
659  bool relabel_input) {
660  const auto *data = mfst.GetAddOn();
661  Reachable reachable(data->First() ? data->SharedFirst()
662  : data->SharedSecond());
663  reachable.Relabel(fst, relabel_input);
664  }
665 
666  // Returns relabeling pairs (cf. relabel.h::Relabel()). Class LFST should be a
667  // label-lookahead FST. If avoid_collisions is true, extra pairs are added to
668  // ensure no collisions when relabeling automata that have labels unseen here.
669  template <class LFST>
670  static void RelabelPairs(const LFST &mfst,
671  std::vector<std::pair<Label, Label>> *pairs,
672  bool avoid_collisions = false) {
673  const auto *data = mfst.GetAddOn();
674  Reachable reachable(data->First() ? data->SharedFirst()
675  : data->SharedSecond());
676  reachable.RelabelPairs(pairs, avoid_collisions);
677  }
678 };
679 
680 template <class Arc, class Data>
681 template <typename Impl>
683  std::shared_ptr<Impl> *impl) {
684  Fst<Arc> &fst = (*impl)->GetFst();
685  auto data = (*impl)->GetSharedAddOn();
686  const auto name = (*impl)->Type();
687  const bool is_mutable = fst.Properties(kMutable, false);
688  std::unique_ptr<MutableFst<Arc>> mfst;
689  if (is_mutable) {
690  mfst.reset(static_cast<MutableFst<Arc> *>(&fst));
691  } else {
692  mfst.reset(new VectorFst<Arc>(fst));
693  }
694  if (data->First()) { // reach_input.
695  Reachable reachable(data->SharedFirst());
696  reachable.Relabel(mfst.get(), true);
697  if (!FLAGS_save_relabel_ipairs.empty()) {
698  std::vector<std::pair<Label, Label>> pairs;
699  reachable.RelabelPairs(&pairs, true);
700  WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
701  }
702  } else {
703  Reachable reachable(data->SharedSecond());
704  reachable.Relabel(mfst.get(), false);
705  if (!FLAGS_save_relabel_opairs.empty()) {
706  std::vector<std::pair<Label, Label>> pairs;
707  reachable.RelabelPairs(&pairs, true);
708  WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
709  }
710  }
711  if (!is_mutable) {
712  *impl = std::make_shared<Impl>(*mfst, name);
713  (*impl)->SetAddOn(data);
714  }
715 }
716 
717 // Generic lookahead matcher, templated on the FST definition (a wrapper around
718 // a pointer to specific one).
719 template <class F>
721  public:
722  using FST = F;
723  using Arc = typename FST::Arc;
724  using Label = typename Arc::Label;
725  using StateId = typename Arc::StateId;
726  using Weight = typename Arc::Weight;
728 
729  // This makes a copy of the FST.
730  LookAheadMatcher(const FST &fst, MatchType match_type)
731  : owned_fst_(fst.Copy()),
732  base_(owned_fst_->InitMatcher(match_type)),
733  lookahead_(false) {
734  if (!base_) base_.reset(new SortedMatcher<FST>(owned_fst_.get(),
735  match_type));
736  }
737 
738  // This doesn't copy the FST.
739  LookAheadMatcher(const FST *fst, MatchType match_type)
740  : base_(fst->InitMatcher(match_type)),
741  lookahead_(false) {
742  if (!base_) base_.reset(new SortedMatcher<FST>(fst, match_type));
743  }
744 
745  // This makes a copy of the FST.
746  LookAheadMatcher(const LookAheadMatcher<FST> &matcher, bool safe = false)
747  : base_(matcher.base_->Copy(safe)),
748  lookahead_(matcher.lookahead_) { }
749 
750  // Takes ownership of base.
752  : base_(base), lookahead_(false) {}
753 
754  LookAheadMatcher<FST> *Copy(bool safe = false) const {
755  return new LookAheadMatcher<FST>(*this, safe);
756  }
757 
758  MatchType Type(bool test) const { return base_->Type(test); }
759 
760  void SetState(StateId s) { base_->SetState(s); }
761 
762  bool Find(Label label) { return base_->Find(label); }
763 
764  bool Done() const { return base_->Done(); }
765 
766  const Arc &Value() const { return base_->Value(); }
767 
768  void Next() { base_->Next(); }
769 
770  Weight Final(StateId s) const { return base_->Final(s); }
771 
772  ssize_t Priority(StateId s) { return base_->Priority(s); }
773 
774  const FST &GetFst() const {
775  return static_cast<const FST &>(base_->GetFst());
776  }
777 
778  uint64 Properties(uint64 props) const { return base_->Properties(props); }
779 
780  uint32 Flags() const { return base_->Flags(); }
781 
782  bool LookAheadLabel(Label label) const {
783  if (LookAheadCheck()) {
784  return static_cast<LBase *>(base_.get())->LookAheadLabel(label);
785  } else {
786  return true;
787  }
788  }
789 
790  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
791  if (LookAheadCheck()) {
792  return static_cast<LBase *>(base_.get())->LookAheadFst(fst, s);
793  } else {
794  return true;
795  }
796  }
797 
799  if (LookAheadCheck()) {
800  return static_cast<LBase *>(base_.get())->LookAheadWeight();
801  } else {
802  return Weight::One();
803  }
804  }
805 
806  bool LookAheadPrefix(Arc *arc) const {
807  if (LookAheadCheck()) {
808  return static_cast<LBase *>(base_.get())->LookAheadPrefix(arc);
809  } else {
810  return false;
811  }
812  }
813 
814  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
815  if (LookAheadCheck()) {
816  static_cast<LBase *>(base_.get())->InitLookAheadFst(fst, copy);
817  }
818  }
819 
820  private:
821  bool LookAheadCheck() const {
822  if (!lookahead_) {
823  lookahead_ =
824  base_->Flags() & (kInputLookAheadMatcher | kOutputLookAheadMatcher);
825  if (!lookahead_) {
826  FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
827  }
828  }
829  return lookahead_;
830  }
831 
832  std::unique_ptr<const FST> owned_fst_;
833  std::unique_ptr<MatcherBase<Arc>> base_;
834  mutable bool lookahead_;
835 
836  LookAheadMatcher &operator=(const LookAheadMatcher &) = delete;
837 };
838 
839 } // namespace fst
840 
841 #endif // FST_LOOKAHEAD_MATCHER_H_
typename Arc::Label Label
ssize_t Priority(StateId s) final
LabelLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=std::shared_ptr< MatcherData >(), Accumulator *accumulator=nullptr)
ssize_t Priority(StateId s) final
uint32 Flags() const override
bool LookAheadFst(const Fst< Arc > &fst, StateId s)
constexpr uint32 kLookAheadEpsilons
constexpr uint32 kLookAheadNonEpsilonPrefix
virtual bool LookAheadLabel(Label) const =0
bool LookAheadPrefix(Arc *arc) const
constexpr int kNoLabel
Definition: fst.h:179
void SetState(StateId s)
typename Arc::StateId StateId
uint64_t uint64
Definition: types.h:32
bool Find(Label label)
bool LookAheadLabel(Label) const final
typename Arc::Label Label
void InitLookAheadFst(const LFST &fst, bool copy=false)
typename Arc::StateId StateId
constexpr uint32 kLookAheadFlags
bool LookAheadPrefix(Arc *) const
typename Arc::Weight Weight
TrivialLookAheadMatcher< M > * Copy(bool safe=false) const override
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
bool LookAheadLabel(Label label) const
LabelLookAheadMatcher< M, flags, Accumulator, Reachable > * Copy(bool safe=false) const override
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:82
MatchType
Definition: fst.h:171
const FST & GetFst() const override
const Arc & Value() const
TrivialLookAheadMatcher(const TrivialLookAheadMatcher< M > &lmatcher, bool safe=false)
uint64 Properties(uint64 props) const
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
LookAheadMatcher< FST > * Copy(bool safe=false) const
Weight Final(StateId s) const final
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:88
TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
LabelLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=std::shared_ptr< MatcherData >(), Accumulator *accumulator=nullptr)
MatchType Type(bool test) const
void SetState(StateId s) final
constexpr uint32 kLookAheadWeight
virtual bool LookAheadFst(const Fst< Arc > &, StateId)=0
constexpr int kNoStateId
Definition: fst.h:180
bool Find(Label label) final
const Arc & Value() const final
const Arc & Value() const
Definition: fst.h:503
virtual uint64 Properties(uint64 mask, bool test) const =0
uint32 Flags() const override
bool LookAheadLabel(Label label) const final
#define FSTERROR()
Definition: util.h:35
LabelLookAheadRelabeler(std::shared_ptr< Impl > *impl)
uint64 Properties(uint64 props) const override
MatchType Type(bool test) const override
constexpr uint32 kLookAheadNonEpsilons
LookAheadMatcher(const FST *fst, MatchType match_type)
ArcLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=std::shared_ptr< MatcherData >())
static void RelabelPairs(const LFST &mfst, std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
MatchType Type(bool test) const override
const FST & GetFst() const
uint64 Properties(uint64 inprops) const override
const MatcherData * GetData() const
Weight Final(StateId s) const final
void Seek(size_t a)
Definition: fst.h:523
MatchType Type(bool test) const override
DECLARE_string(save_relabel_ipairs)
const FST & GetFst() const override
ssize_t Priority(StateId s)
void SetFlags(uint32 flags, uint32 mask)
Definition: fst.h:541
ArcLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=std::shared_ptr< MatcherData >())
void SetState(StateId s) final
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
constexpr uint32 kInputLookAheadMatcher
constexpr uint32 kOutputLookAheadMatcher
bool LookAheadPrefix(Arc *arc) const
const Arc & Value() const final
virtual const Fst< Arc > & GetFst() const =0
void SetLookAheadWeight(Weight weight)
ArcLookAheadMatcher(const ArcLookAheadMatcher< M, flags > &lmatcher, bool safe=false)
virtual const string & Type() const =0
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
bool LookAheadFst(const Fst< Arc > &fst, StateId s) final
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
const MatcherData * GetData() const
TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
uint32_t uint32
Definition: types.h:31
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false)
typename Arc::Weight Weight
typename Arc::StateId StateId
virtual MatcherBase< Arc > * Copy(bool safe=false) const =0
LabelLookAheadMatcher(const LabelLookAheadMatcher< M, flags, Accumulator, Reachable > &lmatcher, bool safe=false)
constexpr uint64 kError
Definition: properties.h:33
LookAheadMatcher(const FST &fst, MatchType match_type)
virtual void InitLookAheadFst(const Fst< Arc > &, bool copy=false)=0
uint64 Properties(uint64 props) const override
Weight Final(StateId s) const
typename Arc::Weight Weight
Definition: matcher.h:120
ArcLookAheadMatcher< M, flags > * Copy(bool safe=false) const override
typename Arc::Label Label
Definition: matcher.h:118
virtual MatchType Type(bool) const =0
std::shared_ptr< MatcherData > GetSharedData() const
constexpr uint32 kLookAheadPrefix
constexpr uint32 kLookAheadKeepRelabelData
Weight Final(StateId s) const final
Weight LookAheadWeight() const
typename Arc::Weight Weight
void SetState(StateId s) final
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
LookAheadMatcher(const LookAheadMatcher< FST > &matcher, bool safe=false)
ssize_t Priority(StateId s) final
std::shared_ptr< MatcherData > GetSharedData() const
constexpr uint64 kMutable
Definition: properties.h:30
bool LookAheadFst(const Fst< Arc > &, StateId) final
bool WriteLabelPairs(const string &filename, const std::vector< std::pair< Label, Label >> &pairs)
Definition: util.h:338
uint32 Flags() const override
const Arc & Value() const final
const FST & GetFst() const override
LookAheadMatcher(MatcherBase< Arc > *base)
typename Arc::StateId StateId
typename FST::Arc Arc
bool Find(Label label) final
typename Reachable::Data MatcherData
bool Find(Label label) final
bool LookAheadLabel(Label label) const final
bool LookAheadFst(const Fst< Arc > &, StateId) final
typename Arc::StateId StateId
Definition: matcher.h:119