21 #ifndef FST_LOOKAHEAD_MATCHER_H_ 22 #define FST_LOOKAHEAD_MATCHER_H_ 24 #include <sys/types.h> 42 #include <string_view> 189 using Arc =
typename FST::Arc;
196 : matcher_(fst, match_type) {}
200 : matcher_(fst, match_type) {}
205 : matcher_(lmatcher.matcher_, safe) {}
215 bool Find(
Label label)
final {
return matcher_.Find(label); }
217 bool Done() const final {
return matcher_.Done(); }
219 const Arc &
Value() const final {
return matcher_.Value(); }
221 void Next() final { matcher_.Next(); }
227 const FST &
GetFst()
const override {
return matcher_.GetFst(); }
230 return matcher_.Properties(props);
255 template <
class M, uint32_t flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
256 kLookAheadWeight | kLookAheadPrefix>
260 using Arc =
typename FST::Arc;
273 static constexpr uint32_t kFlags = flags;
277 std::shared_ptr<MatcherData> data =
nullptr)
278 : matcher_(fst, match_type),
285 std::shared_ptr<MatcherData> data =
nullptr)
286 : matcher_(fst, match_type),
293 : matcher_(lmatcher.matcher_, safe),
295 lfst_(lmatcher.lfst_),
307 matcher_.SetState(s);
310 bool Find(
Label label)
final {
return matcher_.Find(label); }
312 bool Done() const final {
return matcher_.Done(); }
314 const Arc &
Value() const final {
return matcher_.Value(); }
316 void Next() final { matcher_.Next(); }
325 return matcher_.Properties(props);
329 return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
356 template <
class M, u
int32_t flags>
364 if (fst_.Final(state_) != Weight::Zero() &&
365 lfst_->Final(s) != Weight::Zero()) {
366 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
368 if (kFlags & kLookAheadWeight) {
375 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
377 if (kFlags & kLookAheadWeight) {
378 for (; !matcher_.Done(); matcher_.Next()) {
385 const auto &arc = aiter.Value();
387 switch (matcher_.Type(
false)) {
395 FSTERROR() <<
"ArcLookAheadMatcher::LookAheadFst: Bad match type";
399 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
400 if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
401 if (kFlags & kLookAheadWeight) {
405 }
else if (matcher_.Find(label)) {
406 if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix)))
return true;
407 for (; !matcher_.Done(); matcher_.Next()) {
409 if (kFlags & kLookAheadWeight) {
411 Times(arc.weight, matcher_.Value().weight)));
413 if ((kFlags & kLookAheadPrefix) && nprefix == 1)
419 if (kFlags & kLookAheadPrefix) {
432 uint32_t flags = kLookAheadEpsilons | kLookAheadWeight |
433 kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
445 using Arc =
typename FST::Arc;
458 static_assert(!(flags & kInputLookAheadMatcher) !=
459 !(flags & kOutputLookAheadMatcher),
460 "Must include precisely one of kInputLookAheadMatcher and " 461 "kOutputLookAheadMatcher");
462 static constexpr uint32_t kFlags = flags;
466 std::shared_ptr<MatcherData> data =
nullptr,
467 std::unique_ptr<Accumulator> accumulator =
nullptr)
468 : matcher_(fst, match_type),
472 Init(fst, match_type, data, std::move(accumulator));
477 std::shared_ptr<MatcherData> data =
nullptr,
478 std::unique_ptr<Accumulator> accumulator =
nullptr)
479 : matcher_(fst, match_type),
483 Init(*fst, match_type, data, std::move(accumulator));
489 : matcher_(lmatcher.matcher_, safe),
490 lfst_(lmatcher.lfst_),
491 label_reachable_(lmatcher.label_reachable_
492 ? new
Reachable(*lmatcher.label_reachable_, safe)
495 error_(lmatcher.error_) {}
504 if (state_ == s)
return;
506 match_set_state_ =
false;
507 reach_set_state_ =
false;
511 if (!match_set_state_) {
512 matcher_.SetState(state_);
513 match_set_state_ =
true;
515 return matcher_.Find(label);
518 bool Done() const final {
return matcher_.Done(); }
520 const Arc &
Value() const final {
return matcher_.Value(); }
522 void Next() final { matcher_.Next(); }
528 const FST &
GetFst()
const override {
return matcher_.GetFst(); }
531 auto outprops = matcher_.Properties(inprops);
532 if (error_ || (label_reachable_ && label_reachable_->Error())) {
539 if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
541 }
else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
544 return matcher_.Flags();
549 return label_reachable_ ? label_reachable_->GetData() :
nullptr;
553 return label_reachable_ ? label_reachable_->GetSharedData() :
nullptr;
557 template <
class LFST>
562 return LookAheadFst<Fst<Arc>>(fst, s);
567 if (label_reachable_) {
569 label_reachable_->ReachInit(fst, reach_input, copy);
573 template <
class LFST>
576 if (label_reachable_) {
578 label_reachable_->ReachInit(fst, reach_input, copy);
583 if (label == 0)
return true;
584 if (label_reachable_) {
585 if (!reach_set_state_) {
586 label_reachable_->SetState(state_);
587 reach_set_state_ =
true;
589 return label_reachable_->Reach(label);
597 std::shared_ptr<MatcherData> data,
598 std::unique_ptr<Accumulator> accumulator) {
599 const bool reach_input = match_type ==
MATCH_INPUT;
601 if (reach_input == data->ReachInput()) {
603 std::make_unique<Reachable>(data, std::move(accumulator));
605 }
else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
608 std::make_unique<Reachable>(fst, reach_input, std::move(accumulator),
615 std::unique_ptr<Reachable> label_reachable_;
617 bool match_set_state_;
618 mutable bool reach_set_state_;
622 template <
class M, u
int32_t flags,
class Accumulator,
class Reachable>
623 template <
class LFST>
630 if (!label_reachable_)
return true;
631 label_reachable_->SetState(state_, s);
632 reach_set_state_ =
true;
637 const bool reach_arc = label_reachable_->Reach(
640 const bool reach_final =
641 lfinal != Weight::Zero() && label_reachable_->ReachFinal();
643 const auto begin = label_reachable_->ReachBegin();
644 const auto end = label_reachable_->ReachEnd();
645 if (kComputePrefix && end - begin == 1 && !reach_final) {
648 compute_weight =
false;
649 }
else if (compute_weight) {
653 if (reach_final && compute_weight) {
656 return reach_arc || reach_final;
663 template <
class Reachable,
class FST,
class Data>
665 std::string_view save_relabel_ipairs,
666 std::string_view save_relabel_opairs) {
667 using Label =
typename FST::Arc::Label;
668 if (data.First() !=
nullptr) {
669 Reachable reachable(data.SharedFirst());
670 reachable.Relabel(fst,
true);
671 if (!save_relabel_ipairs.empty()) {
672 std::vector<std::pair<Label, Label>> pairs;
673 reachable.RelabelPairs(&pairs,
true);
674 std::sort(pairs.begin(), pairs.end());
678 Reachable reachable(data.SharedSecond());
679 reachable.Relabel(fst,
false);
680 if (!save_relabel_opairs.empty()) {
681 std::vector<std::pair<Label, Label>> pairs;
682 reachable.RelabelPairs(&pairs,
true);
683 std::sort(pairs.begin(), pairs.end());
690 template <
class Arc,
class Data = LabelReachableData<
typename Arc::Label>>
697 template <
typename Impl>
701 template <
class LFST>
703 bool relabel_input) {
704 const auto *data = mfst.GetAddOn();
705 Reachable reachable(data->First() ? data->SharedFirst()
706 : data->SharedSecond());
707 reachable.
Relabel(fst, relabel_input);
713 template <
class LFST>
715 std::vector<std::pair<Label, Label>> *pairs,
716 bool avoid_collisions =
false) {
717 const auto *data = mfst.GetAddOn();
718 Reachable reachable(data->First() ? data->SharedFirst()
719 : data->SharedSecond());
724 template <
class Arc,
class Data>
725 template <
typename Impl>
727 std::shared_ptr<Impl> *impl) {
729 auto data = (*impl)->GetSharedAddOn();
730 const auto name = (*impl)->Type();
732 std::unique_ptr<MutableFst<Arc>> mfst;
739 mfst = std::make_unique<VectorFst<Arc>>(fst);
742 RelabelForReachable<Reachable>(mfst.get(), *data,
743 FST_FLAGS_save_relabel_ipairs,
744 FST_FLAGS_save_relabel_opairs);
750 *impl = std::make_shared<Impl>(*mfst, name);
751 (*impl)->SetAddOn(data);
761 using Arc =
typename FST::Arc;
769 : owned_fst_(fst.
Copy()),
770 base_(owned_fst_->InitMatcher(match_type)),
774 std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
779 : base_(fst->InitMatcher(match_type)), lookahead_(false) {
780 if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
785 : base_(matcher.base_->
Copy(safe)), lookahead_(matcher.lookahead_) {}
789 : base_(base), lookahead_(false) {}
801 bool Done()
const {
return base_->Done(); }
803 const Arc &
Value()
const {
return base_->Value(); }
813 uint64_t
Properties(uint64_t props)
const {
return base_->Properties(props); }
815 uint32_t
Flags()
const {
return base_->Flags(); }
818 if (LookAheadCheck()) {
826 if (LookAheadCheck()) {
834 if (LookAheadCheck()) {
837 return Weight::One();
842 if (LookAheadCheck()) {
850 if (LookAheadCheck()) {
856 bool LookAheadCheck()
const {
861 FSTERROR() <<
"LookAheadMatcher: No look-ahead matcher defined";
867 std::unique_ptr<const FST> owned_fst_;
868 std::unique_ptr<MatcherBase<Arc>> base_;
869 mutable bool lookahead_;
876 #endif // FST_LOOKAHEAD_MATCHER_H_ typename Arc::Label Label
ssize_t Priority(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s)
ArcLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
virtual bool LookAheadLabel(Label) const =0
constexpr uint64_t kMutable
Weight LookAheadWeight() const
bool LookAheadLabel(Label label) const final
bool LookAheadPrefix(Arc *arc) const
uint64_t Properties(uint64_t props) const
constexpr uint8_t kArcNoCache
void RelabelForReachable(FST *fst, const Data &data, std::string_view save_relabel_ipairs, std::string_view save_relabel_opairs)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Arc::StateId StateId
bool LookAheadLabel(Label) const final
typename Arc::Label Label
TrivialLookAheadMatcher * Copy(bool safe=false) const override
const Arc & Value() const final
TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher, bool safe=false)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
typename Arc::StateId StateId
bool LookAheadPrefix(Arc *) const
typename Arc::Weight Weight
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
bool LookAheadLabel(Label label) const
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
std::shared_ptr< MatcherData > GetSharedData() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
const FST & GetFst() const override
bool WriteLabelPairs(std::string_view source, const std::vector< std::pair< Label, Label >> &pairs)
constexpr uint64_t kError
const Arc & Value() const
ArcLookAheadMatcher * Copy(bool safe=false) const override
void SetLookAheadPrefix(Arc arc)
constexpr uint32_t kLookAheadFlags
MatchType Type(bool test) const override
typename Arc::Weight Weight
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
LabelLookAheadMatcher(const FST *fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
typename Arc::Label Label
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
MatchType Type(bool test) const
void SetState(StateId s) final
virtual bool LookAheadFst(const Fst< Arc > &, StateId)=0
void SetState(StateId s) final
bool Find(Label label) final
Label Relabel(Label label)
const Arc & Value() const final
const Arc & Value() const
typename Arc::Label Label
void RelabelPairs(std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
ssize_t Priority(StateId s) final
constexpr uint32_t kOutputLookAheadMatcher
bool Find(Label label) final
LabelLookAheadRelabeler(std::shared_ptr< Impl > *impl)
void SetFlags(uint8_t flags, uint8_t mask)
const FST & GetFst() const override
MatchType Type(bool test) const override
LabelLookAheadMatcher(const LabelLookAheadMatcher &lmatcher, bool safe=false)
constexpr uint32_t kLookAheadEpsilons
constexpr uint32_t kLookAheadNonEpsilonPrefix
constexpr uint32_t kLookAheadKeepRelabelData
void ClearLookAheadPrefix()
LookAheadMatcher(const FST *fst, MatchType match_type)
static void RelabelPairs(const LFST &mfst, std::vector< std::pair< Label, Label >> *pairs, bool avoid_collisions=false)
const FST & GetFst() const
constexpr uint32_t kLookAheadNonEpsilons
const MatcherData * GetData() const
Weight Final(StateId s) const final
MatchType Type(bool test) const override
DECLARE_string(save_relabel_ipairs)
typename Arc::Label Label
const FST & GetFst() const override
constexpr uint32_t kInputLookAheadMatcher
ssize_t Priority(StateId s)
typename Reachable::Data MatcherData
LookAheadMatcher * Copy(bool safe=false) const
constexpr uint32_t kLookAheadPrefix
void SetState(StateId s) final
bool LookAheadFst(const Fst< Arc > &fst, StateId s) final
bool LookAheadPrefix(Arc *arc) const
virtual const Fst< Arc > & GetFst() const =0
void SetLookAheadWeight(Weight weight)
ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe=false)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
uint64_t Properties(uint64_t inprops) const override
typename Arc::Weight Weight
LookAheadMatcher(const LookAheadMatcher &matcher, bool safe=false)
LabelLookAheadMatcher * Copy(bool safe=false) const override
uint32_t Flags() const override
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false) override
TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
void InitLookAheadFst(const Fst< Arc > &fst, bool copy=false)
uint32_t Flags() const override
const MatcherData * GetData() const
ArcLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr)
constexpr uint32_t kLookAheadWeight
virtual MatcherBase * Copy(bool safe=false) const =0
typename Arc::StateId StateId
typename Arc::StateId StateId
LookAheadMatcher(const FST &fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
virtual void InitLookAheadFst(const Fst< Arc > &, bool copy=false)=0
void ClearLookAheadWeight()
Weight Final(StateId s) const
typename Arc::Weight Weight
typename Arc::Label Label
virtual MatchType Type(bool) const =0
Weight LookAheadWeight() const
Weight Final(StateId s) const final
Weight LookAheadWeight() const
typename Arc::Weight Weight
Weight Final(StateId s) const final
ssize_t Priority(StateId s) final
std::shared_ptr< MatcherData > GetSharedData() const
uint32_t Flags() const override
bool LookAheadFst(const Fst< Arc > &, StateId) final
const Arc & Value() const final
uint64_t Properties(uint64_t props) const override
LookAheadMatcher(MatcherBase< Arc > *base)
void InitLookAheadFst(const LFST &fst, bool copy=false)
LabelLookAheadMatcher(const FST &fst, MatchType match_type, std::shared_ptr< MatcherData > data=nullptr, std::unique_ptr< Accumulator > accumulator=nullptr)
typename Arc::Label Label
bool Find(Label label) final
bool LookAheadLabel(Label label) const final
bool LookAheadFst(const Fst< Arc > &, StateId) final
typename Arc::StateId StateId