FST  openfst-1.8.2
OpenFst Library
lookahead-matcher.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to add lookahead to FST matchers, useful for improving composition
19 // efficiency with certain inputs.
20 
21 #ifndef FST_LOOKAHEAD_MATCHER_H_
22 #define FST_LOOKAHEAD_MATCHER_H_
23 
24 #include <cstdint>
25 #include <memory>
26 #include <utility>
27 #include <vector>
28 
29 #include <fst/flags.h>
30 #include <fst/log.h>
31 
32 #include <fst/add-on.h>
33 #include <fst/const-fst.h>
34 #include <fst/fst.h>
35 #include <fst/label-reachable.h>
36 #include <fst/matcher.h>
37 
38 DECLARE_string(save_relabel_ipairs);
39 DECLARE_string(save_relabel_opairs);
40 
41 namespace fst {
42 
43 // Lookahead matches extend the matcher interface with following additional
44 // methods:
45 //
46 // template <class FST>
47 // class LookAheadMatcher {
48 // public:
49 // using Arc = typename FST::Arc;
50 // using Label = typename Arc::Label;
51 // using StateId = typename Arc::StateId;
52 // using Weight = typename Arc::Weight;
53 //
54 // // Required constructors.
55 // // This makes a copy of the FST.
56 // LookAheadMatcher(const FST &fst, MatchType match_type);
57 // // This doesn't copy the FST.
58 // LookAheadMatcher(const FST *fst, MatchType match_type);
59 // // This makes a copy of the FST.
60 // // See Copy() below.
61 // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
62 //
63 // // If safe = true, the copy is thread-safe (except the lookahead FST is
64 // // preserved). See Fst<>::Copy() for further doc.
65 // LookaheadMatcher *Copy(bool safe = false) const override;
66 
67 // // Below are methods for looking ahead for a match to a label and more
68 // // generally, to a rational set. Each returns false if there is definitely
69 // // not a match and returns true if there possibly is a match.
70 //
71 // // Optionally pre-specifies the lookahead FST that will be passed to
72 // // LookAheadFst() for possible precomputation. If copy is true, then the FST
73 // // argument is a copy of the FST used in the previous call to this method
74 // // (to avoid unnecessary updates).
75 // void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override;
76 //
77 // // Are there paths from a state in the lookahead FST that can be read from
78 // // the current matcher state?
79 // bool LookAheadFst(const Fst<Arc> &fst, StateId s) override;
80 //
81 // // Can the label be read from the current matcher state after possibly
82 // // following epsilon transitions?
83 // bool LookAheadLabel(Label label) const override;
84 //
85 // // The following methods allow looking ahead for an arbitrary rational set
86 // // of strings, specified by an FST and a state from which to begin the
87 // // matching. If the lookahead FST is a transducer, this looks on the side
88 // // different from the matcher's match_type (cf. composition).
89 // // Is there is a single non-epsilon arc found in the lookahead FST that
90 // // begins the path (after possibly following any epsilons) in the last call
91 // // to LookAheadFst? If so, return true and copy it to the arc argument;
92 // // otherwise, return false. Non-trivial implementations are useful for
93 // // label-pushing in composition.
94 // bool LookAheadPrefix(Arc *arc) override;
95 //
96 // // Gives an estimate of the combined weight of the paths in the lookahead
97 // // and matcher FSTs for the last call to LookAheadFst. Non-trivial
98 // // implementations are useful for weight-pushing in composition.
99 // Weight LookAheadWeight() const override;
100 // };
101 
102 // Look-ahead flags.
103 // Matcher is a lookahead matcher when match_type is MATCH_INPUT.
104 inline constexpr uint32_t kInputLookAheadMatcher = 0x00000010;
105 
106 // Matcher is a lookahead matcher when match_type is MATCH_OUTPUT.
107 inline constexpr uint32_t kOutputLookAheadMatcher = 0x00000020;
108 
109 // Is a non-trivial implementation of LookAheadWeight() method defined and
110 // if so, should it be used?
111 inline constexpr uint32_t kLookAheadWeight = 0x00000040;
112 
113 // Is a non-trivial implementation of LookAheadPrefix() method defined and
114 // if so, should it be used?
115 inline constexpr uint32_t kLookAheadPrefix = 0x00000080;
116 
117 // Look-ahead of matcher FST non-epsilon arcs?
118 inline constexpr uint32_t kLookAheadNonEpsilons = 0x00000100;
119 
120 // Look-ahead of matcher FST epsilon arcs?
121 inline constexpr uint32_t kLookAheadEpsilons = 0x00000200;
122 
123 // Ignore epsilon paths for the lookahead prefix? This gives correct results in
124 // composition only with an appropriate composition filter since it depends on
125 // the filter blocking the ignored paths.
126 inline constexpr uint32_t kLookAheadNonEpsilonPrefix = 0x00000400;
127 
128 // For LabelLookAheadMatcher, save relabeling data to file?
129 inline constexpr uint32_t kLookAheadKeepRelabelData = 0x00000800;
130 
131 // Flags used for lookahead matchers.
132 inline constexpr uint32_t kLookAheadFlags = 0x00000ff0;
133 
134 // LookAhead Matcher interface, templated on the Arc definition; used
135 // for lookahead matcher specializations that are returned by the
136 // InitMatcher() Fst method.
137 template <class Arc>
138 class LookAheadMatcherBase : public MatcherBase<Arc> {
139  public:
140  using Label = typename Arc::Label;
141  using StateId = typename Arc::StateId;
142  using Weight = typename Arc::Weight;
143 
144  virtual void InitLookAheadFst(const Fst<Arc> &, bool copy = false) = 0;
145  virtual bool LookAheadFst(const Fst<Arc> &, StateId) = 0;
146  virtual bool LookAheadLabel(Label) const = 0;
147 
148  // Suggested concrete implementation of lookahead methods.
149 
150  bool LookAheadPrefix(Arc *arc) const {
151  if (prefix_arc_.nextstate != kNoStateId) {
152  *arc = prefix_arc_;
153  return true;
154  } else {
155  return false;
156  }
157  }
158 
159  Weight LookAheadWeight() const { return weight_; }
160 
161  protected:
162  // Concrete implementations for lookahead helper methods.
163 
164  void ClearLookAheadWeight() { weight_ = Weight::One(); }
165 
166  void SetLookAheadWeight(Weight weight) { weight_ = std::move(weight); }
167 
168  void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
169 
170  void SetLookAheadPrefix(Arc arc) { prefix_arc_ = std::move(arc); }
171 
172  private:
173  Arc prefix_arc_;
174  Weight weight_;
175 };
176 
177 // Doesn't actually lookahead, just declares that the future looks good.
178 template <class M>
180  : public LookAheadMatcherBase<typename M::FST::Arc> {
181  public:
182  using FST = typename M::FST;
183  using Arc = typename FST::Arc;
184  using Label = typename Arc::Label;
185  using StateId = typename Arc::StateId;
186  using Weight = typename Arc::Weight;
187 
188  // This makes a copy of the FST.
190  : matcher_(fst, match_type) {}
191 
192  // This doesn't copy the FST.
194  : matcher_(fst, match_type) {}
195 
196  // This makes a copy of the FST.
198  bool safe = false)
199  : matcher_(lmatcher.matcher_, safe) {}
200 
201  TrivialLookAheadMatcher *Copy(bool safe = false) const override {
202  return new TrivialLookAheadMatcher(*this, safe);
203  }
204 
205  MatchType Type(bool test) const override { return matcher_.Type(test); }
206 
207  void SetState(StateId s) final { return matcher_.SetState(s); }
208 
209  bool Find(Label label) final { return matcher_.Find(label); }
210 
211  bool Done() const final { return matcher_.Done(); }
212 
213  const Arc &Value() const final { return matcher_.Value(); }
214 
215  void Next() final { matcher_.Next(); }
216 
217  Weight Final(StateId s) const final { return matcher_.Final(s); }
218 
219  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
220 
221  const FST &GetFst() const override { return matcher_.GetFst(); }
222 
223  uint64_t Properties(uint64_t props) const override {
224  return matcher_.Properties(props);
225  }
226 
227  uint32_t Flags() const override {
228  return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
229  }
230 
231  // Lookahead methods (all trivial).
232 
233  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {}
234 
235  bool LookAheadFst(const Fst<Arc> &, StateId) final { return true; }
236 
237  bool LookAheadLabel(Label) const final { return true; }
238 
239  bool LookAheadPrefix(Arc *) const { return false; }
240 
241  Weight LookAheadWeight() const { return Weight::One(); }
242 
243  private:
244  M matcher_;
245 };
246 
247 // Look-ahead of one transition. Template argument flags accepts flags to
248 // control behavior.
249 template <class M, uint32_t flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
250  kLookAheadWeight | kLookAheadPrefix>
251 class ArcLookAheadMatcher : public LookAheadMatcherBase<typename M::FST::Arc> {
252  public:
253  using FST = typename M::FST;
254  using Arc = typename FST::Arc;
255  using Label = typename Arc::Label;
256  using StateId = typename Arc::StateId;
257  using Weight = typename Arc::Weight;
259 
266 
267  static constexpr uint32_t kFlags = flags;
268 
269  // This makes a copy of the FST.
270  ArcLookAheadMatcher(const FST &fst, MatchType match_type,
271  std::shared_ptr<MatcherData> data = nullptr)
272  : matcher_(fst, match_type),
273  fst_(matcher_.GetFst()),
274  lfst_(nullptr),
275  state_(kNoStateId) {}
276 
277  // This doesn't copy the FST.
278  ArcLookAheadMatcher(const FST *fst, MatchType match_type,
279  std::shared_ptr<MatcherData> data = nullptr)
280  : matcher_(fst, match_type),
281  fst_(matcher_.GetFst()),
282  lfst_(nullptr),
283  state_(kNoStateId) {}
284 
285  // This makes a copy of the FST.
286  ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe = false)
287  : matcher_(lmatcher.matcher_, safe),
288  fst_(matcher_.GetFst()),
289  lfst_(lmatcher.lfst_),
290  state_(kNoStateId) {}
291 
292  // General matcher methods.
293  ArcLookAheadMatcher *Copy(bool safe = false) const override {
294  return new ArcLookAheadMatcher(*this, safe);
295  }
296 
297  MatchType Type(bool test) const override { return matcher_.Type(test); }
298 
299  void SetState(StateId s) final {
300  state_ = s;
301  matcher_.SetState(s);
302  }
303 
304  bool Find(Label label) final { return matcher_.Find(label); }
305 
306  bool Done() const final { return matcher_.Done(); }
307 
308  const Arc &Value() const final { return matcher_.Value(); }
309 
310  void Next() final { matcher_.Next(); }
311 
312  Weight Final(StateId s) const final { return matcher_.Final(s); }
313 
314  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
315 
316  const FST &GetFst() const override { return fst_; }
317 
318  uint64_t Properties(uint64_t props) const override {
319  return matcher_.Properties(props);
320  }
321 
322  uint32_t Flags() const override {
323  return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
324  kFlags;
325  }
326 
327  const MatcherData *GetData() const { return nullptr; }
328 
329  std::shared_ptr<MatcherData> GetSharedData() const { return nullptr; }
330 
331  // Look-ahead methods.
332 
333  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
334  lfst_ = &fst;
335  }
336 
337  // Checks if there is a matching (possibly super-final) transition
338  // at (state_, s).
339  bool LookAheadFst(const Fst<Arc> &, StateId) final;
340 
341  bool LookAheadLabel(Label label) const final { return matcher_.Find(label); }
342 
343  private:
344  mutable M matcher_;
345  const FST &fst_; // Matcher FST.
346  const Fst<Arc> *lfst_; // Look-ahead FST.
347  StateId state_; // Matcher state.
348 };
349 
350 template <class M, uint32_t flags>
352  StateId s) {
353  if (&fst != lfst_) InitLookAheadFst(fst);
354  bool result = false;
355  ssize_t nprefix = 0;
356  if (kFlags & kLookAheadWeight) ClearLookAheadWeight();
357  if (kFlags & kLookAheadPrefix) ClearLookAheadPrefix();
358  if (fst_.Final(state_) != Weight::Zero() &&
359  lfst_->Final(s) != Weight::Zero()) {
360  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
361  ++nprefix;
362  if (kFlags & kLookAheadWeight) {
364  Plus(LookAheadWeight(), Times(fst_.Final(state_), lfst_->Final(s))));
365  }
366  result = true;
367  }
368  if (matcher_.Find(kNoLabel)) {
369  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
370  ++nprefix;
371  if (kFlags & kLookAheadWeight) {
372  for (; !matcher_.Done(); matcher_.Next()) {
373  SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
374  }
375  }
376  result = true;
377  }
378  for (ArcIterator<Fst<Arc>> aiter(*lfst_, s); !aiter.Done(); aiter.Next()) {
379  const auto &arc = aiter.Value();
380  Label label = kNoLabel;
381  switch (matcher_.Type(false)) {
382  case MATCH_INPUT:
383  label = arc.olabel;
384  break;
385  case MATCH_OUTPUT:
386  label = arc.ilabel;
387  break;
388  default:
389  FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: Bad match type";
390  return true;
391  }
392  if (label == 0) {
393  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
394  if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
395  if (kFlags & kLookAheadWeight) {
396  SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
397  }
398  result = true;
399  } else if (matcher_.Find(label)) {
400  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
401  for (; !matcher_.Done(); matcher_.Next()) {
402  ++nprefix;
403  if (kFlags & kLookAheadWeight) {
405  Times(arc.weight, matcher_.Value().weight)));
406  }
407  if ((kFlags & kLookAheadPrefix) && nprefix == 1)
408  SetLookAheadPrefix(arc);
409  }
410  result = true;
411  }
412  }
413  if (kFlags & kLookAheadPrefix) {
414  if (nprefix == 1) {
415  ClearLookAheadWeight(); // Avoids double counting.
416  } else {
418  }
419  }
420  return result;
421 }
422 
423 // Template argument flags accepts flags to control behavior. It must include
424 // precisely one of kInputLookAheadMatcher or kOutputLookAheadMatcher.
425 template <class M,
426  uint32_t flags = kLookAheadEpsilons | kLookAheadWeight |
427  kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
432  : public LookAheadMatcherBase<typename M::FST::Arc> {
433  public:
434  using Matcher = M;
435  using Accumulator = Accum;
436  using Reachable = R;
437 
438  using FST = typename M::FST;
439  using Arc = typename FST::Arc;
440  using Label = typename Arc::Label;
441  using StateId = typename Arc::StateId;
442  using Weight = typename Arc::Weight;
443  using MatcherData = typename Reachable::Data;
444 
451 
452  static_assert(!(flags & kInputLookAheadMatcher) !=
453  !(flags & kOutputLookAheadMatcher),
454  "Must include precisely one of kInputLookAheadMatcher and "
455  "kOutputLookAheadMatcher");
456  static constexpr uint32_t kFlags = flags;
457 
458  // This makes a copy of the FST.
460  std::shared_ptr<MatcherData> data = nullptr,
461  std::unique_ptr<Accumulator> accumulator = nullptr)
462  : matcher_(fst, match_type),
463  lfst_(nullptr),
464  state_(kNoStateId),
465  error_(false) {
466  Init(fst, match_type, data, std::move(accumulator));
467  }
468 
469  // This doesn't copy the FST.
471  std::shared_ptr<MatcherData> data = nullptr,
472  std::unique_ptr<Accumulator> accumulator = nullptr)
473  : matcher_(fst, match_type),
474  lfst_(nullptr),
475  state_(kNoStateId),
476  error_(false) {
477  Init(*fst, match_type, data, std::move(accumulator));
478  }
479 
480  // This makes a copy of the FST.
482  bool safe = false)
483  : matcher_(lmatcher.matcher_, safe),
484  lfst_(lmatcher.lfst_),
485  label_reachable_(lmatcher.label_reachable_
486  ? new Reachable(*lmatcher.label_reachable_, safe)
487  : nullptr),
488  state_(kNoStateId),
489  error_(lmatcher.error_) {}
490 
491  LabelLookAheadMatcher *Copy(bool safe = false) const override {
492  return new LabelLookAheadMatcher(*this, safe);
493  }
494 
495  MatchType Type(bool test) const override { return matcher_.Type(test); }
496 
497  void SetState(StateId s) final {
498  if (state_ == s) return;
499  state_ = s;
500  match_set_state_ = false;
501  reach_set_state_ = false;
502  }
503 
504  bool Find(Label label) final {
505  if (!match_set_state_) {
506  matcher_.SetState(state_);
507  match_set_state_ = true;
508  }
509  return matcher_.Find(label);
510  }
511 
512  bool Done() const final { return matcher_.Done(); }
513 
514  const Arc &Value() const final { return matcher_.Value(); }
515 
516  void Next() final { matcher_.Next(); }
517 
518  Weight Final(StateId s) const final { return matcher_.Final(s); }
519 
520  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
521 
522  const FST &GetFst() const override { return matcher_.GetFst(); }
523 
524  uint64_t Properties(uint64_t inprops) const override {
525  auto outprops = matcher_.Properties(inprops);
526  if (error_ || (label_reachable_ && label_reachable_->Error())) {
527  outprops |= kError;
528  }
529  return outprops;
530  }
531 
532  uint32_t Flags() const override {
533  if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
534  return matcher_.Flags() | kFlags | kInputLookAheadMatcher;
535  } else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
536  return matcher_.Flags() | kFlags | kOutputLookAheadMatcher;
537  } else {
538  return matcher_.Flags();
539  }
540  }
541 
542  const MatcherData *GetData() const {
543  return label_reachable_ ? label_reachable_->GetData() : nullptr;
544  }
545 
546  std::shared_ptr<MatcherData> GetSharedData() const {
547  return label_reachable_ ? label_reachable_->GetSharedData() : nullptr;
548  }
549  // Checks if there is a matching (possibly super-final) transition at
550  // (state_, s).
551  template <class LFST>
552  bool LookAheadFst(const LFST &fst, StateId s);
553 
554  // Required to make class concrete.
555  bool LookAheadFst(const Fst<Arc> &fst, StateId s) final {
556  return LookAheadFst<Fst<Arc>>(fst, s);
557  }
558 
559  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
560  lfst_ = &fst;
561  if (label_reachable_) {
562  const bool reach_input = Type(false) == MATCH_OUTPUT;
563  label_reachable_->ReachInit(fst, reach_input, copy);
564  }
565  }
566 
567  template <class LFST>
568  void InitLookAheadFst(const LFST &fst, bool copy = false) {
569  lfst_ = &fst;
570  if (label_reachable_) {
571  const bool reach_input = Type(false) == MATCH_OUTPUT;
572  label_reachable_->ReachInit(fst, reach_input, copy);
573  }
574  }
575 
576  bool LookAheadLabel(Label label) const final {
577  if (label == 0) return true;
578  if (label_reachable_) {
579  if (!reach_set_state_) {
580  label_reachable_->SetState(state_);
581  reach_set_state_ = true;
582  }
583  return label_reachable_->Reach(label);
584  } else {
585  return true;
586  }
587  }
588 
589  private:
590  void Init(const FST &fst, MatchType match_type,
591  std::shared_ptr<MatcherData> data,
592  std::unique_ptr<Accumulator> accumulator) {
593  const bool reach_input = match_type == MATCH_INPUT;
594  if (data) {
595  if (reach_input == data->ReachInput()) {
596  label_reachable_ =
597  std::make_unique<Reachable>(data, std::move(accumulator));
598  }
599  } else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
600  (!reach_input && (kFlags & kOutputLookAheadMatcher))) {
601  label_reachable_ =
602  std::make_unique<Reachable>(fst, reach_input, std::move(accumulator),
603  kFlags & kLookAheadKeepRelabelData);
604  }
605  }
606 
607  mutable M matcher_;
608  const Fst<Arc> *lfst_; // Look-ahead FST.
609  std::unique_ptr<Reachable> label_reachable_; // Label reachability info.
610  StateId state_; // Matcher state.
611  bool match_set_state_; // matcher_.SetState called?
612  mutable bool reach_set_state_; // reachable_.SetState called?
613  bool error_; // Error encountered?
614 };
615 
616 template <class M, uint32_t flags, class Accumulator, class Reachable>
617 template <class LFST>
618 inline bool LabelLookAheadMatcher<M, flags, Accumulator,
620  StateId s) {
621  if (&fst != lfst_) InitLookAheadFst(fst);
624  if (!label_reachable_) return true;
625  label_reachable_->SetState(state_, s);
626  reach_set_state_ = true;
627  bool compute_weight = kFlags & kLookAheadWeight;
628  bool compute_prefix = kFlags & kLookAheadPrefix;
629  ArcIterator<LFST> aiter(fst, s);
630  aiter.SetFlags(kArcNoCache, kArcNoCache); // Makes caching optional.
631  const bool reach_arc = label_reachable_->Reach(
632  &aiter, 0, internal::NumArcs(*lfst_, s), compute_weight);
633  const auto lfinal = internal::Final(*lfst_, s);
634  const bool reach_final =
635  lfinal != Weight::Zero() && label_reachable_->ReachFinal();
636  if (reach_arc) {
637  const auto begin = label_reachable_->ReachBegin();
638  const auto end = label_reachable_->ReachEnd();
639  if (compute_prefix && end - begin == 1 && !reach_final) {
640  aiter.Seek(begin);
641  SetLookAheadPrefix(aiter.Value());
642  compute_weight = false;
643  } else if (compute_weight) {
644  SetLookAheadWeight(label_reachable_->ReachWeight());
645  }
646  }
647  if (reach_final && compute_weight) {
648  SetLookAheadWeight(reach_arc ? Plus(LookAheadWeight(), lfinal) : lfinal);
649  }
650  return reach_arc || reach_final;
651 }
652 
653 // Relabels the fst with Reachable::Reachable. Relabels input
654 // if data.First() is non-null, otherwise relabels output.
655 // Optionally saves the input/output label pairs to a file
656 // if save_relabel_ipairs/opairs is non-empty.
657 template <class Reachable, class FST, class Data>
658 void RelabelForReachable(FST *fst, const Data &data,
659  const std::string &save_relabel_ipairs,
660  const std::string &save_relabel_opairs) {
661  using Label = typename FST::Arc::Label;
662  if (data.First() != nullptr) { // reach_input.
663  Reachable reachable(data.SharedFirst());
664  reachable.Relabel(fst, /*relabel_input=*/true);
665  if (!save_relabel_ipairs.empty()) {
666  std::vector<std::pair<Label, Label>> pairs;
667  reachable.RelabelPairs(&pairs, /*avoid_collisions=*/true);
668  WriteLabelPairs(save_relabel_ipairs, pairs);
669  }
670  } else {
671  Reachable reachable(data.SharedSecond());
672  reachable.Relabel(fst, /*relabel_input=*/false);
673  if (!save_relabel_opairs.empty()) {
674  std::vector<std::pair<Label, Label>> pairs;
675  reachable.RelabelPairs(&pairs, /*avoid_collisions=*/true);
676  WriteLabelPairs(save_relabel_opairs, pairs);
677  }
678  }
679 }
680 
681 // Label-lookahead relabeling class.
682 template <class Arc, class Data = LabelReachableData<typename Arc::Label>>
684  public:
685  using Label = typename Arc::Label;
687 
688  // Relabels matcher FST (initialization function object).
689  template <typename Impl>
690  explicit LabelLookAheadRelabeler(std::shared_ptr<Impl> *impl);
691 
692  // Relabels arbitrary FST. Class LFST should be a label-lookahead FST.
693  template <class LFST>
694  static void Relabel(MutableFst<Arc> *fst, const LFST &mfst,
695  bool relabel_input) {
696  const auto *data = mfst.GetAddOn();
697  Reachable reachable(data->First() ? data->SharedFirst()
698  : data->SharedSecond());
699  reachable.Relabel(fst, relabel_input);
700  }
701 
702  // Returns relabeling pairs (cf. relabel.h::Relabel()). Class LFST should be a
703  // label-lookahead FST. If avoid_collisions is true, extra pairs are added to
704  // ensure no collisions when relabeling automata that have labels unseen here.
705  template <class LFST>
706  static void RelabelPairs(const LFST &mfst,
707  std::vector<std::pair<Label, Label>> *pairs,
708  bool avoid_collisions = false) {
709  const auto *data = mfst.GetAddOn();
710  Reachable reachable(data->First() ? data->SharedFirst()
711  : data->SharedSecond());
712  reachable.RelabelPairs(pairs, avoid_collisions);
713  }
714 };
715 
716 template <class Arc, class Data>
717 template <typename Impl>
719  std::shared_ptr<Impl> *impl) {
720  Fst<Arc> &fst = (*impl)->GetFst();
721  auto data = (*impl)->GetSharedAddOn();
722  const auto name = (*impl)->Type();
723  const bool is_mutable = fst.Properties(kMutable, false);
724  std::unique_ptr<MutableFst<Arc>> mfst;
725  if (is_mutable) {
726  // Borrow pointer from fst without increasing ref count; it will
727  // be released below. We do not want to call Copy() since that would
728  // do a deep copy when the Fst is modified.
729  mfst.reset(down_cast<MutableFst<Arc> *>(&fst));
730  } else {
731  mfst = std::make_unique<VectorFst<Arc>>(fst);
732  }
733 
734  RelabelForReachable<Reachable>(mfst.get(), *data,
735  FST_FLAGS_save_relabel_ipairs,
736  FST_FLAGS_save_relabel_opairs);
737 
738  if (is_mutable) {
739  // Pointer was just borrowed, don't delete it.
740  mfst.release();
741  } else {
742  *impl = std::make_shared<Impl>(*mfst, name);
743  (*impl)->SetAddOn(data);
744  }
745 }
746 
747 // Generic lookahead matcher, templated on the FST definition (a wrapper around
748 // a pointer to specific one).
749 template <class F>
751  public:
752  using FST = F;
753  using Arc = typename FST::Arc;
754  using Label = typename Arc::Label;
755  using StateId = typename Arc::StateId;
756  using Weight = typename Arc::Weight;
758 
759  // This makes a copy of the FST.
760  LookAheadMatcher(const FST &fst, MatchType match_type)
761  : owned_fst_(fst.Copy()),
762  base_(owned_fst_->InitMatcher(match_type)),
763  lookahead_(false) {
764  if (!base_)
765  base_ =
766  std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
767  }
768 
769  // This doesn't copy the FST.
770  LookAheadMatcher(const FST *fst, MatchType match_type)
771  : base_(fst->InitMatcher(match_type)), lookahead_(false) {
772  if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
773  }
774 
775  // This makes a copy of the FST.
776  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false)
777  : base_(matcher.base_->Copy(safe)), lookahead_(matcher.lookahead_) {}
778 
779  // Takes ownership of base.
781  : base_(base), lookahead_(false) {}
782 
783  LookAheadMatcher *Copy(bool safe = false) const {
784  return new LookAheadMatcher(*this, safe);
785  }
786 
787  MatchType Type(bool test) const { return base_->Type(test); }
788 
789  void SetState(StateId s) { base_->SetState(s); }
790 
791  bool Find(Label label) { return base_->Find(label); }
792 
793  bool Done() const { return base_->Done(); }
794 
795  const Arc &Value() const { return base_->Value(); }
796 
797  void Next() { base_->Next(); }
798 
799  Weight Final(StateId s) const { return base_->Final(s); }
800 
801  ssize_t Priority(StateId s) { return base_->Priority(s); }
802 
803  const FST &GetFst() const { return down_cast<const FST &>(base_->GetFst()); }
804 
805  uint64_t Properties(uint64_t props) const { return base_->Properties(props); }
806 
807  uint32_t Flags() const { return base_->Flags(); }
808 
809  bool LookAheadLabel(Label label) const {
810  if (LookAheadCheck()) {
811  return down_cast<LBase *>(base_.get())->LookAheadLabel(label);
812  } else {
813  return true;
814  }
815  }
816 
817  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
818  if (LookAheadCheck()) {
819  return down_cast<LBase *>(base_.get())->LookAheadFst(fst, s);
820  } else {
821  return true;
822  }
823  }
824 
826  if (LookAheadCheck()) {
827  return down_cast<LBase *>(base_.get())->LookAheadWeight();
828  } else {
829  return Weight::One();
830  }
831  }
832 
833  bool LookAheadPrefix(Arc *arc) const {
834  if (LookAheadCheck()) {
835  return down_cast<LBase *>(base_.get())->LookAheadPrefix(arc);
836  } else {
837  return false;
838  }
839  }
840 
841  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
842  if (LookAheadCheck()) {
843  down_cast<LBase *>(base_.get())->InitLookAheadFst(fst, copy);
844  }
845  }
846 
847  private:
848  bool LookAheadCheck() const {
849  if (!lookahead_) {
850  lookahead_ =
851  base_->Flags() & (kInputLookAheadMatcher | kOutputLookAheadMatcher);
852  if (!lookahead_) {
853  FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
854  }
855  }
856  return lookahead_;
857  }
858 
859  std::unique_ptr<const FST> owned_fst_;
860  std::unique_ptr<MatcherBase<Arc>> base_;
861  mutable bool lookahead_;
862 
863  LookAheadMatcher &operator=(const LookAheadMatcher &) = delete;
864 };
865 
866 } // namespace fst
867 
868 #endif // FST_LOOKAHEAD_MATCHER_H_
typename Arc::Label Label
ssize_t Priority(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s)
ArcLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
virtual bool LookAheadLabel(Label) const =0
constexpr uint64_t kMutable
Definition: properties.h:48
bool LookAheadLabel(Label label) const final
bool LookAheadPrefix(Arc *arc) const
uint64_t Properties(uint64_t props) const
constexpr int kNoLabel
Definition: fst.h:201
constexpr uint8_t kArcNoCache
Definition: fst.h:452
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void SetState(StateId s)
typename Arc::StateId StateId
bool Find(Label label)
bool LookAheadLabel(Label) const final
typename Arc::Label Label
TrivialLookAheadMatcher * Copy(bool safe=false) const override
const Arc & Value() const final
TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, bool safe=false)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
typename Arc::StateId StateId
bool LookAheadPrefix(Arc *) const
typename Arc::Weight Weight
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
bool LookAheadLabel(Label label) const
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:97
MatchType
Definition: fst.h:193
std::shared_ptr< MatcherData > GetSharedData() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
const FST & GetFst() const override
constexpr uint64_t kError
Definition: properties.h:51
const Arc & Value() const
ArcLookAheadMatcher * Copy(bool safe=false) const override
constexpr uint32_t kLookAheadFlags
MatchType Type(bool test) const override
typename Arc::Weight Weight
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:103
To down_cast(From *f)
Definition: compat.h:50
TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
LabelLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
MatchType Type(bool test) const
void SetState(StateId s) final
virtual bool LookAheadFst(const Fst< Arc > &, StateId)=0
void SetState(StateId s) final
constexpr int kNoStateId
Definition: fst.h:202
bool Find(Label label) final
Label Relabel(Label label)
const Arc & Value() const final
const Arc & Value() const
Definition: fst.h:537
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
ssize_t Priority(StateId s) final
constexpr uint32_t kOutputLookAheadMatcher
#define FSTERROR()
Definition: util.h:53
bool Find(Label label) final
LabelLookAheadRelabeler(std::shared_ptr< Impl > *impl)
void SetFlags(uint8_t flags, uint8_t mask)
Definition: fst.h:571
const FST & GetFst() const override
MatchType Type(bool test) const override
LabelLookAheadMatcher(const LabelLookAheadMatcher &lmatcher, bool safe=false)
constexpr uint32_t kLookAheadEpsilons
constexpr uint32_t kLookAheadNonEpsilonPrefix
constexpr uint32_t kLookAheadKeepRelabelData
LookAheadMatcher(const FST *fst, MatchType match_type)
static void RelabelPairs(const LFST &mfst, std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
const FST & GetFst() const
constexpr uint32_t kLookAheadNonEpsilons
const MatcherData * GetData() const
Weight Final(StateId s) const final
void Seek(size_t a)
Definition: fst.h:557
MatchType Type(bool test) const override
DECLARE_string(save_relabel_ipairs)
typename Arc::Label Label
const FST & GetFst() const override
constexpr uint32_t kInputLookAheadMatcher
ssize_t Priority(StateId s)
typename Reachable::Data MatcherData
LookAheadMatcher * Copy(bool safe=false) const
constexpr uint32_t kLookAheadPrefix
void SetState(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s) final
bool LookAheadPrefix(Arc *arc) const
virtual const Fst< Arc > & GetFst() const =0
void SetLookAheadWeight(Weight weight)
ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe=false)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
uint64_t Properties(uint64_t inprops) const override
LookAheadMatcher(const LookAheadMatcher &matcher, bool safe=false)
LabelLookAheadMatcher * Copy(bool safe=false) const override
uint32_t Flags() const override
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false)
uint32_t Flags() const override
const MatcherData * GetData() const
ArcLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
constexpr uint32_t kLookAheadWeight
virtual MatcherBase * Copy(bool safe=false) const =0
typename Arc::StateId StateId
typename Arc::StateId StateId
LookAheadMatcher(const FST &fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
virtual void InitLookAheadFst(const Fst< Arc > &, bool copy=false)=0
Weight Final(StateId s) const
typename Arc::Weight Weight
Definition: matcher.h:138
typename Arc::Label Label
Definition: matcher.h:136
virtual MatchType Type(bool) const =0
void RelabelForReachable(FST *fst, const Data &data, const std::string &save_relabel_ipairs, const std::string &save_relabel_opairs)
Weight Final(StateId s) const final
Weight LookAheadWeight() const
typename Arc::Weight Weight
Weight Final(StateId s) const final
bool WriteLabelPairs(const std::string &source, const std::vector< std::pair< Label, Label >> &pairs)
Definition: util.h:406
ssize_t Priority(StateId s) final
std::shared_ptr< MatcherData > GetSharedData() const
uint32_t Flags() const override
bool LookAheadFst(const Fst< Arc > &, StateId) final
const Arc & Value() const final
uint64_t Properties(uint64_t props) const override
uint32_t Flags() const
LookAheadMatcher(MatcherBase< Arc > *base)
typename FST::Arc Arc
void InitLookAheadFst(const LFST &fst, bool copy=false)
LabelLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
bool Find(Label label) final
bool LookAheadLabel(Label label) const final
bool LookAheadFst(const Fst< Arc > &, StateId) final
typename Arc::StateId StateId
Definition: matcher.h:137