FST  openfst-1.8.4
OpenFst Library
lookahead-matcher.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to add lookahead to FST matchers, useful for improving composition
19 // efficiency with certain inputs.
20 
21 #ifndef FST_LOOKAHEAD_MATCHER_H_
22 #define FST_LOOKAHEAD_MATCHER_H_
23 
24 #include <sys/types.h>
25 
26 #include <algorithm>
27 #include <cstdint>
28 #include <memory>
29 #include <utility>
30 #include <vector>
31 
32 #include <fst/log.h>
33 #include <fst/accumulator.h>
34 #include <fst/add-on.h>
35 #include <fst/fst.h>
36 #include <fst/label-reachable.h>
37 #include <fst/matcher.h>
38 #include <fst/mutable-fst.h>
39 #include <fst/properties.h>
40 #include <fst/util.h>
41 #include <fst/vector-fst.h>
42 #include <string_view>
43 
44 DECLARE_string(save_relabel_ipairs);
45 DECLARE_string(save_relabel_opairs);
46 
47 namespace fst {
48 
49 // Lookahead matches extend the matcher interface with following additional
50 // methods:
51 //
52 // template <class FST>
53 // class LookAheadMatcher {
54 // public:
55 // using Arc = typename FST::Arc;
56 // using Label = typename Arc::Label;
57 // using StateId = typename Arc::StateId;
58 // using Weight = typename Arc::Weight;
59 //
60 // // Required constructors.
61 // // This makes a copy of the FST.
62 // LookAheadMatcher(const FST &fst, MatchType match_type);
63 // // This doesn't copy the FST.
64 // LookAheadMatcher(const FST *fst, MatchType match_type);
65 // // This makes a copy of the FST.
66 // // See Copy() below.
67 // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
68 //
69 // // If safe = true, the copy is thread-safe (except the lookahead FST is
70 // // preserved). See Fst<>::Copy() for further doc.
71 // LookaheadMatcher *Copy(bool safe = false) const override;
72 
73 // // Below are methods for looking ahead for a match to a label and more
74 // // generally, to a rational set. Each returns false if there is definitely
75 // // not a match and returns true if there possibly is a match.
76 //
77 // // Optionally pre-specifies the lookahead FST that will be passed to
78 // // LookAheadFst() for possible precomputation. If copy is true, then the FST
79 // // argument is a copy of the FST used in the previous call to this method
80 // // (to avoid unnecessary updates).
81 // void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override;
82 //
83 // // Are there paths from a state in the lookahead FST that can be read from
84 // // the current matcher state?
85 // bool LookAheadFst(const Fst<Arc> &fst, StateId s) override;
86 //
87 // // Can the label be read from the current matcher state after possibly
88 // // following epsilon transitions?
89 // bool LookAheadLabel(Label label) const override;
90 //
91 // // The following methods allow looking ahead for an arbitrary rational set
92 // // of strings, specified by an FST and a state from which to begin the
93 // // matching. If the lookahead FST is a transducer, this looks on the side
94 // // different from the matcher's match_type (cf. composition).
95 // // Is there is a single non-epsilon arc found in the lookahead FST that
96 // // begins the path (after possibly following any epsilons) in the last call
97 // // to LookAheadFst? If so, return true and copy it to the arc argument;
98 // // otherwise, return false. Non-trivial implementations are useful for
99 // // label-pushing in composition.
100 // bool LookAheadPrefix(Arc *arc) override;
101 //
102 // // Gives an estimate of the combined weight of the paths in the lookahead
103 // // and matcher FSTs for the last call to LookAheadFst. Non-trivial
104 // // implementations are useful for weight-pushing in composition.
105 // Weight LookAheadWeight() const override;
106 // };
107 
108 // Look-ahead flags.
109 // Matcher is a lookahead matcher when match_type is MATCH_INPUT.
110 inline constexpr uint32_t kInputLookAheadMatcher = 0x00000010;
111 
112 // Matcher is a lookahead matcher when match_type is MATCH_OUTPUT.
113 inline constexpr uint32_t kOutputLookAheadMatcher = 0x00000020;
114 
115 // Is a non-trivial implementation of LookAheadWeight() method defined and
116 // if so, should it be used?
117 inline constexpr uint32_t kLookAheadWeight = 0x00000040;
118 
119 // Is a non-trivial implementation of LookAheadPrefix() method defined and
120 // if so, should it be used?
121 inline constexpr uint32_t kLookAheadPrefix = 0x00000080;
122 
123 // Look-ahead of matcher FST non-epsilon arcs?
124 inline constexpr uint32_t kLookAheadNonEpsilons = 0x00000100;
125 
126 // Look-ahead of matcher FST epsilon arcs?
127 inline constexpr uint32_t kLookAheadEpsilons = 0x00000200;
128 
129 // Ignore epsilon paths for the lookahead prefix? This gives correct results in
130 // composition only with an appropriate composition filter since it depends on
131 // the filter blocking the ignored paths.
132 inline constexpr uint32_t kLookAheadNonEpsilonPrefix = 0x00000400;
133 
134 // For LabelLookAheadMatcher, save relabeling data to file?
135 inline constexpr uint32_t kLookAheadKeepRelabelData = 0x00000800;
136 
137 // Flags used for lookahead matchers.
138 inline constexpr uint32_t kLookAheadFlags = 0x00000ff0;
139 
140 // LookAhead Matcher interface, templated on the Arc definition; used
141 // for lookahead matcher specializations that are returned by the
142 // InitMatcher() Fst method.
143 template <class Arc>
144 class LookAheadMatcherBase : public MatcherBase<Arc> {
145  public:
146  using Label = typename Arc::Label;
147  using StateId = typename Arc::StateId;
148  using Weight = typename Arc::Weight;
149 
150  virtual void InitLookAheadFst(const Fst<Arc> &, bool copy = false) = 0;
151  virtual bool LookAheadFst(const Fst<Arc> &, StateId) = 0;
152  virtual bool LookAheadLabel(Label) const = 0;
153 
154  // Suggested concrete implementation of lookahead methods.
155 
156  bool LookAheadPrefix(Arc *arc) const {
157  if (prefix_arc_.nextstate != kNoStateId) {
158  *arc = prefix_arc_;
159  return true;
160  } else {
161  return false;
162  }
163  }
164 
165  Weight LookAheadWeight() const { return weight_; }
166 
167  protected:
168  // Concrete implementations for lookahead helper methods.
169 
170  void ClearLookAheadWeight() { weight_ = Weight::One(); }
171 
172  void SetLookAheadWeight(Weight weight) { weight_ = std::move(weight); }
173 
174  void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
175 
176  void SetLookAheadPrefix(Arc arc) { prefix_arc_ = std::move(arc); }
177 
178  private:
179  Arc prefix_arc_;
180  Weight weight_;
181 };
182 
183 // Doesn't actually lookahead, just declares that the future looks good.
184 template <class M>
186  : public LookAheadMatcherBase<typename M::FST::Arc> {
187  public:
188  using FST = typename M::FST;
189  using Arc = typename FST::Arc;
190  using Label = typename Arc::Label;
191  using StateId = typename Arc::StateId;
192  using Weight = typename Arc::Weight;
193 
194  // This makes a copy of the FST.
196  : matcher_(fst, match_type) {}
197 
198  // This doesn't copy the FST.
200  : matcher_(fst, match_type) {}
201 
202  // This makes a copy of the FST.
204  bool safe = false)
205  : matcher_(lmatcher.matcher_, safe) {}
206 
207  TrivialLookAheadMatcher *Copy(bool safe = false) const override {
208  return new TrivialLookAheadMatcher(*this, safe);
209  }
210 
211  MatchType Type(bool test) const override { return matcher_.Type(test); }
212 
213  void SetState(StateId s) final { return matcher_.SetState(s); }
214 
215  bool Find(Label label) final { return matcher_.Find(label); }
216 
217  bool Done() const final { return matcher_.Done(); }
218 
219  const Arc &Value() const final { return matcher_.Value(); }
220 
221  void Next() final { matcher_.Next(); }
222 
223  Weight Final(StateId s) const final { return matcher_.Final(s); }
224 
225  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
226 
227  const FST &GetFst() const override { return matcher_.GetFst(); }
228 
229  uint64_t Properties(uint64_t props) const override {
230  return matcher_.Properties(props);
231  }
232 
233  uint32_t Flags() const override {
234  return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
235  }
236 
237  // Lookahead methods (all trivial).
238 
239  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {}
240 
241  bool LookAheadFst(const Fst<Arc> &, StateId) final { return true; }
242 
243  bool LookAheadLabel(Label) const final { return true; }
244 
245  bool LookAheadPrefix(Arc *) const { return false; }
246 
247  Weight LookAheadWeight() const { return Weight::One(); }
248 
249  private:
250  M matcher_;
251 };
252 
253 // Look-ahead of one transition. Template argument flags accepts flags to
254 // control behavior.
255 template <class M, uint32_t flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
256  kLookAheadWeight | kLookAheadPrefix>
257 class ArcLookAheadMatcher : public LookAheadMatcherBase<typename M::FST::Arc> {
258  public:
259  using FST = typename M::FST;
260  using Arc = typename FST::Arc;
261  using Label = typename Arc::Label;
262  using StateId = typename Arc::StateId;
263  using Weight = typename Arc::Weight;
265 
272 
273  static constexpr uint32_t kFlags = flags;
274 
275  // This makes a copy of the FST.
276  ArcLookAheadMatcher(const FST &fst, MatchType match_type,
277  std::shared_ptr<MatcherData> data = nullptr)
278  : matcher_(fst, match_type),
279  fst_(matcher_.GetFst()),
280  lfst_(nullptr),
281  state_(kNoStateId) {}
282 
283  // This doesn't copy the FST.
284  ArcLookAheadMatcher(const FST *fst, MatchType match_type,
285  std::shared_ptr<MatcherData> data = nullptr)
286  : matcher_(fst, match_type),
287  fst_(matcher_.GetFst()),
288  lfst_(nullptr),
289  state_(kNoStateId) {}
290 
291  // This makes a copy of the FST.
292  ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe = false)
293  : matcher_(lmatcher.matcher_, safe),
294  fst_(matcher_.GetFst()),
295  lfst_(lmatcher.lfst_),
296  state_(kNoStateId) {}
297 
298  // General matcher methods.
299  ArcLookAheadMatcher *Copy(bool safe = false) const override {
300  return new ArcLookAheadMatcher(*this, safe);
301  }
302 
303  MatchType Type(bool test) const override { return matcher_.Type(test); }
304 
305  void SetState(StateId s) final {
306  state_ = s;
307  matcher_.SetState(s);
308  }
309 
310  bool Find(Label label) final { return matcher_.Find(label); }
311 
312  bool Done() const final { return matcher_.Done(); }
313 
314  const Arc &Value() const final { return matcher_.Value(); }
315 
316  void Next() final { matcher_.Next(); }
317 
318  Weight Final(StateId s) const final { return matcher_.Final(s); }
319 
320  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
321 
322  const FST &GetFst() const override { return fst_; }
323 
324  uint64_t Properties(uint64_t props) const override {
325  return matcher_.Properties(props);
326  }
327 
328  uint32_t Flags() const override {
329  return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
330  kFlags;
331  }
332 
333  const MatcherData *GetData() const { return nullptr; }
334 
335  std::shared_ptr<MatcherData> GetSharedData() const { return nullptr; }
336 
337  // Look-ahead methods.
338 
339  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
340  lfst_ = &fst;
341  }
342 
343  // Checks if there is a matching (possibly super-final) transition
344  // at (state_, s).
345  bool LookAheadFst(const Fst<Arc> &, StateId) final;
346 
347  bool LookAheadLabel(Label label) const final { return matcher_.Find(label); }
348 
349  private:
350  mutable M matcher_;
351  const FST &fst_; // Matcher FST.
352  const Fst<Arc> *lfst_; // Look-ahead FST.
353  StateId state_; // Matcher state.
354 };
355 
356 template <class M, uint32_t flags>
358  StateId s) {
359  if (&fst != lfst_) InitLookAheadFst(fst);
360  bool result = false;
361  ssize_t nprefix = 0;
362  if (kFlags & kLookAheadWeight) ClearLookAheadWeight();
363  if (kFlags & kLookAheadPrefix) ClearLookAheadPrefix();
364  if (fst_.Final(state_) != Weight::Zero() &&
365  lfst_->Final(s) != Weight::Zero()) {
366  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
367  ++nprefix;
368  if (kFlags & kLookAheadWeight) {
370  Plus(LookAheadWeight(), Times(fst_.Final(state_), lfst_->Final(s))));
371  }
372  result = true;
373  }
374  if (matcher_.Find(kNoLabel)) {
375  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
376  ++nprefix;
377  if (kFlags & kLookAheadWeight) {
378  for (; !matcher_.Done(); matcher_.Next()) {
379  SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
380  }
381  }
382  result = true;
383  }
384  for (ArcIterator<Fst<Arc>> aiter(*lfst_, s); !aiter.Done(); aiter.Next()) {
385  const auto &arc = aiter.Value();
386  Label label = kNoLabel;
387  switch (matcher_.Type(false)) {
388  case MATCH_INPUT:
389  label = arc.olabel;
390  break;
391  case MATCH_OUTPUT:
392  label = arc.ilabel;
393  break;
394  default:
395  FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: Bad match type";
396  return true;
397  }
398  if (label == 0) {
399  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
400  if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
401  if (kFlags & kLookAheadWeight) {
402  SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
403  }
404  result = true;
405  } else if (matcher_.Find(label)) {
406  if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
407  for (; !matcher_.Done(); matcher_.Next()) {
408  ++nprefix;
409  if (kFlags & kLookAheadWeight) {
411  Times(arc.weight, matcher_.Value().weight)));
412  }
413  if ((kFlags & kLookAheadPrefix) && nprefix == 1)
414  SetLookAheadPrefix(arc);
415  }
416  result = true;
417  }
418  }
419  if (kFlags & kLookAheadPrefix) {
420  if (nprefix == 1) {
421  ClearLookAheadWeight(); // Avoids double counting.
422  } else {
424  }
425  }
426  return result;
427 }
428 
429 // Template argument flags accepts flags to control behavior. It must include
430 // precisely one of kInputLookAheadMatcher or kOutputLookAheadMatcher.
431 template <class M,
432  uint32_t flags = kLookAheadEpsilons | kLookAheadWeight |
433  kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
438  : public LookAheadMatcherBase<typename M::FST::Arc> {
439  public:
440  using Matcher = M;
441  using Accumulator = Accum;
442  using Reachable = R;
443 
444  using FST = typename M::FST;
445  using Arc = typename FST::Arc;
446  using Label = typename Arc::Label;
447  using StateId = typename Arc::StateId;
448  using Weight = typename Arc::Weight;
449  using MatcherData = typename Reachable::Data;
450 
457 
458  static_assert(!(flags & kInputLookAheadMatcher) !=
459  !(flags & kOutputLookAheadMatcher),
460  "Must include precisely one of kInputLookAheadMatcher and "
461  "kOutputLookAheadMatcher");
462  static constexpr uint32_t kFlags = flags;
463 
464  // This makes a copy of the FST.
466  std::shared_ptr<MatcherData> data = nullptr,
467  std::unique_ptr<Accumulator> accumulator = nullptr)
468  : matcher_(fst, match_type),
469  lfst_(nullptr),
470  state_(kNoStateId),
471  error_(false) {
472  Init(fst, match_type, data, std::move(accumulator));
473  }
474 
475  // This doesn't copy the FST.
477  std::shared_ptr<MatcherData> data = nullptr,
478  std::unique_ptr<Accumulator> accumulator = nullptr)
479  : matcher_(fst, match_type),
480  lfst_(nullptr),
481  state_(kNoStateId),
482  error_(false) {
483  Init(*fst, match_type, data, std::move(accumulator));
484  }
485 
486  // This makes a copy of the FST.
488  bool safe = false)
489  : matcher_(lmatcher.matcher_, safe),
490  lfst_(lmatcher.lfst_),
491  label_reachable_(lmatcher.label_reachable_
492  ? new Reachable(*lmatcher.label_reachable_, safe)
493  : nullptr),
494  state_(kNoStateId),
495  error_(lmatcher.error_) {}
496 
497  LabelLookAheadMatcher *Copy(bool safe = false) const override {
498  return new LabelLookAheadMatcher(*this, safe);
499  }
500 
501  MatchType Type(bool test) const override { return matcher_.Type(test); }
502 
503  void SetState(StateId s) final {
504  if (state_ == s) return;
505  state_ = s;
506  match_set_state_ = false;
507  reach_set_state_ = false;
508  }
509 
510  bool Find(Label label) final {
511  if (!match_set_state_) {
512  matcher_.SetState(state_);
513  match_set_state_ = true;
514  }
515  return matcher_.Find(label);
516  }
517 
518  bool Done() const final { return matcher_.Done(); }
519 
520  const Arc &Value() const final { return matcher_.Value(); }
521 
522  void Next() final { matcher_.Next(); }
523 
524  Weight Final(StateId s) const final { return matcher_.Final(s); }
525 
526  ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
527 
528  const FST &GetFst() const override { return matcher_.GetFst(); }
529 
530  uint64_t Properties(uint64_t inprops) const override {
531  auto outprops = matcher_.Properties(inprops);
532  if (error_ || (label_reachable_ && label_reachable_->Error())) {
533  outprops |= kError;
534  }
535  return outprops;
536  }
537 
538  uint32_t Flags() const override {
539  if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
540  return matcher_.Flags() | kFlags | kInputLookAheadMatcher;
541  } else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
542  return matcher_.Flags() | kFlags | kOutputLookAheadMatcher;
543  } else {
544  return matcher_.Flags();
545  }
546  }
547 
548  const MatcherData *GetData() const {
549  return label_reachable_ ? label_reachable_->GetData() : nullptr;
550  }
551 
552  std::shared_ptr<MatcherData> GetSharedData() const {
553  return label_reachable_ ? label_reachable_->GetSharedData() : nullptr;
554  }
555  // Checks if there is a matching (possibly super-final) transition at
556  // (state_, s).
557  template <class LFST>
558  bool LookAheadFst(const LFST &fst, StateId s);
559 
560  // Required to make class concrete.
561  bool LookAheadFst(const Fst<Arc> &fst, StateId s) final {
562  return LookAheadFst<Fst<Arc>>(fst, s);
563  }
564 
565  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
566  lfst_ = &fst;
567  if (label_reachable_) {
568  const bool reach_input = Type(false) == MATCH_OUTPUT;
569  label_reachable_->ReachInit(fst, reach_input, copy);
570  }
571  }
572 
573  template <class LFST>
574  void InitLookAheadFst(const LFST &fst, bool copy = false) {
575  lfst_ = &fst;
576  if (label_reachable_) {
577  const bool reach_input = Type(false) == MATCH_OUTPUT;
578  label_reachable_->ReachInit(fst, reach_input, copy);
579  }
580  }
581 
582  bool LookAheadLabel(Label label) const final {
583  if (label == 0) return true;
584  if (label_reachable_) {
585  if (!reach_set_state_) {
586  label_reachable_->SetState(state_);
587  reach_set_state_ = true;
588  }
589  return label_reachable_->Reach(label);
590  } else {
591  return true;
592  }
593  }
594 
595  private:
596  void Init(const FST &fst, MatchType match_type,
597  std::shared_ptr<MatcherData> data,
598  std::unique_ptr<Accumulator> accumulator) {
599  const bool reach_input = match_type == MATCH_INPUT;
600  if (data) {
601  if (reach_input == data->ReachInput()) {
602  label_reachable_ =
603  std::make_unique<Reachable>(data, std::move(accumulator));
604  }
605  } else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
606  (!reach_input && (kFlags & kOutputLookAheadMatcher))) {
607  label_reachable_ =
608  std::make_unique<Reachable>(fst, reach_input, std::move(accumulator),
609  kFlags & kLookAheadKeepRelabelData);
610  }
611  }
612 
613  mutable M matcher_;
614  const Fst<Arc> *lfst_; // Look-ahead FST.
615  std::unique_ptr<Reachable> label_reachable_; // Label reachability info.
616  StateId state_; // Matcher state.
617  bool match_set_state_; // matcher_.SetState called?
618  mutable bool reach_set_state_; // reachable_.SetState called?
619  bool error_; // Error encountered?
620 };
621 
622 template <class M, uint32_t flags, class Accumulator, class Reachable>
623 template <class LFST>
624 inline bool LabelLookAheadMatcher<M, flags, Accumulator,
626  StateId s) {
627  if (&fst != lfst_) InitLookAheadFst(fst);
630  if (!label_reachable_) return true;
631  label_reachable_->SetState(state_, s);
632  reach_set_state_ = true;
633  bool compute_weight = kFlags & kLookAheadWeight;
634  constexpr bool kComputePrefix = kFlags & kLookAheadPrefix;
635  ArcIterator<LFST> aiter(fst, s);
636  aiter.SetFlags(kArcNoCache, kArcNoCache); // Makes caching optional.
637  const bool reach_arc = label_reachable_->Reach(
638  &aiter, 0, internal::NumArcs(*lfst_, s), compute_weight);
639  const auto lfinal = internal::Final(*lfst_, s);
640  const bool reach_final =
641  lfinal != Weight::Zero() && label_reachable_->ReachFinal();
642  if (reach_arc) {
643  const auto begin = label_reachable_->ReachBegin();
644  const auto end = label_reachable_->ReachEnd();
645  if (kComputePrefix && end - begin == 1 && !reach_final) {
646  aiter.Seek(begin);
647  SetLookAheadPrefix(aiter.Value());
648  compute_weight = false;
649  } else if (compute_weight) {
650  SetLookAheadWeight(label_reachable_->ReachWeight());
651  }
652  }
653  if (reach_final && compute_weight) {
654  SetLookAheadWeight(reach_arc ? Plus(LookAheadWeight(), lfinal) : lfinal);
655  }
656  return reach_arc || reach_final;
657 }
658 
659 // Relabels the fst with Reachable::Reachable. Relabels input
660 // if data.First() is non-null, otherwise relabels output.
661 // Optionally saves the input/output label pairs to a file
662 // if save_relabel_ipairs/opairs is non-empty.
663 template <class Reachable, class FST, class Data>
664 void RelabelForReachable(FST *fst, const Data &data,
665  std::string_view save_relabel_ipairs,
666  std::string_view save_relabel_opairs) {
667  using Label = typename FST::Arc::Label;
668  if (data.First() != nullptr) { // reach_input.
669  Reachable reachable(data.SharedFirst());
670  reachable.Relabel(fst, /*relabel_input=*/true);
671  if (!save_relabel_ipairs.empty()) {
672  std::vector<std::pair<Label, Label>> pairs;
673  reachable.RelabelPairs(&pairs, /*avoid_collisions=*/true);
674  std::sort(pairs.begin(), pairs.end()); // Sort for deterministic output.
675  WriteLabelPairs(save_relabel_ipairs, pairs);
676  }
677  } else {
678  Reachable reachable(data.SharedSecond());
679  reachable.Relabel(fst, /*relabel_input=*/false);
680  if (!save_relabel_opairs.empty()) {
681  std::vector<std::pair<Label, Label>> pairs;
682  reachable.RelabelPairs(&pairs, /*avoid_collisions=*/true);
683  std::sort(pairs.begin(), pairs.end()); // Sort for deterministic output.
684  WriteLabelPairs(save_relabel_opairs, pairs);
685  }
686  }
687 }
688 
689 // Label-lookahead relabeling class.
690 template <class Arc, class Data = LabelReachableData<typename Arc::Label>>
692  public:
693  using Label = typename Arc::Label;
695 
696  // Relabels matcher FST (initialization function object).
697  template <typename Impl>
698  explicit LabelLookAheadRelabeler(std::shared_ptr<Impl> *impl);
699 
700  // Relabels arbitrary FST. Class LFST should be a label-lookahead FST.
701  template <class LFST>
702  static void Relabel(MutableFst<Arc> *fst, const LFST &mfst,
703  bool relabel_input) {
704  const auto *data = mfst.GetAddOn();
705  Reachable reachable(data->First() ? data->SharedFirst()
706  : data->SharedSecond());
707  reachable.Relabel(fst, relabel_input);
708  }
709 
710  // Returns relabeling pairs (cf. relabel.h::Relabel()). Class LFST should be a
711  // label-lookahead FST. If avoid_collisions is true, extra pairs are added to
712  // ensure no collisions when relabeling automata that have labels unseen here.
713  template <class LFST>
714  static void RelabelPairs(const LFST &mfst,
715  std::vector<std::pair<Label, Label>> *pairs,
716  bool avoid_collisions = false) {
717  const auto *data = mfst.GetAddOn();
718  Reachable reachable(data->First() ? data->SharedFirst()
719  : data->SharedSecond());
720  reachable.RelabelPairs(pairs, avoid_collisions);
721  }
722 };
723 
724 template <class Arc, class Data>
725 template <typename Impl>
727  std::shared_ptr<Impl> *impl) {
728  Fst<Arc> &fst = (*impl)->GetFst();
729  auto data = (*impl)->GetSharedAddOn();
730  const auto name = (*impl)->Type();
731  const bool is_mutable = fst.Properties(kMutable, false);
732  std::unique_ptr<MutableFst<Arc>> mfst;
733  if (is_mutable) {
734  // Borrow pointer from fst without increasing ref count; it will
735  // be released below. We do not want to call Copy() since that would
736  // do a deep copy when the Fst is modified.
737  mfst.reset(down_cast<MutableFst<Arc> *>(&fst));
738  } else {
739  mfst = std::make_unique<VectorFst<Arc>>(fst);
740  }
741 
742  RelabelForReachable<Reachable>(mfst.get(), *data,
743  FST_FLAGS_save_relabel_ipairs,
744  FST_FLAGS_save_relabel_opairs);
745 
746  if (is_mutable) {
747  // Pointer was just borrowed, don't delete it.
748  mfst.release();
749  } else {
750  *impl = std::make_shared<Impl>(*mfst, name);
751  (*impl)->SetAddOn(data);
752  }
753 }
754 
755 // Generic lookahead matcher, templated on the FST definition (a wrapper around
756 // a pointer to specific one).
757 template <class F>
759  public:
760  using FST = F;
761  using Arc = typename FST::Arc;
762  using Label = typename Arc::Label;
763  using StateId = typename Arc::StateId;
764  using Weight = typename Arc::Weight;
766 
767  // This makes a copy of the FST.
768  LookAheadMatcher(const FST &fst, MatchType match_type)
769  : owned_fst_(fst.Copy()),
770  base_(owned_fst_->InitMatcher(match_type)),
771  lookahead_(false) {
772  if (!base_)
773  base_ =
774  std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
775  }
776 
777  // This doesn't copy the FST.
778  LookAheadMatcher(const FST *fst, MatchType match_type)
779  : base_(fst->InitMatcher(match_type)), lookahead_(false) {
780  if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
781  }
782 
783  // This makes a copy of the FST.
784  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false)
785  : base_(matcher.base_->Copy(safe)), lookahead_(matcher.lookahead_) {}
786 
787  // Takes ownership of base.
789  : base_(base), lookahead_(false) {}
790 
791  LookAheadMatcher *Copy(bool safe = false) const {
792  return new LookAheadMatcher(*this, safe);
793  }
794 
795  MatchType Type(bool test) const { return base_->Type(test); }
796 
797  void SetState(StateId s) { base_->SetState(s); }
798 
799  bool Find(Label label) { return base_->Find(label); }
800 
801  bool Done() const { return base_->Done(); }
802 
803  const Arc &Value() const { return base_->Value(); }
804 
805  void Next() { base_->Next(); }
806 
807  Weight Final(StateId s) const { return base_->Final(s); }
808 
809  ssize_t Priority(StateId s) { return base_->Priority(s); }
810 
811  const FST &GetFst() const { return down_cast<const FST &>(base_->GetFst()); }
812 
813  uint64_t Properties(uint64_t props) const { return base_->Properties(props); }
814 
815  uint32_t Flags() const { return base_->Flags(); }
816 
817  bool LookAheadLabel(Label label) const {
818  if (LookAheadCheck()) {
819  return down_cast<LBase *>(base_.get())->LookAheadLabel(label);
820  } else {
821  return true;
822  }
823  }
824 
825  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
826  if (LookAheadCheck()) {
827  return down_cast<LBase *>(base_.get())->LookAheadFst(fst, s);
828  } else {
829  return true;
830  }
831  }
832 
834  if (LookAheadCheck()) {
835  return down_cast<LBase *>(base_.get())->LookAheadWeight();
836  } else {
837  return Weight::One();
838  }
839  }
840 
841  bool LookAheadPrefix(Arc *arc) const {
842  if (LookAheadCheck()) {
843  return down_cast<LBase *>(base_.get())->LookAheadPrefix(arc);
844  } else {
845  return false;
846  }
847  }
848 
849  void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
850  if (LookAheadCheck()) {
851  down_cast<LBase *>(base_.get())->InitLookAheadFst(fst, copy);
852  }
853  }
854 
855  private:
856  bool LookAheadCheck() const {
857  if (!lookahead_) {
858  lookahead_ =
859  base_->Flags() & (kInputLookAheadMatcher | kOutputLookAheadMatcher);
860  if (!lookahead_) {
861  FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
862  }
863  }
864  return lookahead_;
865  }
866 
867  std::unique_ptr<const FST> owned_fst_;
868  std::unique_ptr<MatcherBase<Arc>> base_;
869  mutable bool lookahead_;
870 
871  LookAheadMatcher &operator=(const LookAheadMatcher &) = delete;
872 };
873 
874 } // namespace fst
875 
876 #endif // FST_LOOKAHEAD_MATCHER_H_
typename Arc::Label Label
ssize_t Priority(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s)
ArcLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
virtual bool LookAheadLabel(Label) const =0
constexpr uint64_t kMutable
Definition: properties.h:49
bool LookAheadLabel(Label label) const final
bool LookAheadPrefix(Arc *arc) const
uint64_t Properties(uint64_t props) const
constexpr int kNoLabel
Definition: fst.h:194
constexpr uint8_t kArcNoCache
Definition: fst.h:451
void RelabelForReachable(FST *fst, const Data &data, std::string_view save_relabel_ipairs, std::string_view save_relabel_opairs)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
void SetState(StateId s)
typename Arc::StateId StateId
bool Find(Label label)
bool LookAheadLabel(Label) const final
typename Arc::Label Label
TrivialLookAheadMatcher * Copy(bool safe=false) const override
const Arc & Value() const final
TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, bool safe=false)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
typename Arc::StateId StateId
bool LookAheadPrefix(Arc *) const
typename Arc::Weight Weight
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
bool LookAheadLabel(Label label) const
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:107
MatchType
Definition: fst.h:186
std::shared_ptr< MatcherData > GetSharedData() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
const FST & GetFst() const override
bool WriteLabelPairs(std::string_view source, const std::vector< std::pair< Label, Label >> &pairs)
Definition: util.h:429
constexpr uint64_t kError
Definition: properties.h:52
const Arc & Value() const
ArcLookAheadMatcher * Copy(bool safe=false) const override
constexpr uint32_t kLookAheadFlags
MatchType Type(bool test) const override
typename Arc::Weight Weight
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:113
To down_cast(From *f)
Definition: compat.h:52
TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
LabelLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
MatchType Type(bool test) const
void SetState(StateId s) final
virtual bool LookAheadFst(const Fst< Arc > &, StateId)=0
void SetState(StateId s) final
constexpr int kNoStateId
Definition: fst.h:195
bool Find(Label label) final
Label Relabel(Label label)
const Arc & Value() const final
const Arc & Value() const
Definition: fst.h:535
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
ssize_t Priority(StateId s) final
constexpr uint32_t kOutputLookAheadMatcher
#define FSTERROR()
Definition: util.h:57
bool Find(Label label) final
LabelLookAheadRelabeler(std::shared_ptr< Impl > *impl)
void SetFlags(uint8_t flags, uint8_t mask)
Definition: fst.h:569
const FST & GetFst() const override
MatchType Type(bool test) const override
LabelLookAheadMatcher(const LabelLookAheadMatcher &lmatcher, bool safe=false)
constexpr uint32_t kLookAheadEpsilons
constexpr uint32_t kLookAheadNonEpsilonPrefix
constexpr uint32_t kLookAheadKeepRelabelData
LookAheadMatcher(const FST *fst, MatchType match_type)
static void RelabelPairs(const LFST &mfst, std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
const FST & GetFst() const
constexpr uint32_t kLookAheadNonEpsilons
const MatcherData * GetData() const
Weight Final(StateId s) const final
void Seek(size_t a)
Definition: fst.h:555
MatchType Type(bool test) const override
DECLARE_string(save_relabel_ipairs)
typename Arc::Label Label
const FST & GetFst() const override
constexpr uint32_t kInputLookAheadMatcher
ssize_t Priority(StateId s)
typename Reachable::Data MatcherData
LookAheadMatcher * Copy(bool safe=false) const
constexpr uint32_t kLookAheadPrefix
void SetState(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s) final
bool LookAheadPrefix(Arc *arc) const
virtual const Fst< Arc > & GetFst() const =0
void SetLookAheadWeight(Weight weight)
ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe=false)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
uint64_t Properties(uint64_t inprops) const override
LookAheadMatcher(const LookAheadMatcher &matcher, bool safe=false)
LabelLookAheadMatcher * Copy(bool safe=false) const override
uint32_t Flags() const override
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false)
uint32_t Flags() const override
const MatcherData * GetData() const
ArcLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
constexpr uint32_t kLookAheadWeight
virtual MatcherBase * Copy(bool safe=false) const =0
typename Arc::StateId StateId
typename Arc::StateId StateId
LookAheadMatcher(const FST &fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
virtual void InitLookAheadFst(const Fst< Arc > &, bool copy=false)=0
Weight Final(StateId s) const
typename Arc::Weight Weight
Definition: matcher.h:145
typename Arc::Label Label
Definition: matcher.h:143
virtual MatchType Type(bool) const =0
Weight Final(StateId s) const final
Weight LookAheadWeight() const
typename Arc::Weight Weight
Weight Final(StateId s) const final
ssize_t Priority(StateId s) final
std::shared_ptr< MatcherData > GetSharedData() const
uint32_t Flags() const override
bool LookAheadFst(const Fst< Arc > &, StateId) final
const Arc & Value() const final
uint64_t Properties(uint64_t props) const override
uint32_t Flags() const
LookAheadMatcher(MatcherBase< Arc > *base)
typename FST::Arc Arc
void InitLookAheadFst(const LFST &fst, bool copy=false)
LabelLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
bool Find(Label label) final
bool LookAheadLabel(Label label) const final
bool LookAheadFst(const Fst< Arc > &, StateId) final
typename Arc::StateId StateId
Definition: matcher.h:144