20 #ifndef FST_VECTOR_FST_H_ 21 #define FST_VECTOR_FST_H_ 46 #include <string_view> 50 template <
class A,
class S>
53 template <
class F,
class G>
54 void Cast(
const F &, G *);
58 template <
class A,
class M >
72 : final_weight_(state.
Final()),
75 arcs_(state.arcs_.begin(), state.arcs_.end(), alloc) {}
78 final_weight_ = Weight::Zero();
90 size_t NumArcs()
const {
return arcs_.size(); }
92 const Arc &
GetArc(
size_t n)
const {
return arcs_[n]; }
94 const Arc *
Arcs()
const {
return !arcs_.empty() ? &arcs_[0] :
nullptr; }
107 IncrementNumEpsilons(arc);
108 arcs_.push_back(arc);
112 IncrementNumEpsilons(arc);
113 arcs_.push_back(std::move(arc));
116 template <
class... T>
118 arcs_.emplace_back(std::forward<T>(ctor_args)...);
119 IncrementNumEpsilons(arcs_.back());
123 if (arcs_[n].ilabel == 0) --niepsilons_;
124 if (arcs_[n].olabel == 0) --noepsilons_;
125 IncrementNumEpsilons(arc);
136 for (
size_t i = 0; i < n; ++i) {
137 if (arcs_.back().ilabel == 0) --niepsilons_;
138 if (arcs_.back().olabel == 0) --noepsilons_;
145 return alloc->allocate(1);
151 state->~VectorState<A, M>();
152 alloc->deallocate(state, 1);
158 void IncrementNumEpsilons(
const Arc &arc) {
159 if (arc.ilabel == 0) ++niepsilons_;
160 if (arc.olabel == 0) ++noepsilons_;
163 Weight final_weight_ = Weight::Zero();
164 size_t niepsilons_ = 0;
165 size_t noepsilons_ = 0;
166 std::vector<A, ArcAllocator> arcs_;
177 using Arc =
typename State::Arc;
194 states_(std::move(impl.states_)),
195 start_(impl.start_) {
196 impl.states_.clear();
201 for (
auto *state : states_) {
205 std::swap(states_, impl.states_);
206 start_ = impl.start_;
220 return GetState(state)->NumInputEpsilons();
224 return GetState(state)->NumOutputEpsilons();
230 states_[state]->SetFinal(std::move(weight));
234 states_.push_back(state);
235 return states_.size() - 1;
241 const auto curr_num_states = NumStates();
242 states_.resize(n + curr_num_states);
243 std::generate(states_.begin() + curr_num_states, states_.end(),
244 [
this] {
return CreateState(); });
250 states_[state]->AddArc(std::move(arc));
253 template <
class... T>
255 states_[state]->EmplaceArc(std::forward<T>(ctor_args)...);
259 std::vector<StateId> newid(states_.size(), 0);
260 for (
size_t i = 0; i < dstates.size(); ++i) newid[dstates[i]] =
kNoStateId;
262 for (
StateId state = 0; state < states_.size(); ++state) {
264 newid[state] = nstates;
265 if (state != nstates) states_[nstates] = states_[state];
271 states_.resize(nstates);
272 for (
StateId state = 0; state < states_.size(); ++state) {
273 auto *arcs = states_[state]->MutableArcs();
275 auto nieps = states_[state]->NumInputEpsilons();
276 auto noeps = states_[state]->NumOutputEpsilons();
277 for (
size_t i = 0; i < states_[state]->NumArcs(); ++i) {
278 const auto t = newid[arcs[i].nextstate];
280 arcs[i].nextstate = t;
281 if (i != narcs) arcs[narcs] = arcs[i];
284 if (arcs[i].ilabel == 0) --nieps;
285 if (arcs[i].olabel == 0) --noeps;
288 states_[state]->DeleteArcs(states_[state]->
NumArcs() - narcs);
289 states_[state]->SetNumInputEpsilons(nieps);
290 states_[state]->SetNumOutputEpsilons(noeps);
292 if (Start() !=
kNoStateId) SetStart(newid[Start()]);
296 for (
size_t state = 0; state < states_.size(); ++state) {
319 data->
base =
nullptr;
320 data->
nstates = states_.size();
325 data->
base =
nullptr;
326 data->
narcs = states_[state]->NumArcs();
327 data->
arcs = states_[state]->Arcs();
332 State *CreateState() {
return new (&state_alloc_)
State(arc_alloc_); }
334 std::vector<State *> states_;
336 typename State::StateAllocator state_alloc_;
337 typename State::ArcAllocator arc_alloc_;
346 using Arc =
typename State::Arc;
376 BaseImpl::SetStart(state);
382 const auto properties =
384 BaseImpl::SetFinal(state, std::move(weight));
385 SetProperties(properties);
389 const auto state = BaseImpl::AddState();
395 BaseImpl::AddStates(n);
400 BaseImpl::AddArc(state, arc);
401 UpdatePropertiesAfterAddArc(state);
405 BaseImpl::AddArc(state, std::move(arc));
406 UpdatePropertiesAfterAddArc(state);
409 template <
class... T>
411 BaseImpl::EmplaceArc(state, std::forward<T>(ctor_args)...);
412 UpdatePropertiesAfterAddArc(state);
416 BaseImpl::DeleteStates(dstates);
421 BaseImpl::DeleteStates();
426 BaseImpl::DeleteArcs(state, n);
431 BaseImpl::DeleteArcs(state);
439 void UpdatePropertiesAfterAddArc(
StateId state) {
440 auto *vstate = GetState(state);
441 const size_t num_arcs{vstate->NumArcs()};
443 const auto &arc = vstate->GetArc(num_arcs - 1);
445 (num_arcs < 2) ?
nullptr : &(vstate->GetArc(num_arcs - 2));
451 static constexpr
int kMinFileVersion = 2;
459 BaseImpl::SetStart(fst.
Start());
461 BaseImpl::ReserveStates(*num_states);
464 const auto state = siter.Value();
465 BaseImpl::AddState();
466 BaseImpl::SetFinal(state, fst.
Final(state));
469 const auto &arc = aiter.Value();
470 BaseImpl::AddArc(state, arc);
479 auto impl = std::make_unique<VectorFstImpl>();
481 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr))
return nullptr;
482 impl->BaseImpl::SetStart(hdr.Start());
483 if (hdr.NumStates() !=
kNoStateId) impl->ReserveStates(hdr.NumStates());
485 for (; hdr.NumStates() ==
kNoStateId || state < hdr.NumStates(); ++state) {
487 if (!weight.Read(strm))
break;
488 impl->BaseImpl::AddState();
489 auto *vstate = impl->GetState(state);
490 vstate->SetFinal(weight);
494 LOG(ERROR) <<
"VectorFst::Read: Read failed: " << opts.
source;
497 impl->ReserveArcs(state, narcs);
498 for (int64_t i = 0; i < narcs; ++i) {
502 arc.weight.Read(strm);
505 LOG(ERROR) <<
"VectorFst::Read: Read failed: " << opts.
source;
508 impl->BaseImpl::AddArc(state, std::move(arc));
511 if (hdr.NumStates() !=
kNoStateId && state != hdr.NumStates()) {
512 LOG(ERROR) <<
"VectorFst::Read: Unexpected end of file: " << opts.
source;
515 return impl.release();
526 template <
class A,
class S >
539 template <
class F,
class G>
540 friend void Cast(
const F &, G *);
562 if (
this != &fst) SetImpl(std::make_shared<Impl>(fst));
566 template <
class... T>
569 GetMutableImpl()->EmplaceArc(state, std::forward<T>(ctor_args)...);
574 auto *impl = Impl::Read(strm, opts);
575 return impl ?
new VectorFst(std::shared_ptr<Impl>(impl)) :
nullptr;
582 return impl ?
new VectorFst(std::shared_ptr<Impl>(impl)) :
nullptr;
586 return WriteFst(*
this, strm, opts);
589 bool Write(
const std::string &source)
const override {
594 static bool WriteFst(
const FST &
fst, std::ostream &strm,
598 GetImpl()->InitStateIterator(data);
602 GetImpl()->InitArcIterator(s, data);
605 inline void InitMutableArcIterator(
StateId s,
617 explicit VectorFst(std::shared_ptr<Impl> impl)
621 template <
class Arc,
class State>
624 template <
class Arc,
class State>
630 template <
class Arc,
class State>
634 static constexpr
int file_version = 2;
635 bool update_header =
true;
639 std::streampos start_offset = 0;
641 (start_offset = strm.tellp()) != -1) {
643 update_header =
false;
645 const auto properties =
648 "vector", properties, &hdr);
651 const auto s = siter.Value();
652 fst.Final(s).Write(strm);
653 const int64_t narcs = fst.NumArcs(s);
656 const auto &arc = aiter.Value();
659 arc.weight.Write(strm);
666 LOG(ERROR) <<
"VectorFst::Write: Write failed: " << opts.
source;
672 fst, strm, opts, file_version,
"vector", properties, &hdr,
676 LOG(ERROR) <<
"Inconsistent number of states observed during write";
685 template <
class Arc,
class State>
691 : nstates_(fst.GetImpl()->NumStates()) {}
693 bool Done()
const {
return s_ >= nstates_; }
708 template <
class Arc,
class State>
714 : arcs_(fst.GetImpl()->GetState(s)->
Arcs()),
715 narcs_(fst.GetImpl()->GetState(s)->
NumArcs()) {}
717 bool Done()
const {
return i_ >= narcs_; }
725 void Seek(
size_t a) { i_ = a; }
741 template <
class Arc,
class State>
751 properties_ = &fst->
GetImpl()->properties_;
754 bool Done() const final {
return i_ >= state_->NumArcs(); }
756 const Arc &
Value() const final {
return state_->GetArc(i_); }
764 void Seek(
size_t a)
final { i_ = a; }
767 const auto &oarc = state_->GetArc(i_);
768 uint64_t properties = properties_->load(std::memory_order_relaxed);
769 if (oarc.ilabel != oarc.olabel) properties &= ~
kNotAcceptor;
770 if (oarc.ilabel == 0) {
772 if (oarc.olabel == 0) properties &= ~
kEpsilons;
774 if (oarc.olabel == 0) properties &= ~
kOEpsilons;
775 if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) {
778 state_->SetArc(arc, i_);
779 if (arc.ilabel != arc.olabel) {
783 if (arc.ilabel == 0) {
786 if (arc.olabel == 0) {
791 if (arc.olabel == 0) {
795 if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
802 properties_->store(properties, std::memory_order_relaxed);
811 std::atomic<uint64_t> *properties_;
816 template <
class Arc,
class State>
820 std::make_unique<MutableArcIterator<VectorFst<Arc, State>>>(
this, s);
828 #endif // FST_VECTOR_FST_H_
typename std::allocator_traits< ArcAllocator >::template rebind_alloc< VectorState< Arc, M >> StateAllocator
void AddArc(StateId state, const Arc &arc)
typename Arc::Weight Weight
void SetState(StateId state, State *vstate)
uint64_t AddArcProperties(uint64_t inprops, typename A::StateId s, const A &arc, const A *prev_arc)
void SetFlags(uint8_t, uint8_t) final
void DeleteArcs(StateId state)
constexpr uint64_t kMutable
constexpr uint8_t kArcValueFlags
void AddArc(StateId state, Arc &&arc)
uint64_t SetStartProperties(uint64_t inprops)
void Cast(const F &, G *)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kOEpsilons
void DeleteStates(const std::vector< StateId > &dstates)
std::unique_ptr< MutableArcIteratorBase< Arc > > base
virtual size_t NumArcs(StateId) const =0
VectorFstBaseImpl & operator=(VectorFstBaseImpl &&impl) noexcept
VectorState(const ArcAllocator &alloc)
void DeleteArcs(StateId state, size_t n)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
void ReserveArcs(size_t n)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
uint64_t AddStateProperties(uint64_t inprops)
StateIterator(const VectorFst< Arc, State > &fst)
uint64_t DeleteAllStatesProperties(uint64_t inprops, uint64_t staticProps)
uint64_t DeleteStatesProperties(uint64_t inprops)
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
void SetNumOutputEpsilons(size_t n)
size_t NumArcs(StateId state) const
void SetFinal(StateId state, Weight weight)
virtual Weight Final(StateId) const =0
const Arc & Value() const final
static VectorFst * Read(std::istream &strm, const FstReadOptions &opts)
typename Arc::StateId StateId
constexpr uint64_t kEpsilons
uint64_t SetFinalProperties(uint64_t inprops, const Weight &old_weight, const Weight &new_weight)
const Arc & Value() const
void InitMutableArcIterator(StateId s, MutableArcIteratorData< Arc > *) override
typename Arc::Label Label
void DeleteStates(const std::vector< StateId > &dstates)
void AddArc(StateId state, Arc &&arc)
typename Arc::StateId StateId
VectorFst(const Fst< Arc > &fst)
std::ostream & WriteType(std::ostream &strm, const T t)
void InitStateIterator(StateIteratorData< Arc > *data) const
static void WriteFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr)
Weight Final(StateId state) const
const State * GetState(StateId state) const
MutableArcIterator(VectorFst< Arc, State > *fst, StateId s)
State * GetState(StateId state)
typename Arc::StateId StateId
void SetFinal(Weight weight)
void EmplaceArc(StateId state, T &&...ctor_args)
static void Destroy(VectorState< A, M > *state, StateAllocator *alloc)
constexpr uint64_t kNoOEpsilons
StateId NumStates() const
void InitStateIterator(StateIteratorData< Arc > *data) const override
constexpr uint64_t kNotAcceptor
constexpr uint64_t kCopyProperties
VectorFst & operator=(const Fst< Arc > &fst) override
void AddArc(const Arc &arc)
constexpr uint64_t kNoEpsilons
ArcIterator(const VectorFst< Arc, State > &fst, StateId s)
static bool UpdateFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr, size_t header_offset)
void SetStart(StateId state)
typename Arc::Weight Weight
VectorFst * Copy(bool safe=false) const override
virtual StateId Start() const =0
uint8_t Flags() const final
std::unique_ptr< StateIteratorBase< Arc > > base
void DeleteArcs(size_t n)
void ReserveArcs(StateId state, size_t n)
void AddArc(StateId state, const Arc &arc)
size_t NumOutputEpsilons() const
VectorFst & operator=(const VectorFst &)=default
void EmplaceArc(StateId state, T &&...ctor_args)
typename Arc::StateId StateId
constexpr uint64_t kNullProperties
std::unique_ptr< ArcIteratorBase< Arc > > base
void ReserveStates(size_t n)
~VectorFstBaseImpl() override
constexpr uint64_t kIEpsilons
size_t NumInputEpsilons(StateId state) const
virtual std::optional< StateId > NumStatesIfKnown() const
VectorState(const VectorState< A, M > &state, const ArcAllocator &alloc)
size_t NumOutputEpsilons(StateId state) const
bool WriteFile(const std::string &source) const
constexpr uint8_t Flags() const
size_t Position() const final
void SetNumInputEpsilons(size_t n)
bool Write(const std::string &source) const override
constexpr uint64_t kUnweighted
void EmplaceArc(StateId state, T &&...ctor_args)
const Arc & GetArc(size_t n) const
void Seek(size_t a) final
virtual const SymbolTable * InputSymbols() const =0
Arc::StateId CountStates(const Fst< Arc > &fst)
typename Arc::Weight Weight
VectorFst(const VectorFst &fst, bool unused_safe=false)
typename Arc::StateId StateId
void SetFlags(uint8_t, uint8_t)
void DeleteArcs(StateId state, size_t n)
void SetStart(StateId state)
void SetValue(const Arc &arc) final
VectorFstBaseImpl(VectorFstBaseImpl &&impl) noexcept
static VectorFst * Read(std::string_view source)
std::istream & ReadType(std::istream &strm, T *t)
uint64_t DeleteArcsProperties(uint64_t inprops)
typename Arc::StateId StateId
void DeleteArcs(StateId state)
size_t NumInputEpsilons() const
void InitArcIterator(StateId state, ArcIteratorData< Arc > *data) const
internal::VectorFstImpl< S > * GetMutableImpl() const
constexpr uint64_t kWeighted
void EmplaceArc(T &&...ctor_args)
static bool WriteFst(const FST &fst, std::ostream &strm, const FstWriteOptions &opts)
constexpr uint64_t kExpanded
void SetFinal(StateId state, Weight weight)
void SetArc(const Arc &arc, size_t n)
constexpr uint64_t kNoIEpsilons
constexpr uint64_t kSetArcProperties
void Destroy(ArcIterator< FST > *aiter, MemoryPool< ArcIterator< FST >> *pool)
StateId AddState(State *state)
const internal::VectorFstImpl< S > * GetImpl() const
constexpr uint64_t kAcceptor
virtual const SymbolTable * OutputSymbols() const =0