21 #ifndef FST_CONST_FST_H_ 22 #define FST_CONST_FST_H_ 39 template <
class A,
class Un
signed>
42 template <
class F,
class G>
43 void Cast(
const F &, G *);
49 template <
class A,
class Un
signed>
63 std::string type =
"const";
64 if (
sizeof(Unsigned) !=
sizeof(uint32_t)) {
65 type += std::to_string(CHAR_BIT *
sizeof(Unsigned));
98 data->
arcs = arcs_ + states_[s].pos;
99 data->
narcs = states_[s].narcs;
115 ConstState() : final_weight(
Weight::Zero()) {}
119 static constexpr uint64_t kStaticProperties =
kExpanded;
122 static constexpr
int kFileVersion = 2;
124 static constexpr
int kAlignedFileVersion = 1;
126 static constexpr
int kMinFileVersion = 1;
128 std::unique_ptr<MappedFile> states_region_;
129 std::unique_ptr<MappedFile> arcs_region_;
130 ConstState *states_ =
nullptr;
131 Arc *arcs_ =
nullptr;
140 template <
class Arc,
class Un
signed>
142 std::string type =
"const";
143 if (
sizeof(Unsigned) !=
sizeof(uint32_t)) {
144 type += std::to_string(CHAR_BIT *
sizeof(Unsigned));
149 start_ = fst.
Start();
153 narcs_ += fst.
NumArcs(siter.Value());
155 states_region_.reset(MappedFile::AllocateType<ConstState>(nstates_));
156 arcs_region_.reset(MappedFile::AllocateType<Arc>(narcs_));
157 states_ =
static_cast<ConstState *
>(states_region_->mutable_data());
158 arcs_ =
static_cast<Arc *
>(arcs_region_->mutable_data());
160 for (
StateId s = 0; s < nstates_; ++s) {
161 states_[s].final_weight = fst.
Final(s);
162 states_[s].pos = pos;
163 states_[s].narcs = 0;
164 states_[s].niepsilons = 0;
165 states_[s].noepsilons = 0;
167 const auto &arc = aiter.Value();
169 if (arc.ilabel == 0) ++states_[s].niepsilons;
170 if (arc.olabel == 0) ++states_[s].noepsilons;
184 template <
class Arc,
class Un
signed>
187 auto impl = std::make_unique<ConstFstImpl>();
189 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr))
return nullptr;
190 impl->start_ = hdr.Start();
191 impl->nstates_ = hdr.NumStates();
192 impl->narcs_ = hdr.NumArcs();
194 if (hdr.Version() == kAlignedFileVersion) {
198 LOG(ERROR) <<
"ConstFst::Read: Alignment failed: " << opts.
source;
201 size_t b = impl->nstates_ *
sizeof(ConstState);
202 impl->states_region_.reset(
204 if (!strm || !impl->states_region_) {
205 LOG(ERROR) <<
"ConstFst::Read: Read failed: " << opts.
source;
209 static_cast<ConstState *
>(impl->states_region_->mutable_data());
211 LOG(ERROR) <<
"ConstFst::Read: Alignment failed: " << opts.
source;
214 b = impl->narcs_ *
sizeof(
Arc);
215 impl->arcs_region_.reset(
217 if (!strm || !impl->arcs_region_) {
218 LOG(ERROR) <<
"ConstFst::Read: Read failed: " << opts.
source;
221 impl->arcs_ =
static_cast<Arc *
>(impl->arcs_region_->mutable_data());
222 return impl.release();
233 template <
class A,
class Un
signed>
245 template <
class F,
class G>
246 void friend Cast(
const F &, G *);
263 auto *impl = Impl::Read(strm, opts);
264 return impl ?
new ConstFst(std::shared_ptr<Impl>(impl)) :
nullptr;
271 return impl ?
new ConstFst(std::shared_ptr<Impl>(impl)) :
nullptr;
275 return WriteFst(*
this, strm, opts);
278 bool Write(
const std::string &source)
const override {
283 static bool WriteFst(
const FST &
fst, std::ostream &strm,
287 GetImpl()->InitStateIterator(data);
291 GetImpl()->InitArcIterator(s, data);
295 explicit ConstFst(std::shared_ptr<Impl> impl)
301 static const Impl *GetImplIfConstFst(
const ConstFst &const_fst) {
306 template <
typename FST>
307 static Impl *GetImplIfConstFst(
const FST &fst) {
316 template <
class Arc,
class Un
signed>
320 const auto file_version =
324 size_t num_states = 0;
325 std::streamoff start_offset = 0;
326 bool update_header =
true;
327 if (
const auto *impl = GetImplIfConstFst(fst)) {
328 num_arcs = impl->narcs_;
329 num_states = impl->nstates_;
330 update_header =
false;
331 }
else if (opts.
stream_write || (start_offset = strm.tellp()) == -1) {
336 num_arcs += fst.NumArcs(siter.Value());
339 update_header =
false;
345 std::string type =
"const";
346 if (
sizeof(Unsigned) !=
sizeof(uint32_t)) {
347 type += std::to_string(CHAR_BIT *
sizeof(Unsigned));
349 const auto properties =
355 LOG(ERROR) <<
"Could not align file during write after header";
362 const auto s = siter.Value();
363 state.final_weight = fst.Final(s);
365 state.narcs = fst.NumArcs(s);
366 state.niepsilons = fst.NumInputEpsilons(s);
367 state.noepsilons = fst.NumOutputEpsilons(s);
368 strm.write(reinterpret_cast<const char *>(&state),
sizeof(state));
375 LOG(ERROR) <<
"Could not align file during write after writing states";
380 const auto &arc = aiter.Value();
381 strm.write(reinterpret_cast<const char *>(&arc),
sizeof(arc));
386 LOG(ERROR) <<
"ConstFst::WriteFst: Write failed: " << opts.
source;
391 fst, strm, opts, file_version, type, properties, &hdr, start_offset);
394 LOG(ERROR) <<
"Inconsistent number of states observed during write";
397 if (hdr.
NumArcs() != num_arcs) {
398 LOG(ERROR) <<
"Inconsistent number of arcs observed during write";
407 template <
class Arc,
class Un
signed>
413 : nstates_(fst.GetImpl()->
NumStates()), s_(0) {}
415 bool Done()
const {
return s_ >= nstates_; }
430 template <
class Arc,
class Un
signed>
436 : arcs_(fst.GetImpl()->
Arcs(s)),
437 narcs_(fst.GetImpl()->
NumArcs(s)),
440 bool Done()
const {
return i_ >= narcs_; }
450 void Seek(
size_t a) { i_ = a; }
467 #endif // FST_CONST_FST_H_
static Impl * Read(std::istream &strm, const FstReadOptions &opts)
typename Arc::Weight Weight
void SetProperties(uint64_t props)
bool AlignInput(std::istream &strm, size_t align=MappedFile::kArchAlignment)
constexpr uint64_t kMutable
constexpr uint64_t kWeightedCycles
constexpr uint8_t kArcValueFlags
void Cast(const F &, G *)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
virtual size_t NumArcs(StateId) const =0
ConstFst * Copy(bool safe=false) const override
static MappedFile * Map(std::istream &istrm, bool memorymap, const std::string &source, size_t size)
virtual Weight Final(StateId) const =0
void SetOutputSymbols(const SymbolTable *osyms)
constexpr uint64_t kUnweightedCycles
bool AlignOutput(std::ostream &strm, size_t align=MappedFile::kArchAlignment)
static ConstFst * Read(std::istream &strm, const FstReadOptions &opts)
typename Arc::StateId StateId
static bool WriteFst(const FST &fst, std::ostream &strm, const FstWriteOptions &opts)
ConstFst(const Fst< Arc > &fst)
ArcIterator(const ConstFst< Arc, Unsigned > &fst, StateId s)
void SetFlags(uint8_t, uint8_t)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
static void WriteFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr)
const Arc * Arcs(StateId s) const
constexpr uint8_t Flags() const
constexpr uint64_t kCopyProperties
typename Arc::StateId StateId
typename Impl::ConstState ConstState
bool Write(const std::string &source) const override
uint64_t CheckProperties(const Fst< Arc > &fst, uint64_t check_mask, uint64_t test_mask)
static bool UpdateFstHeader(const Fst< Arc > &fst, std::ostream &strm, const FstWriteOptions &opts, int version, std::string_view type, uint64_t properties, FstHeader *hdr, size_t header_offset)
size_t NumArcs(StateId s) const
size_t NumInputEpsilons(StateId s) const
virtual StateId Start() const =0
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const
std::unique_ptr< StateIteratorBase< Arc > > base
StateIterator(const ConstFst< Arc, Unsigned > &fst)
constexpr uint64_t kNullProperties
std::unique_ptr< ArcIteratorBase< Arc > > base
void SetInputSymbols(const SymbolTable *isyms)
bool WriteFile(const std::string &source) const
ConstFst(const ConstFst &fst, bool unused_safe=false)
typename Arc::StateId StateId
bool Write(std::ostream &strm, const FstWriteOptions &opts) const override
virtual const SymbolTable * InputSymbols() const =0
typename Arc::StateId StateId
void SetType(std::string_view type)
size_t NumOutputEpsilons(StateId s) const
void InitStateIterator(StateIteratorData< Arc > *data) const
StateId NumStates() const
const Arc & Value() const
static ConstFst * Read(const std::string &source)
constexpr uint64_t kExpanded
static ConstFstImpl * Read(std::istream &strm, const FstReadOptions &opts)
Weight Final(StateId s) const
void InitStateIterator(StateIteratorData< Arc > *data) const override
const Impl * GetImpl() const
virtual const SymbolTable * OutputSymbols() const =0