21 #ifndef FST_LOOKAHEAD_MATCHER_H_ 22 #define FST_LOOKAHEAD_MATCHER_H_ 183 using Arc =
typename FST::Arc;
190 : matcher_(fst, match_type) {}
194 : matcher_(fst, match_type) {}
199 : matcher_(lmatcher.matcher_, safe) {}
209 bool Find(
Label label)
final {
return matcher_.Find(label); }
211 bool Done() const final {
return matcher_.Done(); }
213 const Arc &
Value() const final {
return matcher_.Value(); }
215 void Next() final { matcher_.Next(); }
221 const FST &
GetFst()
const override {
return matcher_.GetFst(); }
224 return matcher_.Properties(props);
249 template <
class M, uint32_t flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
250 kLookAheadWeight | kLookAheadPrefix>
254 using Arc =
typename FST::Arc;
267 static constexpr uint32_t kFlags = flags;
271 std::shared_ptr<MatcherData> data =
nullptr)
272 : matcher_(fst, match_type),
279 std::shared_ptr<MatcherData> data =
nullptr)
280 : matcher_(fst, match_type),
287 : matcher_(lmatcher.matcher_, safe),
289 lfst_(lmatcher.lfst_),
301 matcher_.SetState(s);
304 bool Find(
Label label)
final {
return matcher_.Find(label); }
306 bool Done() const final {
return matcher_.Done(); }
308 const Arc &
Value() const final {
return matcher_.Value(); }
310 void Next() final { matcher_.Next(); }
319 return matcher_.Properties(props);
323 return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
350 template <
class M, u
int32_t flags>
358 if (fst_.Final(state_) != Weight::Zero() &&
359 lfst_->Final(s) != Weight::Zero()) {
360 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
362 if (kFlags & kLookAheadWeight) {
369 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
371 if (kFlags & kLookAheadWeight) {
372 for (; !matcher_.Done(); matcher_.Next()) {
379 const auto &arc = aiter.Value();
381 switch (matcher_.Type(
false)) {
389 FSTERROR() <<
"ArcLookAheadMatcher::LookAheadFst: Bad match type";
393 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
394 if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
395 if (kFlags & kLookAheadWeight) {
399 }
else if (matcher_.Find(label)) {
400 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
401 for (; !matcher_.Done(); matcher_.Next()) {
403 if (kFlags & kLookAheadWeight) {
405 Times(arc.weight, matcher_.Value().weight)));
407 if ((kFlags & kLookAheadPrefix) && nprefix == 1)
413 if (kFlags & kLookAheadPrefix) {
426 uint32_t flags = kLookAheadEpsilons | kLookAheadWeight |
427 kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
439 using Arc =
typename FST::Arc;
452 static_assert(!(flags & kInputLookAheadMatcher) !=
453 !(flags & kOutputLookAheadMatcher),
454 "Must include precisely one of kInputLookAheadMatcher and " 455 "kOutputLookAheadMatcher");
456 static constexpr uint32_t kFlags = flags;
460 std::shared_ptr<MatcherData> data =
nullptr,
461 std::unique_ptr<Accumulator> accumulator =
nullptr)
462 : matcher_(fst, match_type),
466 Init(fst, match_type, data, std::move(accumulator));
471 std::shared_ptr<MatcherData> data =
nullptr,
472 std::unique_ptr<Accumulator> accumulator =
nullptr)
473 : matcher_(fst, match_type),
477 Init(*fst, match_type, data, std::move(accumulator));
483 : matcher_(lmatcher.matcher_, safe),
484 lfst_(lmatcher.lfst_),
485 label_reachable_(lmatcher.label_reachable_
486 ? new
Reachable(*lmatcher.label_reachable_, safe)
489 error_(lmatcher.error_) {}
498 if (state_ == s)
return;
500 match_set_state_ =
false;
501 reach_set_state_ =
false;
505 if (!match_set_state_) {
506 matcher_.SetState(state_);
507 match_set_state_ =
true;
509 return matcher_.Find(label);
512 bool Done() const final {
return matcher_.Done(); }
514 const Arc &
Value() const final {
return matcher_.Value(); }
516 void Next() final { matcher_.Next(); }
522 const FST &
GetFst()
const override {
return matcher_.GetFst(); }
525 auto outprops = matcher_.Properties(inprops);
526 if (error_ || (label_reachable_ && label_reachable_->Error())) {
533 if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
535 }
else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
538 return matcher_.Flags();
543 return label_reachable_ ? label_reachable_->GetData() :
nullptr;
547 return label_reachable_ ? label_reachable_->GetSharedData() :
nullptr;
551 template <
class LFST>
556 return LookAheadFst<Fst<Arc>>(fst, s);
561 if (label_reachable_) {
563 label_reachable_->ReachInit(fst, reach_input, copy);
567 template <
class LFST>
570 if (label_reachable_) {
572 label_reachable_->ReachInit(fst, reach_input, copy);
577 if (label == 0)
return true;
578 if (label_reachable_) {
579 if (!reach_set_state_) {
580 label_reachable_->SetState(state_);
581 reach_set_state_ =
true;
583 return label_reachable_->Reach(label);
591 std::shared_ptr<MatcherData> data,
592 std::unique_ptr<Accumulator> accumulator) {
593 const bool reach_input = match_type ==
MATCH_INPUT;
595 if (reach_input == data->ReachInput()) {
597 std::make_unique<Reachable>(data, std::move(accumulator));
599 }
else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
602 std::make_unique<Reachable>(fst, reach_input, std::move(accumulator),
609 std::unique_ptr<Reachable> label_reachable_;
611 bool match_set_state_;
612 mutable bool reach_set_state_;
616 template <
class M, u
int32_t flags,
class Accumulator,
class Reachable>
617 template <
class LFST>
624 if (!label_reachable_)
return true;
625 label_reachable_->SetState(state_, s);
626 reach_set_state_ =
true;
631 const bool reach_arc = label_reachable_->Reach(
634 const bool reach_final =
635 lfinal != Weight::Zero() && label_reachable_->ReachFinal();
637 const auto begin = label_reachable_->ReachBegin();
638 const auto end = label_reachable_->ReachEnd();
639 if (compute_prefix && end - begin == 1 && !reach_final) {
642 compute_weight =
false;
643 }
else if (compute_weight) {
647 if (reach_final && compute_weight) {
650 return reach_arc || reach_final;
657 template <
class Reachable,
class FST,
class Data>
659 const std::string &save_relabel_ipairs,
660 const std::string &save_relabel_opairs) {
661 using Label =
typename FST::Arc::Label;
662 if (data.First() !=
nullptr) {
663 Reachable reachable(data.SharedFirst());
664 reachable.Relabel(fst,
true);
665 if (!save_relabel_ipairs.empty()) {
666 std::vector<std::pair<Label, Label>> pairs;
667 reachable.RelabelPairs(&pairs,
true);
671 Reachable reachable(data.SharedSecond());
672 reachable.Relabel(fst,
false);
673 if (!save_relabel_opairs.empty()) {
674 std::vector<std::pair<Label, Label>> pairs;
675 reachable.RelabelPairs(&pairs,
true);
682 template <
class Arc,
class Data = LabelReachableData<
typename Arc::Label>>
689 template <
typename Impl>
693 template <
class LFST>
695 bool relabel_input) {
696 const auto *data = mfst.GetAddOn();
697 Reachable reachable(data->First() ? data->SharedFirst()
698 : data->SharedSecond());
699 reachable.
Relabel(fst, relabel_input);
705 template <
class LFST>
707 std::vector<std::pair<Label, Label>> *pairs,
708 bool avoid_collisions =
false) {
709 const auto *data = mfst.GetAddOn();
710 Reachable reachable(data->First() ? data->SharedFirst()
711 : data->SharedSecond());
716 template <
class Arc,
class Data>
717 template <
typename Impl>
719 std::shared_ptr<Impl> *impl) {
721 auto data = (*impl)->GetSharedAddOn();
722 const auto name = (*impl)->Type();
724 std::unique_ptr<MutableFst<Arc>> mfst;
731 mfst = std::make_unique<VectorFst<Arc>>(fst);
734 RelabelForReachable<Reachable>(mfst.get(), *data,
735 FST_FLAGS_save_relabel_ipairs,
736 FST_FLAGS_save_relabel_opairs);
742 *impl = std::make_shared<Impl>(*mfst, name);
743 (*impl)->SetAddOn(data);
753 using Arc =
typename FST::Arc;
761 : owned_fst_(fst.
Copy()),
762 base_(owned_fst_->InitMatcher(match_type)),
766 std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
771 : base_(fst->InitMatcher(match_type)), lookahead_(false) {
772 if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
777 : base_(matcher.base_->
Copy(safe)), lookahead_(matcher.lookahead_) {}
781 : base_(base), lookahead_(false) {}
793 bool Done()
const {
return base_->Done(); }
795 const Arc &
Value()
const {
return base_->Value(); }
805 uint64_t
Properties(uint64_t props)
const {
return base_->Properties(props); }
807 uint32_t
Flags()
const {
return base_->Flags(); }
810 if (LookAheadCheck()) {
818 if (LookAheadCheck()) {
826 if (LookAheadCheck()) {
829 return Weight::One();
834 if (LookAheadCheck()) {
842 if (LookAheadCheck()) {
848 bool LookAheadCheck()
const {
853 FSTERROR() <<
"LookAheadMatcher: No look-ahead matcher defined";
859 std::unique_ptr<const FST> owned_fst_;
860 std::unique_ptr<MatcherBase<Arc>> base_;
861 mutable bool lookahead_;
868 #endif // FST_LOOKAHEAD_MATCHER_H_ typename Arc::Label Label
ssize_t Priority(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s)
ArcLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
virtual bool LookAheadLabel(Label) const =0
constexpr uint64_t kMutable
Weight LookAheadWeight() const
bool LookAheadLabel(Label label) const final
bool LookAheadPrefix(Arc *arc) const
uint64_t Properties(uint64_t props) const
constexpr uint8_t kArcNoCache
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Arc::StateId StateId
bool LookAheadLabel(Label) const final
typename Arc::Label Label
TrivialLookAheadMatcher * Copy(bool safe=false) const override
const Arc & Value() const final
TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, bool safe=false)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
typename Arc::StateId StateId
bool LookAheadPrefix(Arc *) const
typename Arc::Weight Weight
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
bool LookAheadLabel(Label label) const
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
std::shared_ptr< MatcherData > GetSharedData() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
const FST & GetFst() const override
constexpr uint64_t kError
const Arc & Value() const
ArcLookAheadMatcher * Copy(bool safe=false) const override
void SetLookAheadPrefix(Arc arc)
constexpr uint32_t kLookAheadFlags
MatchType Type(bool test) const override
typename Arc::Weight Weight
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
LabelLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
typename Arc::Label Label
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
MatchType Type(bool test) const
void SetState(StateId s) final
virtual bool LookAheadFst(const Fst< Arc > &, StateId)=0
void SetState(StateId s) final
bool Find(Label label) final
Label Relabel(Label label)
const Arc & Value() const final
const Arc & Value() const
typename Arc::Label Label
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
ssize_t Priority(StateId s) final
constexpr uint32_t kOutputLookAheadMatcher
bool Find(Label label) final
LabelLookAheadRelabeler(std::shared_ptr< Impl > *impl)
void SetFlags(uint8_t flags, uint8_t mask)
const FST & GetFst() const override
MatchType Type(bool test) const override
LabelLookAheadMatcher(const LabelLookAheadMatcher &lmatcher, bool safe=false)
constexpr uint32_t kLookAheadEpsilons
constexpr uint32_t kLookAheadNonEpsilonPrefix
constexpr uint32_t kLookAheadKeepRelabelData
void ClearLookAheadPrefix()
LookAheadMatcher(const FST *fst, MatchType match_type)
static void RelabelPairs(const LFST &mfst, std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
const FST & GetFst() const
constexpr uint32_t kLookAheadNonEpsilons
const MatcherData * GetData() const
Weight Final(StateId s) const final
MatchType Type(bool test) const override
DECLARE_string(save_relabel_ipairs)
typename Arc::Label Label
const FST & GetFst() const override
constexpr uint32_t kInputLookAheadMatcher
ssize_t Priority(StateId s)
typename Reachable::Data MatcherData
LookAheadMatcher * Copy(bool safe=false) const
constexpr uint32_t kLookAheadPrefix
void SetState(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s) final
bool LookAheadPrefix(Arc *arc) const
virtual const Fst< Arc > & GetFst() const =0
void SetLookAheadWeight(Weight weight)
ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe=false)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
uint64_t Properties(uint64_t inprops) const override
typename Arc::Weight Weight
LookAheadMatcher(const LookAheadMatcher &matcher, bool safe=false)
LabelLookAheadMatcher * Copy(bool safe=false) const override
uint32_t Flags() const override
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false)
uint32_t Flags() const override
const MatcherData * GetData() const
ArcLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
constexpr uint32_t kLookAheadWeight
virtual MatcherBase * Copy(bool safe=false) const =0
typename Arc::StateId StateId
typename Arc::StateId StateId
LookAheadMatcher(const FST &fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
virtual void InitLookAheadFst(const Fst< Arc > &, bool copy=false)=0
void ClearLookAheadWeight()
Weight Final(StateId s) const
typename Arc::Weight Weight
typename Arc::Label Label
virtual MatchType Type(bool) const =0
Weight LookAheadWeight() const
void RelabelForReachable(FST *fst, const Data &data, const std::string &save_relabel_ipairs, const std::string &save_relabel_opairs)
Weight Final(StateId s) const final
Weight LookAheadWeight() const
typename Arc::Weight Weight
Weight Final(StateId s) const final
bool WriteLabelPairs(const std::string &source, const std::vector< std::pair< Label, Label >> &pairs)
ssize_t Priority(StateId s) final
std::shared_ptr< MatcherData > GetSharedData() const
uint32_t Flags() const override
bool LookAheadFst(const Fst< Arc > &, StateId) final
const Arc & Value() const final
uint64_t Properties(uint64_t props) const override
LookAheadMatcher(MatcherBase< Arc > *base)
void InitLookAheadFst(const LFST &fst, bool copy=false)
LabelLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
typename Arc::Label Label
bool Find(Label label) final
bool LookAheadLabel(Label label) const final
bool LookAheadFst(const Fst< Arc > &, StateId) final
typename Arc::StateId StateId