20 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ 21 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ 70 typedef typename A::Label
Label;
90 :
CacheImpl<A>(opts), data_(data), delay_(data->MaxFutureSize()) {
100 :
CacheImpl<A>(impl), data_(impl.data_), delay_(impl.delay_) {
110 StateId start = FindStartState();
119 FillState(s, &state_stub_);
120 if (CanBeFinal(state_stub_))
121 SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_),
122 InternalEnd(state_stub_)));
155 void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
166 LOG(ERROR) <<
"LinearTaggerFst::Write: Write failed: " << opts.
source;
173 static constexpr
int kMinFileVersion = 1;
174 static constexpr
int kFileVersion = 1;
187 typename std::vector<Label>::const_iterator BufferBegin(
188 const std::vector<Label> &state)
const {
189 return state.begin();
192 typename std::vector<Label>::const_iterator BufferEnd(
193 const std::vector<Label> &state)
const {
194 return state.begin() + delay_;
197 typename std::vector<Label>::const_iterator InternalBegin(
198 const std::vector<Label> &state)
const {
199 return state.begin() + delay_;
202 typename std::vector<Label>::const_iterator InternalEnd(
203 const std::vector<Label> &state)
const {
208 void ReserveStubSpace() {
209 state_stub_.reserve(delay_ + data_->NumGroups());
210 next_stub_.reserve(delay_ + data_->NumGroups());
214 StateId FindStartState() {
219 data_->EncodeStartState(&state_stub_);
220 return FindState(state_stub_);
224 bool IsEmptyBuffer(
typename std::vector<Label>::const_iterator begin,
225 typename std::vector<Label>::const_iterator end)
const {
237 bool CanBeFinal(
const std::vector<Label> &state) {
238 return IsEmptyBuffer(BufferBegin(state), BufferEnd(state));
243 StateId FindState(
const std::vector<Label> &ngram) {
244 StateId sparse = ngrams_.
FindId(ngram,
true);
245 StateId dense = condensed_.
FindId(sparse,
true);
251 void FillState(StateId s, std::vector<Label> *output) {
253 for (NGramIterator it = ngrams_.
FindSet(s); !it.Done(); it.Next()) {
254 Label label = it.Element();
255 output->push_back(label);
265 Label ShiftBuffer(
const std::vector<Label> &state, Label ilabel,
266 std::vector<Label> *next_stub_);
270 Arc MakeArc(
const std::vector<Label> &state, Label ilabel, Label olabel,
271 std::vector<Label> *next_stub_);
276 void ExpandArcs(StateId s,
const std::vector<Label> &state, Label ilabel,
277 std::vector<Label> *next_stub_);
282 void AppendArcs(StateId s,
const std::vector<Label> &state, Label ilabel,
283 std::vector<Label> *next_stub_, std::vector<Arc> *arcs);
285 std::shared_ptr<const LinearFstData<A>> data_;
292 std::vector<Label> state_stub_;
293 std::vector<Label> next_stub_;
300 const std::vector<Label> &state,
Label ilabel,
301 std::vector<Label> *next_stub_) {
307 (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel;
308 return *BufferBegin(state);
315 std::vector<Label> *next_stub_) {
318 Weight weight(Weight::One());
319 data_->TakeTransition(BufferEnd(state), InternalBegin(state),
320 InternalEnd(state), ilabel, olabel, next_stub_,
322 StateId nextstate = FindState(*next_stub_);
324 next_stub_->resize(delay_);
333 const std::vector<Label> &state,
335 std::vector<Label> *next_stub_) {
339 Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
345 std::pair<typename std::vector<typename A::Label>::const_iterator,
346 typename std::vector<typename A::Label>::const_iterator>
347 range = data_->PossibleOutputLabels(obs_ilabel);
348 for (
typename std::vector<typename A::Label>::const_iterator it =
350 it != range.second; ++it)
351 PushArc(s, MakeArc(state, ilabel, *it, next_stub_));
358 const std::vector<Label> &state,
360 std::vector<Label> *next_stub_,
361 std::vector<Arc> *arcs) {
365 Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
371 std::pair<typename std::vector<typename A::Label>::const_iterator,
372 typename std::vector<typename A::Label>::const_iterator>
373 range = data_->PossibleOutputLabels(obs_ilabel);
374 for (
typename std::vector<typename A::Label>::const_iterator it =
376 it != range.second; ++it)
377 arcs->push_back(MakeArc(state, ilabel, *it, next_stub_));
383 VLOG(3) <<
"Expand " << s;
385 FillState(s, &state_stub_);
390 next_stub_.resize(delay_);
392 std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
396 if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
402 for (
Label ilabel = data_->MinInputLabel();
403 ilabel <= data_->MaxInputLabel(); ++ilabel) {
404 ExpandArcs(s, state_stub_, ilabel, &next_stub_);
413 std::vector<Arc> *arcs) {
415 FillState(s, &state_stub_);
420 next_stub_.resize(delay_);
422 std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
427 if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
434 AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs);
443 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
450 impl->delay_ = impl->data_->MaxFutureSize();
451 impl->ReserveStubSpace();
452 return impl.release();
484 LOG(FATAL) <<
"LinearTaggerFst: no constructor from arbitrary FST.";
499 GetMutableImpl()->InitArcIterator(s, data);
507 if (!source.empty()) {
508 std::ifstream strm(source,
509 std::ios_base::in | std::ios_base::binary);
511 LOG(ERROR) <<
"LinearTaggerFst::Read: Can't open file: " << source;
522 auto *impl = Impl::Read(in, opts);
526 bool Write(
const std::string &source)
const override {
527 if (!source.empty()) {
528 std::ofstream strm(source,
529 std::ios_base::out | std::ios_base::binary);
531 LOG(ERROR) <<
"LinearTaggerFst::Write: Can't open file: " << source;
541 return GetImpl()->Write(strm, opts);
568 using StateId =
typename Arc::StateId;
579 data->
base = std::make_unique<StateIterator<LinearTaggerFst<Arc>>>(*this);
625 num_classes_(num_classes),
626 num_groups_(data_->NumGroups() / num_classes_) {
638 num_classes_(impl.num_classes_),
639 num_groups_(impl.num_groups_) {
649 StateId start = FindStartState();
658 FillState(s, &state_stub_);
659 SetFinal(s, FinalWeight(state_stub_));
690 void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
702 LOG(ERROR) <<
"LinearClassifierFst::Write: Write failed: " << opts.
source;
709 static constexpr
int kMinFileVersion = 0;
710 static constexpr
int kFileVersion = 0;
721 Label &Prediction(std::vector<Label> &state) {
return state[0]; }
722 Label Prediction(
const std::vector<Label> &state)
const {
return state[0]; }
724 Label &InternalAt(std::vector<Label> &state,
int index) {
725 return state[index + 1];
727 Label InternalAt(
const std::vector<Label> &state,
int index)
const {
728 return state[index + 1];
732 void ReserveStubSpace() {
733 size_t size = 1 + num_groups_;
734 state_stub_.reserve(size);
735 next_stub_.reserve(size);
739 StateId FindStartState() {
744 for (
size_t i = 0; i < num_groups_; ++i)
746 return FindState(state_stub_);
750 bool IsStartState(
const std::vector<Label> &state)
const {
755 int GroupId(Label pred,
int group)
const {
756 return group * num_classes_ + pred - 1;
761 Weight FinalWeight(
const std::vector<Label> &state)
const {
762 if (IsStartState(state)) {
763 return Weight::Zero();
765 Label pred = Prediction(state);
768 Weight final_weight = Weight::One();
769 for (
size_t group = 0; group < num_groups_; ++group) {
770 int group_id = GroupId(pred, group);
771 int trie_state = InternalAt(state, group);
773 Times(final_weight, data_->GroupFinalWeight(group_id, trie_state));
780 StateId FindState(
const std::vector<Label> &ngram) {
781 StateId sparse = ngrams_.FindId(ngram,
true);
782 StateId dense = condensed_.FindId(sparse,
true);
788 void FillState(StateId s, std::vector<Label> *output) {
789 s = condensed_.FindEntry(s);
790 for (NGramIterator it = ngrams_.FindSet(s); !it.Done(); it.Next()) {
791 output->emplace_back(it.Element());
795 std::shared_ptr<const LinearFstData<A>> data_;
805 std::vector<Label> state_stub_;
806 std::vector<Label> next_stub_;
813 VLOG(3) <<
"Expand " << s;
815 FillState(s, &state_stub_);
817 next_stub_.resize(1 + num_groups_);
819 if (IsStartState(state_stub_)) {
821 for (Label pred = 1; pred <= num_classes_; ++pred) {
822 Prediction(next_stub_) = pred;
823 for (
int i = 0; i < num_groups_; ++i)
824 InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
825 PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_)));
828 Label pred = Prediction(state_stub_);
831 for (Label ilabel = data_->MinInputLabel();
832 ilabel <= data_->MaxInputLabel(); ++ilabel) {
833 Prediction(next_stub_) = pred;
834 Weight weight = Weight::One();
835 for (
int i = 0; i < num_groups_; ++i)
836 InternalAt(next_stub_, i) =
837 data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i),
838 ilabel, pred, &weight);
839 PushArc(s, A(ilabel, 0, weight, FindState(next_stub_)));
848 std::vector<Arc> *arcs) {
850 FillState(s, &state_stub_);
852 next_stub_.resize(1 + num_groups_);
854 if (IsStartState(state_stub_)) {
857 for (Label pred = 1; pred <= num_classes_; ++pred) {
858 Prediction(next_stub_) = pred;
859 for (
int i = 0; i < num_groups_; ++i)
860 InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
861 arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_)));
864 }
else if (ilabel != 0) {
865 Label pred = Prediction(state_stub_);
866 Weight weight = Weight::One();
867 Prediction(next_stub_) = pred;
868 for (
int i = 0; i < num_groups_; ++i)
869 InternalAt(next_stub_, i) = data_->GroupTransition(
870 GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight);
871 arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_)));
878 std::unique_ptr<LinearClassifierFstImpl<A>> impl(
881 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
888 ReadType(strm, &impl->num_classes_);
892 impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_;
893 if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) {
894 FSTERROR() <<
"Total number of feature groups is not a multiple of the " 895 "number of classes: num groups = " 896 << impl->data_->NumGroups()
897 <<
", num classes = " << impl->num_classes_;
900 impl->ReserveStubSpace();
901 return impl.release();
910 :
public ImplToFst<internal::LinearClassifierFstImpl<A>> {
931 std::make_shared<
Impl>(data, num_classes, isyms, osyms, opts)) {}
935 LOG(FATAL) <<
"LinearClassifierFst: no constructor from arbitrary FST.";
950 GetMutableImpl()->InitArcIterator(s, data);
958 if (!source.empty()) {
959 std::ifstream strm(source,
960 std::ios_base::in | std::ios_base::binary);
962 LOG(ERROR) <<
"LinearClassifierFst::Read: Can't open file: " << source;
973 auto *impl = Impl::Read(in, opts);
978 bool Write(
const std::string &source)
const override {
979 if (!source.empty()) {
980 std::ofstream strm(source,
981 std::ios_base::out | std::ios_base::binary);
983 LOG(ERROR) <<
"ProdLmFst::Write: Can't open file: " << source;
993 return GetImpl()->Write(strm, opts);
1007 template <
class Arc>
1013 fst.GetMutableImpl()) {}
1017 template <
class Arc>
1021 using StateId =
typename Arc::StateId;
1029 template <
class Arc>
1033 std::make_unique<StateIterator<LinearClassifierFst<Arc>>>(*this);
1051 : owned_fst_(fst.Copy()),
1053 match_type_(match_type),
1055 current_loop_(false),
1059 switch (match_type_) {
1065 FSTERROR() <<
"LinearFstMatcherTpl: Bad match type";
1074 match_type_(match_type),
1076 current_loop_(false),
1080 switch (match_type_) {
1086 FSTERROR() <<
"LinearFstMatcherTpl: Bad match type";
1094 : owned_fst_(matcher.fst_.Copy(safe)),
1096 match_type_(matcher.match_type_),
1098 current_loop_(false),
1099 loop_(matcher.loop_),
1101 error_(matcher.error_) {}
1113 if (s_ == s)
return;
1117 FSTERROR() <<
"LinearFstMatcherTpl: Bad match type";
1120 loop_.nextstate = s;
1125 current_loop_ =
false;
1128 current_loop_ = label == 0;
1132 fst_.GetMutableImpl()->MatchInput(s_, label, &arcs_);
1133 return current_loop_ || !arcs_.empty();
1137 return !(current_loop_ || cur_arc_ < arcs_.size());
1141 return current_loop_ ? loop_ : arcs_[cur_arc_];
1146 current_loop_ =
false;
1153 const FST &
GetFst()
const override {
return fst_; }
1156 if (error_) props |=
kError;
1163 std::unique_ptr<const FST> owned_fst_;
1170 std::vector<Arc> arcs_;
1177 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ void InitStateIterator(StateIteratorData< A > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
LinearClassifierFst(const Fst< A > &fst)
void SetProperties(uint64_t props)
constexpr ssize_t kRequirePriority
LinearTaggerFstImpl(const LinearFstData< Arc > *data, const SymbolTable *isyms, const SymbolTable *osyms, CacheOptions opts)
size_t NumInputEpsilons(StateId s) const
I FindId(const std::vector< T > &set, bool insert=true)
bool HasFinal(StateId s) const
size_t NumOutputEpsilons(StateId s)
size_t NumInputEpsilons(StateId s)
SetIterator FindSet(I id)
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
DefaultCacheStore< A > Store
DefaultCacheStore< A > Store
void SetFinal(StateId s, Weight weight=Weight::One())
LinearFstMatcherTpl(const FST *fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
const SymbolTable * OutputSymbols() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
const T & FindEntry(I s) const
bool Find(Label label) final
static LinearClassifierFst< A > * Read(std::istream &in, const FstReadOptions &opts)
LinearClassifierFstImpl(const LinearFstData< Arc > *data, size_t num_classes, const SymbolTable *isyms, const SymbolTable *osyms, CacheOptions opts)
LinearFstMatcherTpl(const LinearFstMatcherTpl< F > &matcher, bool safe=false)
uint32_t Flags() const override
void SetState(StateId s) final
void MatchInput(StateId s, Label ilabel, std::vector< Arc > *arcs)
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
void SetOutputSymbols(const SymbolTable *osyms)
LinearClassifierFst(LinearFstData< A > *data, size_t num_classes, const SymbolTable *isyms=nullptr, const SymbolTable *osyms=nullptr, CacheOptions opts=CacheOptions())
Collection< StateId, Label >::SetIterator NGramIterator
LinearTaggerFst(LinearFstData< A > *data, const SymbolTable *isyms=nullptr, const SymbolTable *osyms=nullptr, CacheOptions opts=CacheOptions())
MatcherBase< A > * InitMatcher(MatchType match_type) const override
LinearClassifierFst(const LinearClassifierFst< A > &fst, bool safe=false)
size_t NumArcs(StateId s) const
static LinearTaggerFstImpl * Read(std::istream &strm, const FstReadOptions &opts)
std::ostream & WriteType(std::ostream &strm, const T t)
ssize_t Priority(StateId s) final
static LinearClassifierFst< A > * Read(const std::string &source)
StateIterator(const LinearTaggerFst< Arc > &fst)
size_t NumOutputEpsilons(StateId s)
virtual uint64_t Properties() const
LinearTaggerFst(const LinearTaggerFst< A > &fst, bool safe=false)
const FST & GetFst() const override
LinearTaggerFst(const Fst< A > &fst)
constexpr uint64_t kCopyProperties
MatchType Type(bool) const override
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
bool Write(const std::string &source) const override
LinearClassifierFst< A > * Copy(bool safe=false) const override
StateIterator(const LinearClassifierFst< Arc > &fst)
ArcIterator(const LinearClassifierFst< Arc > &fst, StateId s)
MatcherBase< A > * InitMatcher(MatchType match_type) const override
constexpr int kNoTrieNodeId
static LinearTaggerFst< A > * Read(const std::string &source)
void WriteHeader(std::ostream &strm, const FstWriteOptions &opts, int version, FstHeader *hdr) const
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
void InitArcIterator(StateId s, ArcIteratorData< A > *data)
ArcIterator(const LinearTaggerFst< Arc > &fst, StateId s)
std::unique_ptr< StateIteratorBase< Arc > > base
size_t NumArcs(StateId s)
LinearClassifierFstImpl()
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
LinearFstMatcherTpl(const FST &fst, MatchType match_type)
static LinearFstData< A > * Read(std::istream &strm)
void SetInputSymbols(const SymbolTable *isyms)
bool HasArcs(StateId s) const
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
constexpr uint64_t kILabelSorted
constexpr uint64_t kFstProperties
LinearClassifierFstImpl(const LinearClassifierFstImpl &impl)
void PushArc(StateId s, const Arc &arc)
const SymbolTable * InputSymbols() const
void InitArcIterator(StateId s, ArcIteratorData< A > *data)
void SetType(std::string_view type)
Weight Final(StateId s) const
Collection< StateId, Label >::SetIterator NGramIterator
void MatchInput(StateId s, Label ilabel, std::vector< Arc > *arcs)
LinearTaggerFstImpl(const LinearTaggerFstImpl &impl)
LinearTaggerFst< A > * Copy(bool safe=false) const override
std::istream & ReadType(std::istream &strm, T *t)
size_t NumInputEpsilons(StateId s)
Impl * GetMutableImpl() const
I FindId(const T &entry, bool insert=true)
size_t NumArcs(StateId s)
LinearFstMatcherTpl< F > * Copy(bool safe=false) const override
static LinearClassifierFstImpl< A > * Read(std::istream &strm, const FstReadOptions &opts)
void InitStateIterator(StateIteratorData< A > *data) const override
size_t NumOutputEpsilons(StateId s) const
const Arc & Value() const final
bool Write(const std::string &source) const override
static LinearTaggerFst< A > * Read(std::istream &in, const FstReadOptions &opts)
const Impl * GetImpl() const
constexpr uint32_t kRequireMatch
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override