21 #ifndef FST_CONST_FST_H_ 22 #define FST_CONST_FST_H_ 45 #include <string_view> 49 template <
class A,
class Un
signed>
52 template <
class F,
class G>
53 void Cast(
const F &, G *);
59 template <
class A,
class Un
signed>
73 std::string type =
"const";
74 if (
sizeof(Unsigned) !=
sizeof(uint32_t)) {
75 type += std::to_string(CHAR_BIT *
sizeof(Unsigned));
101 data->
base =
nullptr;
107 data->
base =
nullptr;
108 data->
arcs = arcs_ + states_[s].pos;
109 data->
narcs = states_[s].narcs;
125 ConstState() : final_weight(
Weight::Zero()) {}
129 static constexpr uint64_t kStaticProperties =
kExpanded;
132 static constexpr
int kFileVersion = 2;
134 static constexpr
int kAlignedFileVersion = 1;
136 static constexpr
int kMinFileVersion = 1;
138 std::unique_ptr<MappedFile> states_region_;
139 std::unique_ptr<MappedFile> arcs_region_;
140 ConstState *states_ =
nullptr;
141 Arc *arcs_ =
nullptr;
150 template <
class Arc,
class Un
signed>
152 std::string type =
"const";
153 if (
sizeof(Unsigned) !=
sizeof(uint32_t)) {
154 type += std::to_string(CHAR_BIT *
sizeof(Unsigned));
159 start_ = fst.
Start();
163 narcs_ += fst.
NumArcs(siter.Value());
165 states_region_.reset(MappedFile::AllocateType<ConstState>(nstates_));
166 arcs_region_.reset(MappedFile::AllocateType<Arc>(narcs_));
167 states_ =
static_cast<ConstState *
>(states_region_->mutable_data());
168 arcs_ =
static_cast<Arc *
>(arcs_region_->mutable_data());
170 for (
StateId s = 0; s < nstates_; ++s) {
171 states_[s].final_weight = fst.
Final(s);
172 states_[s].pos = pos;
173 states_[s].narcs = 0;
174 states_[s].niepsilons = 0;
175 states_[s].noepsilons = 0;
177 const auto &arc = aiter.Value();
179 if (arc.ilabel == 0) ++states_[s].niepsilons;
180 if (arc.olabel == 0) ++states_[s].noepsilons;
194 template <
class Arc,
class Un
signed>
197 auto impl = std::make_unique<ConstFstImpl>();
199 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr))
return nullptr;
200 impl->start_ = hdr.Start();
201 impl->nstates_ = hdr.NumStates();
202 impl->narcs_ = hdr.NumArcs();
204 if (hdr.Version() == kAlignedFileVersion) {
208 LOG(ERROR) <<
"ConstFst::Read: Alignment failed: " << opts.
source;
211 size_t b = impl->nstates_ *
sizeof(ConstState);
212 impl->states_region_.reset(
214 if (!strm || !impl->states_region_) {
215 LOG(ERROR) <<
"ConstFst::Read: Read failed: " << opts.
source;
219 static_cast<ConstState *
>(impl->states_region_->mutable_data());
221 LOG(ERROR) <<
"ConstFst::Read: Alignment failed: " << opts.
source;
224 b = impl->narcs_ *
sizeof(
Arc);
225 impl->arcs_region_.reset(
227 if (!strm || !impl->arcs_region_) {
228 LOG(ERROR) <<
"ConstFst::Read: Read failed: " << opts.
source;
231 impl->arcs_ =
static_cast<Arc *
>(impl->arcs_region_->mutable_data());
232 return impl.release();
243 template <
class A,
class Un
signed>
257 template <
class F,
class G>
258 void friend Cast(
const F &, G *);
265 : Base(fst.GetSharedImpl()) {}
274 auto *impl = Impl::Read(strm, opts);
275 return impl ?
new ConstFst(std::shared_ptr<Impl>(impl)) :
nullptr;
281 auto *impl = Base::Read(source);
282 return impl ?
new ConstFst(std::shared_ptr<Impl>(impl)) :
nullptr;
286 return WriteFst(*
this, strm, opts);
289 bool Write(
const std::string &source)
const override {
294 static bool WriteFst(
const FST &
fst, std::ostream &strm,
298 GetImpl()->InitStateIterator(data);
302 GetImpl()->InitArcIterator(s, data);
306 explicit ConstFst(std::shared_ptr<Impl> impl) : Base(impl) {}
311 static const Impl *GetImplIfConstFst(
const ConstFst &const_fst) {
316 template <
typename FST>
317 static Impl *GetImplIfConstFst(
const FST &fst) {
326 template <
class Arc,
class Un
signed>
330 const auto file_version =
334 size_t num_states = 0;
335 std::streamoff start_offset = 0;
336 bool update_header =
true;
337 if (
const auto *impl = GetImplIfConstFst(fst)) {
338 num_arcs = impl->narcs_;
339 num_states = impl->nstates_;
340 update_header =
false;
341 }
else if (opts.
stream_write || (start_offset = strm.tellp()) == -1) {
346 num_arcs += fst.NumArcs(siter.Value());
349 update_header =
false;
355 std::string type =
"const";
356 if (
sizeof(Unsigned) !=
sizeof(uint32_t)) {
357 type += std::to_string(CHAR_BIT *
sizeof(Unsigned));
359 const auto properties =
365 LOG(ERROR) <<
"Could not align file during write after header";
372 const auto s = siter.Value();
373 state.final_weight = fst.Final(s);
375 state.narcs = fst.NumArcs(s);
376 state.niepsilons = fst.NumInputEpsilons(s);
377 state.noepsilons = fst.NumOutputEpsilons(s);
378 strm.write(reinterpret_cast<const char *>(&state),
sizeof(state));
385 LOG(ERROR) <<
"Could not align file during write after writing states";
390 const auto &arc = aiter.Value();
391 strm.write(reinterpret_cast<const char *>(&arc),
sizeof(arc));
396 LOG(ERROR) <<
"ConstFst::WriteFst: Write failed: " << opts.
source;
401 fst, strm, opts, file_version, type, properties, &hdr, start_offset);
404 LOG(ERROR) <<
"Inconsistent number of states observed during write";
407 if (hdr.
NumArcs() != num_arcs) {
408 LOG(ERROR) <<
"Inconsistent number of arcs observed during write";
417 template <
class Arc,
class Un
signed>
423 : nstates_(fst.GetImpl()->
NumStates()), s_(0) {}
425 bool Done()
const {
return s_ >= nstates_; }
440 template <
class Arc,
class Un
signed>
446 : arcs_(fst.GetImpl()->
Arcs(s)),
447 narcs_(fst.GetImpl()->
NumArcs(s)),
450 bool Done()
const {
return i_ >= narcs_; }
460 void Seek(
size_t a) { i_ = a; }
477 #endif // FST_CONST_FST_H_
typename Arc::Weight Weight
void SetProperties(uint64_t props)
bool AlignInput(std::istream &strm, size_t align=MappedFile::kArchAlignment)
constexpr uint64_t kMutable
constexpr uint64_t kWeightedCycles
constexpr uint8_t kArcValueFlags
void Cast(const F &, G *)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
virtual size_t NumArcs(StateId) const =0
ConstFst * Copy(bool safe=false) const override
static MappedFile * Map(std::istream &istrm, bool memorymap, const std::string &source, size_t size)
virtual Weight Final(StateId) const =0
void SetOutputSymbols(const SymbolTable *osyms)
constexpr uint64_t kUnweightedCycles
bool AlignOutput(std::ostream &strm, size_t align=MappedFile::kArchAlignment)
static ConstFst * Read(std::istream &strm, const FstReadOptions &opts)
typename Arc::StateId StateId
static bool WriteFst(const FST &fst, std::ostream &strm, const FstWriteOptions &opts)
static ConstFst * Read(std::string_view source)
ConstFst(const Fst< Arc > &fst)
ArcIterator(const ConstFst< Arc, Unsigned > &fst, StateId s)
void SetFlags(uint8_t, uint8_t)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
static void WriteFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr)
const Impl * GetImpl() const
const Arc * Arcs(StateId s) const
constexpr uint8_t Flags() const
constexpr uint64_t kCopyProperties
typename Arc::StateId StateId
typename Impl::ConstState ConstState
bool Write(const std::string &source) const override
uint64_t CheckProperties(const Fst< Arc > &fst, uint64_t check_mask, uint64_t test_mask)
static bool UpdateFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr, size_t header_offset)
size_t NumArcs(StateId s) const
size_t NumInputEpsilons(StateId s) const
virtual StateId Start() const =0
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
std::unique_ptr< StateIteratorBase< Arc > > base
StateIterator(const ConstFst< Arc, Unsigned > &fst)
constexpr uint64_t kNullProperties
std::unique_ptr< ArcIteratorBase< Arc > > base
void SetInputSymbols(const SymbolTable *isyms)
bool WriteFile(const std::string &source) const
ConstFst(const ConstFst &fst, bool unused_safe=false)
typename Arc::StateId StateId
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
virtual const SymbolTable * InputSymbols() const =0
typename Arc::StateId StateId
void SetType(std::string_view type)
size_t NumOutputEpsilons(StateId s) const
void InitStateIterator(StateIteratorData< Arc > *data) const
StateId NumStates() const
const Arc & Value() const
constexpr uint64_t kExpanded
static ConstFstImpl * Read(std::istream &strm, const FstReadOptions &opts)
Weight Final(StateId s) const
void InitStateIterator(StateIteratorData< Arc > *data) const override
virtual const SymbolTable * OutputSymbols() const =0