18 #ifndef FST_SCRIPT_FST_CLASS_H_ 19 #define FST_SCRIPT_FST_CLASS_H_ 29 #include <type_traits> 44 #include <string_view> 62 virtual const std::string &
ArcType()
const = 0;
64 virtual const std::string &
FstType()
const = 0;
66 virtual size_t NumArcs(int64_t)
const = 0;
70 virtual uint64_t
Properties(uint64_t,
bool)
const = 0;
71 virtual int64_t
Start()
const = 0;
72 virtual const std::string &
WeightType()
const = 0;
74 virtual bool Write(
const std::string &)
const = 0;
75 virtual bool Write(std::ostream &,
const std::string &)
const = 0;
82 virtual bool AddArc(int64_t,
const ArcClass &) = 0;
83 virtual int64_t AddState() = 0;
84 virtual void AddStates(
size_t) = 0;
86 virtual bool DeleteArcs(int64_t,
size_t) = 0;
87 virtual bool DeleteArcs(int64_t) = 0;
88 virtual bool DeleteStates(
const std::vector<int64_t> &) = 0;
89 virtual void DeleteStates() = 0;
92 virtual int64_t NumStates()
const = 0;
93 virtual bool ReserveArcs(int64_t,
size_t) = 0;
94 virtual void ReserveStates(int64_t) = 0;
95 virtual void SetInputSymbols(
const SymbolTable *) = 0;
96 virtual bool SetFinal(int64_t,
const WeightClass &) = 0;
97 virtual void SetOutputSymbols(
const SymbolTable *) = 0;
98 virtual void SetProperties(uint64_t, uint64_t) = 0;
99 virtual bool SetStart(int64_t) = 0;
112 : impl_(std::move(impl)) {}
122 Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight<
typename Arc::Weight>(),
138 const std::string &
ArcType() const final {
return Arc::Type(); }
158 for (
const auto &state : dstates)
162 std::vector<typename Arc::StateId> typed_dstates(dstates.size());
163 std::copy(dstates.begin(), dstates.end(), typed_dstates.begin());
179 const std::string &
FstType() const final {
return impl_->Type(); }
182 return impl_->InputSymbols();
198 : std::numeric_limits<size_t>::max();
204 : std::numeric_limits<size_t>::max();
210 : std::numeric_limits<size_t>::max();
219 return impl_->Properties(mask, test);
235 return impl_->OutputSymbols();
247 ->SetFinal(s, *weight.GetWeight<
typename Arc::Weight>());
268 int64_t
Start() const final {
return impl_->Start(); }
272 const auto num_states = impl_->NumStatesIfKnown();
273 if (!num_states.has_value()) {
274 FSTERROR() <<
"Cannot get number of states for unexpanded FST";
277 if (s < 0 || s >= *num_states) {
278 FSTERROR() <<
"State ID " << s <<
" not valid";
284 const std::string &
WeightType() const final {
return Arc::Weight::Type(); }
286 bool Write(
const std::string &source)
const final {
287 return impl_->Write(source);
290 bool Write(std::ostream &ostr,
const std::string &source)
const final {
292 return impl_->Write(ostr, opts);
300 std::unique_ptr<Fst<Arc>> impl_;
320 : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
323 impl_.reset(other.impl_ ==
nullptr ?
nullptr : other.impl_->Copy());
329 const std::string &
ArcType() const final {
return impl_->ArcType(); }
331 const std::string &
FstType() const final {
return impl_->FstType(); }
334 return impl_->InputSymbols();
337 size_t NumArcs(int64_t s)
const final {
return impl_->NumArcs(s); }
340 return impl_->NumInputEpsilons(s);
344 return impl_->NumOutputEpsilons(s);
348 return impl_->OutputSymbols();
353 if (!impl_)
return kError & mask;
354 return impl_->Properties(mask, test);
357 static std::unique_ptr<FstClass> Read(
358 const std::string &source);
360 static std::unique_ptr<FstClass> Read(
361 std::istream &istrm,
const std::string &source);
363 int64_t
Start() const final {
return impl_->Start(); }
365 bool ValidStateId(int64_t s)
const final {
return impl_->ValidStateId(s); }
367 const std::string &
WeightType() const final {
return impl_->WeightType(); }
373 std::string_view op_name)
const;
375 bool Write(
const std::string &source)
const final {
376 return impl_->Write(source);
379 bool Write(std::ostream &ostr,
const std::string &source)
const final {
380 return impl_->Write(ostr, source);
389 FSTERROR() <<
"Doesn't make sense to convert any class to type FstClass";
394 static std::unique_ptr<FstClassImplBase>
Create() {
395 FSTERROR() <<
"Doesn't make sense to create an FstClass with a " 396 <<
"particular arc type";
402 if (Arc::Type() !=
ArcType()) {
412 static std::unique_ptr<FstClass>
Read(std::istream &stream,
415 LOG(ERROR) <<
"FstClass::Read: Options header not specified";
420 return ReadTypedFst<MutableFstClass, MutableFst<Arc>>(stream, opts);
422 return ReadTypedFst<FstClass, Fst<Arc>>(stream, opts);
427 explicit FstClass(std::unique_ptr<FstClassImplBase> impl)
428 : impl_(std::move(impl)) {}
437 template <
class FstClassT,
class UnderlyingT>
440 std::unique_ptr<UnderlyingT> u(UnderlyingT::Read(stream, opts));
441 return u ? std::make_unique<FstClassT>(std::move(u)) :
nullptr;
445 std::unique_ptr<FstClassImplBase> impl_;
453 if (!WeightTypesMatch(ac.
weight,
"AddArc"))
return false;
454 return GetImpl()->AddArc(s, ac);
457 int64_t
AddState() {
return GetImpl()->AddState(); }
459 void AddStates(
size_t n) {
return GetImpl()->AddStates(n); }
461 bool DeleteArcs(int64_t s,
size_t n) {
return GetImpl()->DeleteArcs(s, n); }
463 bool DeleteArcs(int64_t s) {
return GetImpl()->DeleteArcs(s); }
466 return GetImpl()->DeleteStates(dstates);
472 return GetImpl()->MutableInputSymbols();
476 return GetImpl()->MutableOutputSymbols();
479 int64_t
NumStates()
const {
return GetImpl()->NumStates(); }
481 bool ReserveArcs(int64_t s,
size_t n) {
return GetImpl()->ReserveArcs(s, n); }
485 static std::unique_ptr<MutableFstClass> Read(
486 const std::string &source,
bool convert =
false);
489 GetImpl()->SetInputSymbols(isyms);
493 if (!WeightTypesMatch(weight,
"SetFinal"))
return false;
494 return GetImpl()->SetFinal(s, weight);
498 GetImpl()->SetOutputSymbols(osyms);
502 GetImpl()->SetProperties(props, mask);
505 bool SetStart(int64_t s) {
return GetImpl()->SetStart(s); }
521 FSTERROR() <<
"Doesn't make sense to convert any class to type " 522 <<
"MutableFstClass";
527 static std::unique_ptr<FstClassImplBase>
Create() {
528 FSTERROR() <<
"Doesn't make sense to create a MutableFstClass with a " 529 <<
"particular arc type";
541 static std::unique_ptr<MutableFstClass>
Read(std::istream &stream,
544 return mfst ? std::make_unique<MutableFstClass>(std::move(mfst)) :
nullptr;
561 static std::unique_ptr<VectorFstClass> Read(
562 const std::string &source);
565 static std::unique_ptr<VectorFstClass>
Read(std::istream &stream,
568 return vfst ? std::make_unique<VectorFstClass>(std::move(vfst)) :
nullptr;
583 return std::make_unique<FstClassImpl<Arc>>(
584 std::make_unique<VectorFst<Arc>>(*other.
GetFst<Arc>()));
588 static std::unique_ptr<FstClassImplBase>
Create() {
589 return std::make_unique<FstClassImpl<Arc>>(
590 std::make_unique<VectorFst<Arc>>());
598 template <
class Reader,
class Creator,
class Converter>
605 : reader(r), creator(cr), converter(co) {}
611 template <
class Reader,
class Creator,
class Converter>
614 FstClassRegEntry<Reader, Creator, Converter>,
615 FstClassIORegister<Reader, Creator, Converter>> {
618 return this->GetEntry(arc_type).reader;
622 return this->GetEntry(arc_type).creator;
626 return this->GetEntry(arc_type).converter;
631 std::string legal_type(key);
633 legal_type.append(
"-arc.so");
640 template <
class FstClassType>
642 using Reader = std::unique_ptr<FstClassType> (*)(std::istream &stream,
645 using Creator = std::unique_ptr<FstClassImplBase> (*)();
648 std::unique_ptr<FstClassImplBase> (*)(
const FstClass &other);
662 #define REGISTER_FST_CLASS(Class, Arc) \ 663 static FstClassIORegistration<Class>::Registerer Class##_##Arc##_registerer( \ 665 FstClassIORegistration<Class>::Entry( \ 666 Class::Read<Arc>, Class::Create<Arc>, Class::Convert<Arc>)) 668 #define REGISTER_FST_CLASSES(Arc) \ 669 REGISTER_FST_CLASS(FstClass, Arc); \ 670 REGISTER_FST_CLASS(MutableFstClass, Arc); \ 671 REGISTER_FST_CLASS(VectorFstClass, Arc); 676 #endif // FST_SCRIPT_FST_CLASS_H_ virtual const std::string & WeightType() const =0
static std::unique_ptr< FstClassImplBase > Convert(const FstClass &other)
int64_t NumStates() const
FstClass & operator=(const FstClass &other)
const Fst< Arc > * GetFst() const
constexpr uint64_t kMutable
void ConvertToLegalCSymbol(std::string *s)
FstClassImpl(std::unique_ptr< Fst< Arc >> impl)
Converter GetConverter(std::string_view arc_type) const
const std::string & ArcType() const final
VectorFstClass(const VectorFst< Arc > &fst)
std::unique_ptr< FstClassType >(*)(std::istream &stream, const FstReadOptions &opts) Reader
MutableFst< Arc > * GetMutableFst()
uint64_t Properties(uint64_t mask, bool test) const final
virtual size_t NumArcs(int64_t) const =0
size_t NumInputEpsilons(int64_t s) const final
FstClass(const FstClass &other)
virtual bool Write(const std::string &) const =0
FstClass(std::unique_ptr< FstClassImplBase > impl)
bool DeleteArcs(int64_t s)
const FstClassImplBase * GetImpl() const
bool ReserveArcs(int64_t s, size_t n)
int64_t Start() const final
constexpr uint64_t kError
void SetInputSymbols(const SymbolTable *isyms) final
const SymbolTable * InputSymbols() const final
bool Write(std::ostream &ostr, const std::string &source) const final
WeightClass Final(int64_t s) const final
virtual size_t NumInputEpsilons(int64_t) const =0
const std::string & ArcType() const final
const std::string & FstType() const final
bool SetFinal(int64_t s, const WeightClass &weight)
Reader GetReader(std::string_view arc_type) const
bool ValidStateId(int64_t s) const final
void ReserveStates(int64_t n) final
virtual const std::string & ArcType() const =0
FstClassImpl * Copy() final
bool DeleteArcs(int64_t s) final
virtual uint64_t Properties(uint64_t, bool) const =0
int64_t NumStates() const final
VectorFstClass(std::unique_ptr< FstClassImplBase > impl)
MutableFstClass(const MutableFst< Arc > &fst)
const std::string & WeightType() const final
void DeleteStates() final
bool DeleteArcs(int64_t s, size_t n)
WeightClass Final(int64_t s) const final
static std::unique_ptr< FstClassImplBase > Create()
bool SetStart(int64_t s) final
const std::string & FstType() const final
bool Write(const std::string &source) const final
virtual const SymbolTable * InputSymbols() const =0
MutableFstClass(std::unique_ptr< FstClassImplBase > impl)
bool DeleteArcs(int64_t s, size_t n) final
FstClassImpl(const Fst< Arc > &impl)
virtual int64_t Start() const =0
static std::unique_ptr< FstClass > Read(std::istream &stream, const FstReadOptions &opts)
FstClass(std::unique_ptr< Fst< Arc >> fst)
Fst< Arc > * GetImpl() const
void SetProperties(uint64_t props, uint64_t mask) final
size_t NumInputEpsilons(int64_t s) const final
std::unique_ptr< FstClassImplBase >(*)(const FstClass &other) Converter
virtual bool ValidStateId(int64_t) const =0
Creator GetCreator(std::string_view arc_type) const
bool SetFinal(int64_t s, const WeightClass &weight) final
bool Write(std::ostream &ostr, const std::string &source) const final
size_t NumOutputEpsilons(int64_t s) const final
constexpr To implicit_cast(typename internal::type_identity_t< To > to)
virtual ~FstClassBase()=default
FstClassRegEntry(Reader r, Creator cr, Converter co)
const std::string & WeightType() const final
static std::unique_ptr< FstClassImplBase > Create()
void AddStates(size_t n) final
MutableFstClass(std::unique_ptr< MutableFst< Arc >> fst)
size_t NumArcs(int64_t s) const final
bool Write(const std::string &source) const final
static std::unique_ptr< FstClassImplBase > Convert(const FstClass &other)
bool ReserveArcs(int64_t s, size_t n) final
static std::unique_ptr< MutableFstClass > Read(std::istream &stream, const FstReadOptions &opts)
void SetOutputSymbols(const SymbolTable *osyms)
static std::unique_ptr< VectorFstClass > Read(std::istream &stream, const FstReadOptions &opts)
void SetProperties(uint64_t props, uint64_t mask)
FstClassImplBase * GetImpl()
bool AddArc(int64_t s, const ArcClass &ac)
std::string ConvertKeyToSoFilename(std::string_view key) const final
void SetOutputSymbols(const SymbolTable *osyms) final
VectorFstClass(std::unique_ptr< VectorFst< Arc >> fst)
const SymbolTable * OutputSymbols() const final
static WeightClass NoWeight(std::string_view weight_type)
virtual const std::string & FstType() const =0
void ReserveStates(int64_t n)
void SetInputSymbols(const SymbolTable *isyms)
bool DeleteStates(const std::vector< int64_t > &dstates)
FstClass(const Fst< Arc > &fst)
SymbolTable * MutableInputSymbols() final
virtual const SymbolTable * OutputSymbols() const =0
int64_t Start() const final
bool ValidStateId(int64_t s) const final
SymbolTable * MutableOutputSymbols()
virtual WeightClass Final(int64_t) const =0
size_t NumArcs(int64_t s) const final
static std::unique_ptr< FstClassImplBase > Create()
static std::unique_ptr< FstClassImplBase > Convert(const FstClass &other)
const SymbolTable * InputSymbols() const final
bool DeleteStates(const std::vector< int64_t > &dstates) final
bool AddArc(int64_t s, const ArcClass &ac) final
virtual size_t NumOutputEpsilons(int64_t) const =0
uint64_t Properties(uint64_t mask, bool test) const final
size_t NumOutputEpsilons(int64_t s) const final
SymbolTable * MutableOutputSymbols() final
const SymbolTable * OutputSymbols() const final
static std::unique_ptr< FstClassT > ReadTypedFst(std::istream &stream, const FstReadOptions &opts)
SymbolTable * MutableInputSymbols()