20 #ifndef FST_VECTOR_FST_H_ 21 #define FST_VECTOR_FST_H_ 39 template <
class A,
class S>
42 template <
class F,
class G>
43 void Cast(
const F &, G *);
47 template <
class A,
class M >
59 : final_weight_(
Weight::Zero()),
65 : final_weight_(state.
Final()),
68 arcs_(state.arcs_.begin(), state.arcs_.end(), alloc) {}
71 final_weight_ = Weight::Zero();
83 size_t NumArcs()
const {
return arcs_.size(); }
85 const Arc &
GetArc(
size_t n)
const {
return arcs_[n]; }
87 const Arc *
Arcs()
const {
return !arcs_.empty() ? &arcs_[0] :
nullptr; }
100 IncrementNumEpsilons(arc);
101 arcs_.push_back(arc);
105 IncrementNumEpsilons(arc);
106 arcs_.push_back(std::move(arc));
109 template <
class... T>
111 arcs_.emplace_back(std::forward<T>(ctor_args)...);
112 IncrementNumEpsilons(arcs_.back());
116 if (arcs_[n].ilabel == 0) --niepsilons_;
117 if (arcs_[n].olabel == 0) --noepsilons_;
118 IncrementNumEpsilons(arc);
129 for (
size_t i = 0; i < n; ++i) {
130 if (arcs_.back().ilabel == 0) --niepsilons_;
131 if (arcs_.back().olabel == 0) --noepsilons_;
138 return alloc->allocate(1);
144 state->~VectorState<A, M>();
145 alloc->deallocate(state, 1);
151 void IncrementNumEpsilons(
const Arc &arc) {
152 if (arc.ilabel == 0) ++niepsilons_;
153 if (arc.olabel == 0) ++noepsilons_;
159 std::vector<A, ArcAllocator> arcs_;
170 using Arc =
typename State::Arc;
187 states_(std::move(impl.states_)),
188 start_(impl.start_) {
189 impl.states_.clear();
194 for (
auto *state : states_) {
198 std::swap(states_, impl.states_);
199 start_ = impl.start_;
213 return GetState(state)->NumInputEpsilons();
217 return GetState(state)->NumOutputEpsilons();
223 states_[state]->SetFinal(std::move(weight));
227 states_.push_back(state);
228 return states_.size() - 1;
234 const auto curr_num_states = NumStates();
235 states_.resize(n + curr_num_states);
236 std::generate(states_.begin() + curr_num_states, states_.end(),
237 [
this] {
return CreateState(); });
243 states_[state]->AddArc(std::move(arc));
246 template <
class... T>
248 states_[state]->EmplaceArc(std::forward<T>(ctor_args)...);
252 std::vector<StateId> newid(states_.size(), 0);
253 for (
size_t i = 0; i < dstates.size(); ++i) newid[dstates[i]] =
kNoStateId;
255 for (
StateId state = 0; state < states_.size(); ++state) {
257 newid[state] = nstates;
258 if (state != nstates) states_[nstates] = states_[state];
264 states_.resize(nstates);
265 for (
StateId state = 0; state < states_.size(); ++state) {
266 auto *arcs = states_[state]->MutableArcs();
268 auto nieps = states_[state]->NumInputEpsilons();
269 auto noeps = states_[state]->NumOutputEpsilons();
270 for (
size_t i = 0; i < states_[state]->NumArcs(); ++i) {
271 const auto t = newid[arcs[i].nextstate];
273 arcs[i].nextstate = t;
274 if (i != narcs) arcs[narcs] = arcs[i];
277 if (arcs[i].ilabel == 0) --nieps;
278 if (arcs[i].olabel == 0) --noeps;
281 states_[state]->DeleteArcs(states_[state]->
NumArcs() - narcs);
282 states_[state]->SetNumInputEpsilons(nieps);
283 states_[state]->SetNumOutputEpsilons(noeps);
285 if (Start() !=
kNoStateId) SetStart(newid[Start()]);
289 for (
size_t state = 0; state < states_.size(); ++state) {
312 data->
base =
nullptr;
313 data->
nstates = states_.size();
318 data->
base =
nullptr;
319 data->
narcs = states_[state]->NumArcs();
320 data->
arcs = states_[state]->Arcs();
325 State *CreateState() {
return new (&state_alloc_)
State(arc_alloc_); }
327 std::vector<State *> states_;
329 typename State::StateAllocator state_alloc_;
330 typename State::ArcAllocator arc_alloc_;
339 using Arc =
typename State::Arc;
369 BaseImpl::SetStart(state);
375 const auto properties =
377 BaseImpl::SetFinal(state, std::move(weight));
378 SetProperties(properties);
382 const auto state = BaseImpl::AddState();
388 BaseImpl::AddStates(n);
393 BaseImpl::AddArc(state, arc);
394 UpdatePropertiesAfterAddArc(state);
398 BaseImpl::AddArc(state, std::move(arc));
399 UpdatePropertiesAfterAddArc(state);
402 template <
class... T>
404 BaseImpl::EmplaceArc(state, std::forward<T>(ctor_args)...);
405 UpdatePropertiesAfterAddArc(state);
409 BaseImpl::DeleteStates(dstates);
414 BaseImpl::DeleteStates();
419 BaseImpl::DeleteArcs(state, n);
424 BaseImpl::DeleteArcs(state);
432 void UpdatePropertiesAfterAddArc(
StateId state) {
433 auto *vstate = GetState(state);
434 const size_t num_arcs{vstate->NumArcs()};
436 const auto &arc = vstate->GetArc(num_arcs - 1);
438 (num_arcs < 2) ?
nullptr : &(vstate->GetArc(num_arcs - 2));
444 static constexpr
int kMinFileVersion = 2;
452 BaseImpl::SetStart(fst.
Start());
457 const auto state = siter.Value();
458 BaseImpl::AddState();
459 BaseImpl::SetFinal(state, fst.
Final(state));
462 const auto &arc = aiter.Value();
463 BaseImpl::AddArc(state, arc);
472 auto impl = std::make_unique<VectorFstImpl>();
474 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr))
return nullptr;
475 impl->BaseImpl::SetStart(hdr.Start());
476 if (hdr.NumStates() !=
kNoStateId) impl->ReserveStates(hdr.NumStates());
478 for (; hdr.NumStates() ==
kNoStateId || state < hdr.NumStates(); ++state) {
480 if (!weight.Read(strm))
break;
481 impl->BaseImpl::AddState();
482 auto *vstate = impl->GetState(state);
483 vstate->SetFinal(weight);
487 LOG(ERROR) <<
"VectorFst::Read: Read failed: " << opts.
source;
490 impl->ReserveArcs(state, narcs);
491 for (int64_t i = 0; i < narcs; ++i) {
495 arc.weight.Read(strm);
498 LOG(ERROR) <<
"VectorFst::Read: Read failed: " << opts.
source;
501 impl->BaseImpl::AddArc(state, std::move(arc));
504 if (hdr.NumStates() !=
kNoStateId && state != hdr.NumStates()) {
505 LOG(ERROR) <<
"VectorFst::Read: Unexpected end of file: " << opts.
source;
508 return impl.release();
519 template <
class A,
class S >
532 template <
class F,
class G>
533 friend void Cast(
const F &, G *);
555 if (
this != &fst) SetImpl(std::make_shared<Impl>(fst));
559 template <
class... T>
562 GetMutableImpl()->EmplaceArc(state, std::forward<T>(ctor_args)...);
567 auto *impl = Impl::Read(strm, opts);
568 return impl ?
new VectorFst(std::shared_ptr<Impl>(impl)) :
nullptr;
575 return impl ?
new VectorFst(std::shared_ptr<Impl>(impl)) :
nullptr;
579 return WriteFst(*
this, strm, opts);
582 bool Write(
const std::string &source)
const override {
587 static bool WriteFst(
const FST &
fst, std::ostream &strm,
591 GetImpl()->InitStateIterator(data);
595 GetImpl()->InitArcIterator(s, data);
598 inline void InitMutableArcIterator(
StateId s,
610 explicit VectorFst(std::shared_ptr<Impl> impl)
614 template <
class Arc,
class State>
617 template <
class Arc,
class State>
623 template <
class Arc,
class State>
627 static constexpr
int file_version = 2;
628 bool update_header =
true;
632 std::streampos start_offset = 0;
634 (start_offset = strm.tellp()) != -1) {
636 update_header =
false;
638 const auto properties =
641 "vector", properties, &hdr);
644 const auto s = siter.Value();
645 fst.Final(s).Write(strm);
646 const int64_t narcs = fst.NumArcs(s);
649 const auto &arc = aiter.Value();
652 arc.weight.Write(strm);
659 LOG(ERROR) <<
"VectorFst::Write: Write failed: " << opts.
source;
665 fst, strm, opts, file_version,
"vector", properties, &hdr,
669 LOG(ERROR) <<
"Inconsistent number of states observed during write";
678 template <
class Arc,
class State>
684 : nstates_(fst.GetImpl()->NumStates()), s_(0) {}
686 bool Done()
const {
return s_ >= nstates_; }
701 template <
class Arc,
class State>
707 : arcs_(fst.GetImpl()->GetState(s)->
Arcs()),
708 narcs_(fst.GetImpl()->GetState(s)->
NumArcs()),
711 bool Done()
const {
return i_ >= narcs_; }
719 void Seek(
size_t a) { i_ = a; }
735 template <
class Arc,
class State>
745 properties_ = &fst->
GetImpl()->properties_;
748 bool Done() const final {
return i_ >= state_->NumArcs(); }
750 const Arc &
Value() const final {
return state_->GetArc(i_); }
758 void Seek(
size_t a)
final { i_ = a; }
761 const auto &oarc = state_->GetArc(i_);
762 uint64_t properties = properties_->load(std::memory_order_relaxed);
763 if (oarc.ilabel != oarc.olabel) properties &= ~
kNotAcceptor;
764 if (oarc.ilabel == 0) {
766 if (oarc.olabel == 0) properties &= ~
kEpsilons;
768 if (oarc.olabel == 0) properties &= ~
kOEpsilons;
769 if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) {
772 state_->SetArc(arc, i_);
773 if (arc.ilabel != arc.olabel) {
777 if (arc.ilabel == 0) {
780 if (arc.olabel == 0) {
785 if (arc.olabel == 0) {
789 if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
796 properties_->store(properties, std::memory_order_relaxed);
805 std::atomic<uint64_t> *properties_;
810 template <
class Arc,
class State>
814 std::make_unique<MutableArcIterator<VectorFst<Arc, State>>>(
this, s);
822 #endif // FST_VECTOR_FST_H_ typename std::allocator_traits< ArcAllocator >::template rebind_alloc< VectorState< Arc, M >> StateAllocator
void AddArc(StateId state, const Arc &arc)
typename Arc::Weight Weight
void SetState(StateId state, State *vstate)
uint64_t AddArcProperties(uint64_t inprops, typename A::StateId s, const A &arc, const A *prev_arc)
void SetFlags(uint8_t, uint8_t) final
void DeleteArcs(StateId state)
constexpr uint64_t kMutable
constexpr uint8_t kArcValueFlags
void AddArc(StateId state, Arc &&arc)
uint64_t SetStartProperties(uint64_t inprops)
void Cast(const F &, G *)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kOEpsilons
void DeleteStates(const std::vector< StateId > &dstates)
std::unique_ptr< MutableArcIteratorBase< Arc > > base
virtual size_t NumArcs(StateId) const =0
VectorFstBaseImpl & operator=(VectorFstBaseImpl &&impl) noexcept
VectorState(const ArcAllocator &alloc)
void DeleteArcs(StateId state, size_t n)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
void ReserveArcs(size_t n)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
uint64_t AddStateProperties(uint64_t inprops)
StateIterator(const VectorFst< Arc, State > &fst)
uint64_t DeleteAllStatesProperties(uint64_t inprops, uint64_t staticProps)
uint64_t DeleteStatesProperties(uint64_t inprops)
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
void SetNumOutputEpsilons(size_t n)
size_t NumArcs(StateId state) const
void SetFinal(StateId state, Weight weight)
virtual Weight Final(StateId) const =0
const Arc & Value() const final
static VectorFst * Read(const std::string &source)
static VectorFst * Read(std::istream &strm, const FstReadOptions &opts)
typename Arc::StateId StateId
constexpr uint64_t kEpsilons
uint64_t SetFinalProperties(uint64_t inprops, const Weight &old_weight, const Weight &new_weight)
const Arc & Value() const
void InitMutableArcIterator(StateId s, MutableArcIteratorData< Arc > *) override
typename Arc::Label Label
void DeleteStates(const std::vector< StateId > &dstates)
void AddArc(StateId state, Arc &&arc)
typename Arc::StateId StateId
VectorFst(const Fst< Arc > &fst)
std::ostream & WriteType(std::ostream &strm, const T t)
void InitStateIterator(StateIteratorData< Arc > *data) const
static void WriteFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr)
Weight Final(StateId state) const
const State * GetState(StateId state) const
MutableArcIterator(VectorFst< Arc, State > *fst, StateId s)
State * GetState(StateId state)
typename Arc::StateId StateId
void SetFinal(Weight weight)
void EmplaceArc(StateId state, T &&...ctor_args)
static void Destroy(VectorState< A, M > *state, StateAllocator *alloc)
constexpr uint64_t kNoOEpsilons
StateId NumStates() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
constexpr uint64_t kNotAcceptor
constexpr uint64_t kCopyProperties
VectorFst & operator=(const Fst< Arc > &fst) override
void AddArc(const Arc &arc)
constexpr uint64_t kNoEpsilons
ArcIterator(const VectorFst< Arc, State > &fst, StateId s)
static bool UpdateFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr, size_t header_offset)
void SetStart(StateId state)
typename Arc::Weight Weight
VectorFst * Copy(bool safe=false) const override
virtual StateId Start() const =0
uint8_t Flags() const final
std::unique_ptr< StateIteratorBase< Arc > > base
void DeleteArcs(size_t n)
void ReserveArcs(StateId state, size_t n)
void AddArc(StateId state, const Arc &arc)
size_t NumOutputEpsilons() const
VectorFst & operator=(const VectorFst &)=default
void EmplaceArc(StateId state, T &&...ctor_args)
typename Arc::StateId StateId
constexpr uint64_t kNullProperties
std::unique_ptr< ArcIteratorBase< Arc > > base
void ReserveStates(size_t n)
~VectorFstBaseImpl() override
constexpr uint64_t kIEpsilons
size_t NumInputEpsilons(StateId state) const
VectorState(const VectorState< A, M > &state, const ArcAllocator &alloc)
size_t NumOutputEpsilons(StateId state) const
bool WriteFile(const std::string &source) const
constexpr uint8_t Flags() const
size_t Position() const final
void SetNumInputEpsilons(size_t n)
bool Write(const std::string &source) const override
constexpr uint64_t kUnweighted
void EmplaceArc(StateId state, T &&...ctor_args)
const Arc & GetArc(size_t n) const
void Seek(size_t a) final
virtual const SymbolTable * InputSymbols() const =0
Arc::StateId CountStates(const Fst< Arc > &fst)
typename Arc::Weight Weight
VectorFst(const VectorFst &fst, bool unused_safe=false)
typename Arc::StateId StateId
void SetFlags(uint8_t, uint8_t)
void DeleteArcs(StateId state, size_t n)
void SetStart(StateId state)
void SetValue(const Arc &arc) final
VectorFstBaseImpl(VectorFstBaseImpl &&impl) noexcept
std::istream & ReadType(std::istream &strm, T *t)
uint64_t DeleteArcsProperties(uint64_t inprops)
typename Arc::StateId StateId
void DeleteArcs(StateId state)
size_t NumInputEpsilons() const
void InitArcIterator(StateId state, ArcIteratorData< Arc > *data) const
internal::VectorFstImpl< S > * GetMutableImpl() const
constexpr uint64_t kWeighted
void EmplaceArc(T &&...ctor_args)
static bool WriteFst(const FST &fst, std::ostream &strm, const FstWriteOptions &opts)
constexpr uint64_t kExpanded
void SetFinal(StateId state, Weight weight)
void SetArc(const Arc &arc, size_t n)
constexpr uint64_t kNoIEpsilons
constexpr uint64_t kSetArcProperties
void Destroy(ArcIterator< FST > *aiter, MemoryPool< ArcIterator< FST >> *pool)
StateId AddState(State *state)
const internal::VectorFstImpl< S > * GetImpl() const
constexpr uint64_t kAcceptor
virtual const SymbolTable * OutputSymbols() const =0