21 #ifndef FST_LOOKAHEAD_MATCHER_H_ 22 #define FST_LOOKAHEAD_MATCHER_H_ 24 #include <sys/types.h> 44 #include <string_view> 191 using Arc =
typename FST::Arc;
198 : matcher_(fst, match_type) {}
202 : matcher_(fst, match_type) {}
207 : matcher_(lmatcher.matcher_, safe) {}
217 bool Find(
Label label)
final {
return matcher_.Find(label); }
219 bool Done() const final {
return matcher_.Done(); }
221 const Arc &
Value() const final {
return matcher_.Value(); }
223 void Next() final { matcher_.Next(); }
229 const FST &
GetFst()
const override {
return matcher_.GetFst(); }
232 return matcher_.Properties(props);
257 template <
class M, uint32_t flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
258 kLookAheadWeight | kLookAheadPrefix>
262 using Arc =
typename FST::Arc;
275 static constexpr uint32_t kFlags = flags;
279 std::shared_ptr<MatcherData> data =
nullptr)
280 : matcher_(fst, match_type),
287 std::shared_ptr<MatcherData> data =
nullptr)
288 : matcher_(fst, match_type),
295 : matcher_(lmatcher.matcher_, safe),
297 lfst_(lmatcher.lfst_),
309 matcher_.SetState(s);
312 bool Find(
Label label)
final {
return matcher_.Find(label); }
314 bool Done() const final {
return matcher_.Done(); }
316 const Arc &
Value() const final {
return matcher_.Value(); }
318 void Next() final { matcher_.Next(); }
327 return matcher_.Properties(props);
331 return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
358 template <
class M, u
int32_t flags>
366 if (fst_.Final(state_) != Weight::Zero() &&
367 lfst_->Final(s) != Weight::Zero()) {
368 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
370 if (kFlags & kLookAheadWeight) {
377 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
379 if (kFlags & kLookAheadWeight) {
380 for (; !matcher_.Done(); matcher_.Next()) {
387 const auto &arc = aiter.Value();
389 switch (matcher_.Type(
false)) {
397 FSTERROR() <<
"ArcLookAheadMatcher::LookAheadFst: Bad match type";
401 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
402 if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
403 if (kFlags & kLookAheadWeight) {
407 }
else if (matcher_.Find(label)) {
408 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
409 for (; !matcher_.Done(); matcher_.Next()) {
411 if (kFlags & kLookAheadWeight) {
413 Times(arc.weight, matcher_.Value().weight)));
415 if ((kFlags & kLookAheadPrefix) && nprefix == 1)
421 if (kFlags & kLookAheadPrefix) {
434 uint32_t flags = kLookAheadEpsilons | kLookAheadWeight |
435 kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
447 using Arc =
typename FST::Arc;
460 static_assert(!(flags & kInputLookAheadMatcher) !=
461 !(flags & kOutputLookAheadMatcher),
462 "Must include precisely one of kInputLookAheadMatcher and " 463 "kOutputLookAheadMatcher");
464 static constexpr uint32_t kFlags = flags;
468 std::shared_ptr<MatcherData> data =
nullptr,
469 std::unique_ptr<Accumulator> accumulator =
nullptr)
470 : matcher_(fst, match_type),
474 Init(fst, match_type, data, std::move(accumulator));
479 std::shared_ptr<MatcherData> data =
nullptr,
480 std::unique_ptr<Accumulator> accumulator =
nullptr)
481 : matcher_(fst, match_type),
485 Init(*fst, match_type, data, std::move(accumulator));
491 : matcher_(lmatcher.matcher_, safe),
492 lfst_(lmatcher.lfst_),
493 label_reachable_(lmatcher.label_reachable_
494 ? new
Reachable(*lmatcher.label_reachable_, safe)
497 error_(lmatcher.error_) {}
506 if (state_ == s)
return;
508 match_set_state_ =
false;
509 reach_set_state_ =
false;
513 if (!match_set_state_) {
514 matcher_.SetState(state_);
515 match_set_state_ =
true;
517 return matcher_.Find(label);
520 bool Done() const final {
return matcher_.Done(); }
522 const Arc &
Value() const final {
return matcher_.Value(); }
524 void Next() final { matcher_.Next(); }
530 const FST &
GetFst()
const override {
return matcher_.GetFst(); }
533 auto outprops = matcher_.Properties(inprops);
534 if (error_ || (label_reachable_ && label_reachable_->Error())) {
541 if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
543 }
else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
546 return matcher_.Flags();
551 return label_reachable_ ? label_reachable_->GetData() :
nullptr;
555 return label_reachable_ ? label_reachable_->GetSharedData() :
nullptr;
559 template <
class LFST>
564 return LookAheadFst<Fst<Arc>>(fst, s);
569 if (label_reachable_) {
571 label_reachable_->ReachInit(fst, reach_input, copy);
575 template <
class LFST>
578 if (label_reachable_) {
580 label_reachable_->ReachInit(fst, reach_input, copy);
585 if (label == 0)
return true;
586 if (label_reachable_) {
587 if (!reach_set_state_) {
588 label_reachable_->SetState(state_);
589 reach_set_state_ =
true;
591 return label_reachable_->Reach(label);
599 std::shared_ptr<MatcherData> data,
600 std::unique_ptr<Accumulator> accumulator) {
601 const bool reach_input = match_type ==
MATCH_INPUT;
603 if (reach_input == data->ReachInput()) {
605 std::make_unique<Reachable>(data, std::move(accumulator));
607 }
else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
610 std::make_unique<Reachable>(fst, reach_input, std::move(accumulator),
617 std::unique_ptr<Reachable> label_reachable_;
619 bool match_set_state_;
620 mutable bool reach_set_state_;
624 template <
class M, u
int32_t flags,
class Accumulator,
class Reachable>
625 template <
class LFST>
632 if (!label_reachable_)
return true;
633 label_reachable_->SetState(state_, s);
634 reach_set_state_ =
true;
639 const bool reach_arc = label_reachable_->Reach(
642 const bool reach_final =
643 lfinal != Weight::Zero() && label_reachable_->ReachFinal();
645 const auto begin = label_reachable_->ReachBegin();
646 const auto end = label_reachable_->ReachEnd();
647 if (kComputePrefix && end - begin == 1 && !reach_final) {
650 compute_weight =
false;
651 }
else if (compute_weight) {
655 if (reach_final && compute_weight) {
658 return reach_arc || reach_final;
665 template <
class Reachable,
class FST,
class Data>
667 std::string_view save_relabel_ipairs,
668 std::string_view save_relabel_opairs) {
669 using Label =
typename FST::Arc::Label;
670 if (data.First() !=
nullptr) {
671 Reachable reachable(data.SharedFirst());
672 reachable.Relabel(fst,
true);
673 if (!save_relabel_ipairs.empty()) {
674 std::vector<std::pair<Label, Label>> pairs;
675 reachable.RelabelPairs(&pairs,
true);
679 Reachable reachable(data.SharedSecond());
680 reachable.Relabel(fst,
false);
681 if (!save_relabel_opairs.empty()) {
682 std::vector<std::pair<Label, Label>> pairs;
683 reachable.RelabelPairs(&pairs,
true);
690 template <
class Arc,
class Data = LabelReachableData<
typename Arc::Label>>
697 template <
typename Impl>
701 template <
class LFST>
703 bool relabel_input) {
704 const auto *data = mfst.GetAddOn();
705 Reachable reachable(data->First() ? data->SharedFirst()
706 : data->SharedSecond());
707 reachable.
Relabel(fst, relabel_input);
713 template <
class LFST>
715 std::vector<std::pair<Label, Label>> *pairs,
716 bool avoid_collisions =
false) {
717 const auto *data = mfst.GetAddOn();
718 Reachable reachable(data->First() ? data->SharedFirst()
719 : data->SharedSecond());
724 template <
class Arc,
class Data>
725 template <
typename Impl>
727 std::shared_ptr<Impl> *impl) {
729 auto data = (*impl)->GetSharedAddOn();
730 const auto name = (*impl)->Type();
732 std::unique_ptr<MutableFst<Arc>> mfst;
739 mfst = std::make_unique<VectorFst<Arc>>(fst);
742 RelabelForReachable<Reachable>(mfst.get(), *data,
743 FST_FLAGS_save_relabel_ipairs,
744 FST_FLAGS_save_relabel_opairs);
750 *impl = std::make_shared<Impl>(*mfst, name);
751 (*impl)->SetAddOn(data);
761 using Arc =
typename FST::Arc;
769 : owned_fst_(fst.
Copy()),
770 base_(owned_fst_->InitMatcher(match_type)),
774 std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
779 : base_(fst->InitMatcher(match_type)), lookahead_(false) {
780 if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
785 : base_(matcher.base_->
Copy(safe)), lookahead_(matcher.lookahead_) {}
789 : base_(base), lookahead_(false) {}
801 bool Done()
const {
return base_->Done(); }
803 const Arc &
Value()
const {
return base_->Value(); }
813 uint64_t
Properties(uint64_t props)
const {
return base_->Properties(props); }
815 uint32_t
Flags()
const {
return base_->Flags(); }
818 if (LookAheadCheck()) {
826 if (LookAheadCheck()) {
834 if (LookAheadCheck()) {
837 return Weight::One();
842 if (LookAheadCheck()) {
850 if (LookAheadCheck()) {
856 bool LookAheadCheck()
const {
861 FSTERROR() <<
"LookAheadMatcher: No look-ahead matcher defined";
867 std::unique_ptr<const FST> owned_fst_;
868 std::unique_ptr<MatcherBase<Arc>> base_;
869 mutable bool lookahead_;
876 #endif // FST_LOOKAHEAD_MATCHER_H_ typename Arc::Label Label
ssize_t Priority(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s)
ArcLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
virtual bool LookAheadLabel(Label) const =0
constexpr uint64_t kMutable
Weight LookAheadWeight() const
bool LookAheadLabel(Label label) const final
bool LookAheadPrefix(Arc *arc) const
uint64_t Properties(uint64_t props) const
constexpr uint8_t kArcNoCache
void RelabelForReachable(FST *fst, const Data &data, std::string_view save_relabel_ipairs, std::string_view save_relabel_opairs)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Arc::StateId StateId
bool LookAheadLabel(Label) const final
typename Arc::Label Label
TrivialLookAheadMatcher * Copy(bool safe=false) const override
const Arc & Value() const final
TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, bool safe=false)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
typename Arc::StateId StateId
bool LookAheadPrefix(Arc *) const
typename Arc::Weight Weight
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
bool LookAheadLabel(Label label) const
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
std::shared_ptr< MatcherData > GetSharedData() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
const FST & GetFst() const override
bool WriteLabelPairs(std::string_view source, const std::vector< std::pair< Label, Label >> &pairs)
constexpr uint64_t kError
const Arc & Value() const
ArcLookAheadMatcher * Copy(bool safe=false) const override
void SetLookAheadPrefix(Arc arc)
constexpr uint32_t kLookAheadFlags
MatchType Type(bool test) const override
typename Arc::Weight Weight
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
LabelLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
typename Arc::Label Label
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
MatchType Type(bool test) const
void SetState(StateId s) final
virtual bool LookAheadFst(const Fst< Arc > &, StateId)=0
void SetState(StateId s) final
bool Find(Label label) final
Label Relabel(Label label)
const Arc & Value() const final
const Arc & Value() const
typename Arc::Label Label
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
ssize_t Priority(StateId s) final
constexpr uint32_t kOutputLookAheadMatcher
bool Find(Label label) final
LabelLookAheadRelabeler(std::shared_ptr< Impl > *impl)
void SetFlags(uint8_t flags, uint8_t mask)
const FST & GetFst() const override
MatchType Type(bool test) const override
LabelLookAheadMatcher(const LabelLookAheadMatcher &lmatcher, bool safe=false)
constexpr uint32_t kLookAheadEpsilons
constexpr uint32_t kLookAheadNonEpsilonPrefix
constexpr uint32_t kLookAheadKeepRelabelData
void ClearLookAheadPrefix()
LookAheadMatcher(const FST *fst, MatchType match_type)
static void RelabelPairs(const LFST &mfst, std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
const FST & GetFst() const
constexpr uint32_t kLookAheadNonEpsilons
const MatcherData * GetData() const
Weight Final(StateId s) const final
MatchType Type(bool test) const override
DECLARE_string(save_relabel_ipairs)
typename Arc::Label Label
const FST & GetFst() const override
constexpr uint32_t kInputLookAheadMatcher
ssize_t Priority(StateId s)
typename Reachable::Data MatcherData
LookAheadMatcher * Copy(bool safe=false) const
constexpr uint32_t kLookAheadPrefix
void SetState(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s) final
bool LookAheadPrefix(Arc *arc) const
virtual const Fst< Arc > & GetFst() const =0
void SetLookAheadWeight(Weight weight)
ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe=false)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
uint64_t Properties(uint64_t inprops) const override
typename Arc::Weight Weight
LookAheadMatcher(const LookAheadMatcher &matcher, bool safe=false)
LabelLookAheadMatcher * Copy(bool safe=false) const override
uint32_t Flags() const override
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false)
uint32_t Flags() const override
const MatcherData * GetData() const
ArcLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
constexpr uint32_t kLookAheadWeight
virtual MatcherBase * Copy(bool safe=false) const =0
typename Arc::StateId StateId
typename Arc::StateId StateId
LookAheadMatcher(const FST &fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
virtual void InitLookAheadFst(const Fst< Arc > &, bool copy=false)=0
void ClearLookAheadWeight()
Weight Final(StateId s) const
typename Arc::Weight Weight
typename Arc::Label Label
virtual MatchType Type(bool) const =0
Weight LookAheadWeight() const
Weight Final(StateId s) const final
Weight LookAheadWeight() const
typename Arc::Weight Weight
Weight Final(StateId s) const final
ssize_t Priority(StateId s) final
std::shared_ptr< MatcherData > GetSharedData() const
uint32_t Flags() const override
bool LookAheadFst(const Fst< Arc > &, StateId) final
const Arc & Value() const final
uint64_t Properties(uint64_t props) const override
LookAheadMatcher(MatcherBase< Arc > *base)
void InitLookAheadFst(const LFST &fst, bool copy=false)
LabelLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
typename Arc::Label Label
bool Find(Label label) final
bool LookAheadLabel(Label label) const final
bool LookAheadFst(const Fst< Arc > &, StateId) final
typename Arc::StateId StateId