20 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ 21 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ 23 #include <sys/types.h> 82 typedef typename A::Label
Label;
102 :
CacheImpl<A>(opts), data_(data), delay_(data->MaxFutureSize()) {
112 :
CacheImpl<A>(impl), data_(impl.data_), delay_(impl.delay_) {
122 StateId start = FindStartState();
131 FillState(s, &state_stub_);
132 if (CanBeFinal(state_stub_))
133 SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_),
134 InternalEnd(state_stub_)));
167 void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
178 LOG(ERROR) <<
"LinearTaggerFst::Write: Write failed: " << opts.
source;
185 static constexpr
int kMinFileVersion = 1;
186 static constexpr
int kFileVersion = 1;
199 typename std::vector<Label>::const_iterator BufferBegin(
200 const std::vector<Label> &state)
const {
201 return state.begin();
204 typename std::vector<Label>::const_iterator BufferEnd(
205 const std::vector<Label> &state)
const {
206 return state.begin() + delay_;
209 typename std::vector<Label>::const_iterator InternalBegin(
210 const std::vector<Label> &state)
const {
211 return state.begin() + delay_;
214 typename std::vector<Label>::const_iterator InternalEnd(
215 const std::vector<Label> &state)
const {
220 void ReserveStubSpace() {
221 state_stub_.reserve(delay_ + data_->NumGroups());
222 next_stub_.reserve(delay_ + data_->NumGroups());
226 StateId FindStartState() {
231 data_->EncodeStartState(&state_stub_);
232 return FindState(state_stub_);
236 bool IsEmptyBuffer(
typename std::vector<Label>::const_iterator begin,
237 typename std::vector<Label>::const_iterator end)
const {
249 bool CanBeFinal(
const std::vector<Label> &state) {
250 return IsEmptyBuffer(BufferBegin(state), BufferEnd(state));
255 StateId FindState(
const std::vector<Label> &ngram) {
256 StateId sparse = ngrams_.
FindId(ngram,
true);
257 StateId dense = condensed_.
FindId(sparse,
true);
263 void FillState(StateId s, std::vector<Label> *output) {
265 for (NGramIterator it = ngrams_.
FindSet(s); !it.Done(); it.Next()) {
266 Label label = it.Element();
267 output->push_back(label);
277 Label ShiftBuffer(
const std::vector<Label> &state, Label ilabel,
278 std::vector<Label> *next_stub_);
282 Arc MakeArc(
const std::vector<Label> &state, Label ilabel, Label olabel,
283 std::vector<Label> *next_stub_);
288 void ExpandArcs(StateId s,
const std::vector<Label> &state, Label ilabel,
289 std::vector<Label> *next_stub_);
294 void AppendArcs(StateId s,
const std::vector<Label> &state, Label ilabel,
295 std::vector<Label> *next_stub_, std::vector<Arc> *arcs);
297 std::shared_ptr<const LinearFstData<A>> data_;
304 std::vector<Label> state_stub_;
305 std::vector<Label> next_stub_;
312 const std::vector<Label> &state,
Label ilabel,
313 std::vector<Label> *next_stub_) {
319 (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel;
320 return *BufferBegin(state);
327 std::vector<Label> *next_stub_) {
330 Weight weight(Weight::One());
331 data_->TakeTransition(BufferEnd(state), InternalBegin(state),
332 InternalEnd(state), ilabel, olabel, next_stub_,
334 StateId nextstate = FindState(*next_stub_);
336 next_stub_->resize(delay_);
345 const std::vector<Label> &state,
347 std::vector<Label> *next_stub_) {
351 Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
357 std::pair<typename std::vector<typename A::Label>::const_iterator,
358 typename std::vector<typename A::Label>::const_iterator>
359 range = data_->PossibleOutputLabels(obs_ilabel);
360 for (
typename std::vector<typename A::Label>::const_iterator it =
362 it != range.second; ++it)
363 PushArc(s, MakeArc(state, ilabel, *it, next_stub_));
370 const std::vector<Label> &state,
372 std::vector<Label> *next_stub_,
373 std::vector<Arc> *arcs) {
377 Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
383 std::pair<typename std::vector<typename A::Label>::const_iterator,
384 typename std::vector<typename A::Label>::const_iterator>
385 range = data_->PossibleOutputLabels(obs_ilabel);
386 for (
typename std::vector<typename A::Label>::const_iterator it =
388 it != range.second; ++it)
389 arcs->push_back(MakeArc(state, ilabel, *it, next_stub_));
395 VLOG(3) <<
"Expand " << s;
397 FillState(s, &state_stub_);
402 next_stub_.resize(delay_);
404 std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
408 if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
414 for (
Label ilabel = data_->MinInputLabel();
415 ilabel <= data_->MaxInputLabel(); ++ilabel) {
416 ExpandArcs(s, state_stub_, ilabel, &next_stub_);
425 std::vector<Arc> *arcs) {
427 FillState(s, &state_stub_);
432 next_stub_.resize(delay_);
434 std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
439 if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
446 AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs);
455 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
462 impl->delay_ = impl->data_->MaxFutureSize();
463 impl->ReserveStubSpace();
464 return impl.release();
496 LOG(FATAL) <<
"LinearTaggerFst: no constructor from arbitrary FST.";
511 GetMutableImpl()->InitArcIterator(s, data);
519 if (!source.empty()) {
520 std::ifstream strm(source,
521 std::ios_base::in | std::ios_base::binary);
523 LOG(ERROR) <<
"LinearTaggerFst::Read: Can't open file: " << source;
534 auto *impl = Impl::Read(in, opts);
538 bool Write(
const std::string &source)
const override {
539 if (!source.empty()) {
540 std::ofstream strm(source,
541 std::ios_base::out | std::ios_base::binary);
543 LOG(ERROR) <<
"LinearTaggerFst::Write: Can't open file: " << source;
553 return GetImpl()->Write(strm, opts);
580 using StateId =
typename Arc::StateId;
591 data->
base = std::make_unique<StateIterator<LinearTaggerFst<Arc>>>(*this);
637 num_classes_(num_classes),
638 num_groups_(data_->NumGroups() / num_classes_) {
650 num_classes_(impl.num_classes_),
651 num_groups_(impl.num_groups_) {
661 StateId start = FindStartState();
670 FillState(s, &state_stub_);
671 SetFinal(s, FinalWeight(state_stub_));
702 void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
714 LOG(ERROR) <<
"LinearClassifierFst::Write: Write failed: " << opts.
source;
721 static constexpr
int kMinFileVersion = 0;
722 static constexpr
int kFileVersion = 0;
733 Label &Prediction(std::vector<Label> &state) {
return state[0]; }
734 Label Prediction(
const std::vector<Label> &state)
const {
return state[0]; }
736 Label &InternalAt(std::vector<Label> &state,
int index) {
737 return state[index + 1];
739 Label InternalAt(
const std::vector<Label> &state,
int index)
const {
740 return state[index + 1];
744 void ReserveStubSpace() {
745 size_t size = 1 + num_groups_;
746 state_stub_.reserve(size);
747 next_stub_.reserve(size);
751 StateId FindStartState() {
756 for (
size_t i = 0; i < num_groups_; ++i)
758 return FindState(state_stub_);
762 bool IsStartState(
const std::vector<Label> &state)
const {
767 int GroupId(Label pred,
int group)
const {
768 return group * num_classes_ + pred - 1;
773 Weight FinalWeight(
const std::vector<Label> &state)
const {
774 if (IsStartState(state)) {
775 return Weight::Zero();
777 Label pred = Prediction(state);
780 Weight final_weight = Weight::One();
781 for (
size_t group = 0; group < num_groups_; ++group) {
782 int group_id = GroupId(pred, group);
783 int trie_state = InternalAt(state, group);
785 Times(final_weight, data_->GroupFinalWeight(group_id, trie_state));
792 StateId FindState(
const std::vector<Label> &ngram) {
793 StateId sparse = ngrams_.FindId(ngram,
true);
794 StateId dense = condensed_.FindId(sparse,
true);
800 void FillState(StateId s, std::vector<Label> *output) {
801 s = condensed_.FindEntry(s);
802 for (NGramIterator it = ngrams_.FindSet(s); !it.Done(); it.Next()) {
803 output->emplace_back(it.Element());
807 std::shared_ptr<const LinearFstData<A>> data_;
817 std::vector<Label> state_stub_;
818 std::vector<Label> next_stub_;
825 VLOG(3) <<
"Expand " << s;
827 FillState(s, &state_stub_);
829 next_stub_.resize(1 + num_groups_);
831 if (IsStartState(state_stub_)) {
833 for (Label pred = 1; pred <= num_classes_; ++pred) {
834 Prediction(next_stub_) = pred;
835 for (
int i = 0; i < num_groups_; ++i)
836 InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
837 PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_)));
840 Label pred = Prediction(state_stub_);
843 for (Label ilabel = data_->MinInputLabel();
844 ilabel <= data_->MaxInputLabel(); ++ilabel) {
845 Prediction(next_stub_) = pred;
846 Weight weight = Weight::One();
847 for (
int i = 0; i < num_groups_; ++i)
848 InternalAt(next_stub_, i) =
849 data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i),
850 ilabel, pred, &weight);
851 PushArc(s, A(ilabel, 0, weight, FindState(next_stub_)));
860 std::vector<Arc> *arcs) {
862 FillState(s, &state_stub_);
864 next_stub_.resize(1 + num_groups_);
866 if (IsStartState(state_stub_)) {
869 for (Label pred = 1; pred <= num_classes_; ++pred) {
870 Prediction(next_stub_) = pred;
871 for (
int i = 0; i < num_groups_; ++i)
872 InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
873 arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_)));
876 }
else if (ilabel != 0) {
877 Label pred = Prediction(state_stub_);
878 Weight weight = Weight::One();
879 Prediction(next_stub_) = pred;
880 for (
int i = 0; i < num_groups_; ++i)
881 InternalAt(next_stub_, i) = data_->GroupTransition(
882 GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight);
883 arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_)));
890 std::unique_ptr<LinearClassifierFstImpl<A>> impl(
893 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
900 ReadType(strm, &impl->num_classes_);
904 impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_;
905 if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) {
906 FSTERROR() <<
"Total number of feature groups is not a multiple of the " 907 "number of classes: num groups = " 908 << impl->data_->NumGroups()
909 <<
", num classes = " << impl->num_classes_;
912 impl->ReserveStubSpace();
913 return impl.release();
922 :
public ImplToFst<internal::LinearClassifierFstImpl<A>> {
943 std::make_shared<
Impl>(data, num_classes, isyms, osyms, opts)) {}
947 LOG(FATAL) <<
"LinearClassifierFst: no constructor from arbitrary FST.";
962 GetMutableImpl()->InitArcIterator(s, data);
970 if (!source.empty()) {
971 std::ifstream strm(source,
972 std::ios_base::in | std::ios_base::binary);
974 LOG(ERROR) <<
"LinearClassifierFst::Read: Can't open file: " << source;
985 auto *impl = Impl::Read(in, opts);
990 bool Write(
const std::string &source)
const override {
991 if (!source.empty()) {
992 std::ofstream strm(source,
993 std::ios_base::out | std::ios_base::binary);
995 LOG(ERROR) <<
"ProdLmFst::Write: Can't open file: " << source;
1005 return GetImpl()->Write(strm, opts);
1019 template <
class Arc>
1025 fst.GetMutableImpl()) {}
1029 template <
class Arc>
1033 using StateId =
typename Arc::StateId;
1041 template <
class Arc>
1044 data->
base = std::make_unique<StateIterator<LinearClassifierFst<Arc>>>(*this);
1062 : owned_fst_(fst.Copy()),
1064 match_type_(match_type),
1066 current_loop_(false),
1070 switch (match_type_) {
1076 FSTERROR() <<
"LinearFstMatcherTpl: Bad match type";
1085 match_type_(match_type),
1087 current_loop_(false),
1091 switch (match_type_) {
1097 FSTERROR() <<
"LinearFstMatcherTpl: Bad match type";
1105 : owned_fst_(matcher.fst_.Copy(safe)),
1107 match_type_(matcher.match_type_),
1109 current_loop_(false),
1110 loop_(matcher.loop_),
1112 error_(matcher.error_) {}
1124 if (s_ == s)
return;
1128 FSTERROR() <<
"LinearFstMatcherTpl: Bad match type";
1131 loop_.nextstate = s;
1136 current_loop_ =
false;
1139 current_loop_ = label == 0;
1143 fst_.GetMutableImpl()->MatchInput(s_, label, &arcs_);
1144 return current_loop_ || !arcs_.empty();
1148 return !(current_loop_ || cur_arc_ < arcs_.size());
1152 return current_loop_ ? loop_ : arcs_[cur_arc_];
1157 current_loop_ =
false;
1164 const FST &
GetFst()
const override {
return fst_; }
1167 if (error_) props |=
kError;
1174 std::unique_ptr<const FST> owned_fst_;
1181 std::vector<Arc> arcs_;
1188 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ void InitStateIterator(StateIteratorData< A > *data) const override
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
LinearClassifierFst(const Fst< A > &fst)
void SetProperties(uint64_t props)
constexpr ssize_t kRequirePriority
LinearTaggerFstImpl(const LinearFstData< Arc > *data, const SymbolTable *isyms, const SymbolTable *osyms, CacheOptions opts)
size_t NumInputEpsilons(StateId s) const
I FindId(const std::vector< T > &set, bool insert=true)
bool HasFinal(StateId s) const
size_t NumOutputEpsilons(StateId s)
size_t NumInputEpsilons(StateId s)
SetIterator FindSet(I id)
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
DefaultCacheStore< A > Store
DefaultCacheStore< A > Store
void SetFinal(StateId s, Weight weight=Weight::One())
LinearFstMatcherTpl(const FST *fst, MatchType match_type)
uint64_t Properties(uint64_t props) const override
const SymbolTable * OutputSymbols() const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
const T & FindEntry(I s) const
bool Find(Label label) final
static LinearClassifierFst< A > * Read(std::istream &in, const FstReadOptions &opts)
LinearClassifierFstImpl(const LinearFstData< Arc > *data, size_t num_classes, const SymbolTable *isyms, const SymbolTable *osyms, CacheOptions opts)
LinearFstMatcherTpl(const LinearFstMatcherTpl< F > &matcher, bool safe=false)
uint32_t Flags() const override
void SetState(StateId s) final
void MatchInput(StateId s, Label ilabel, std::vector< Arc > *arcs)
void InitArcIterator(StateId s, ArcIteratorData< A > *data) const override
void SetOutputSymbols(const SymbolTable *osyms)
LinearClassifierFst(LinearFstData< A > *data, size_t num_classes, const SymbolTable *isyms=nullptr, const SymbolTable *osyms=nullptr, CacheOptions opts=CacheOptions())
Collection< StateId, Label >::SetIterator NGramIterator
LinearTaggerFst(LinearFstData< A > *data, const SymbolTable *isyms=nullptr, const SymbolTable *osyms=nullptr, CacheOptions opts=CacheOptions())
MatcherBase< A > * InitMatcher(MatchType match_type) const override
LinearClassifierFst(const LinearClassifierFst< A > &fst, bool safe=false)
size_t NumArcs(StateId s) const
static LinearTaggerFstImpl * Read(std::istream &strm, const FstReadOptions &opts)
std::ostream & WriteType(std::ostream &strm, const T t)
ssize_t Priority(StateId s) final
static LinearClassifierFst< A > * Read(const std::string &source)
StateIterator(const LinearTaggerFst< Arc > &fst)
size_t NumOutputEpsilons(StateId s)
virtual uint64_t Properties() const
LinearTaggerFst(const LinearTaggerFst< A > &fst, bool safe=false)
const FST & GetFst() const override
LinearTaggerFst(const Fst< A > &fst)
constexpr uint64_t kCopyProperties
MatchType Type(bool) const override
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
bool Write(const std::string &source) const override
LinearClassifierFst< A > * Copy(bool safe=false) const override
StateIterator(const LinearClassifierFst< Arc > &fst)
ArcIterator(const LinearClassifierFst< Arc > &fst, StateId s)
MatcherBase< A > * InitMatcher(MatchType match_type) const override
constexpr int kNoTrieNodeId
static LinearTaggerFst< A > * Read(const std::string &source)
void WriteHeader(std::ostream &strm, const FstWriteOptions &opts, int version, FstHeader *hdr) const
bool Write(std::ostream &strm, const FstWriteOptions &opts) const
void InitArcIterator(StateId s, ArcIteratorData< A > *data)
ArcIterator(const LinearTaggerFst< Arc > &fst, StateId s)
std::unique_ptr< StateIteratorBase< Arc > > base
size_t NumArcs(StateId s)
LinearClassifierFstImpl()
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
LinearFstMatcherTpl(const FST &fst, MatchType match_type)
static LinearFstData< A > * Read(std::istream &strm)
void SetInputSymbols(const SymbolTable *isyms)
bool HasArcs(StateId s) const
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
constexpr uint64_t kILabelSorted
constexpr uint64_t kFstProperties
LinearClassifierFstImpl(const LinearClassifierFstImpl &impl)
void PushArc(StateId s, const Arc &arc)
const SymbolTable * InputSymbols() const
void InitArcIterator(StateId s, ArcIteratorData< A > *data)
void SetType(std::string_view type)
Weight Final(StateId s) const
Collection< StateId, Label >::SetIterator NGramIterator
void MatchInput(StateId s, Label ilabel, std::vector< Arc > *arcs)
LinearTaggerFstImpl(const LinearTaggerFstImpl &impl)
LinearTaggerFst< A > * Copy(bool safe=false) const override
std::istream & ReadType(std::istream &strm, T *t)
size_t NumInputEpsilons(StateId s)
Impl * GetMutableImpl() const
I FindId(const T &entry, bool insert=true)
size_t NumArcs(StateId s)
LinearFstMatcherTpl< F > * Copy(bool safe=false) const override
static LinearClassifierFstImpl< A > * Read(std::istream &strm, const FstReadOptions &opts)
void InitStateIterator(StateIteratorData< A > *data) const override
size_t NumOutputEpsilons(StateId s) const
const Arc & Value() const final
bool Write(const std::string &source) const override
static LinearTaggerFst< A > * Read(std::istream &in, const FstReadOptions &opts)
const Impl * GetImpl() const
constexpr uint32_t kRequireMatch
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override