48 #include <unordered_map> 49 #include <string_view> 80 const std::string &
ArcType()
const {
return arctype_; }
82 uint8_t
Flags()
const {
return flags_; }
84 size_t Size()
const {
return size_; }
89 arctype_ = std::string(arctype);
92 void SetFlags(uint8_t flags) { flags_ = flags; }
94 void SetSize(
size_t size) { size_ = size; }
98 bool Read(std::istream &strm, std::string_view source);
100 bool Write(std::ostream &strm, std::string_view source)
const;
103 std::string arctype_;
125 : ilabel(ilabel), olabel(olabel), weight(std::move(weight)) {}
129 : ilabel(arc.ilabel),
130 olabel(flags & kEncodeLabels ? arc.olabel : 0),
131 weight(flags & kEncodeWeights ? arc.weight :
Weight::One()) {}
133 static std::unique_ptr<Triple>
Read(std::istream &strm) {
134 auto triple = std::make_unique<Triple>();
141 void Write(std::ostream &strm)
const {
149 return (ilabel == other.
ilabel && olabel == other.
olabel &&
169 size_t hash = triple->
ilabel;
170 static constexpr
int lshift = 5;
171 static constexpr
int rshift = CHAR_BIT *
sizeof(size_t) - 5;
172 if (flags_ & kEncodeLabels) {
173 hash = hash << lshift ^ hash >> rshift ^ triple->
olabel;
175 if (flags_ & kEncodeWeights) {
176 hash = hash << lshift ^ hash >> rshift ^ triple->
weight.Hash();
186 : flags_(flags), triple2label_(1024,
TripleHash(flags)) {}
193 if (arc.nextstate ==
kNoStateId && (flags_ & kEncodeWeights)) {
196 return Encode(std::make_unique<Triple>(arc, flags_));
202 if (label < 1 || label > triples_.size()) {
203 LOG(ERROR) <<
"EncodeTable::Decode: Unknown decode label: " << label;
206 return triples_[label - 1].get();
209 size_t Size()
const {
return triples_.size(); }
211 static EncodeTable *Read(std::istream &strm, std::string_view source);
213 bool Write(std::ostream &strm, std::string_view source)
const;
225 isymbols_.reset(syms->
Copy());
235 osymbols_.reset(syms->
Copy());
246 triple2label_.emplace(triple.get(), triples_.size() + 1);
247 if (insert_result.second) triples_.push_back(std::move(triple));
248 return insert_result.first->second;
252 std::vector<std::unique_ptr<Triple>> triples_;
253 std::unordered_map<const Triple *, Label, TripleHash, TripleEqual>
255 std::unique_ptr<SymbolTable> isymbols_;
256 std::unique_ptr<SymbolTable> osymbols_;
264 std::string_view source) {
266 if (!hdr.
Read(strm, source))
return nullptr;
267 const auto flags = hdr.
Flags();
268 const auto size = hdr.
Size();
269 auto table = std::make_unique<EncodeTable>(flags);
270 for (int64_t i = 0; i < size; ++i) {
271 table->triples_.emplace_back(std::move(Triple::Read(strm)));
272 table->triple2label_[table->triples_.back().get()] = table->triples_.size();
281 LOG(ERROR) <<
"EncodeTable::Read: Read failed: " << source;
284 return table.release();
289 std::string_view source)
const {
294 if (!hdr.
Write(strm, source))
return false;
295 for (
const auto &triple : triples_) triple->Write(strm);
300 LOG(ERROR) <<
"EncodeTable::Write: Write failed: " << source;
326 using Label =
typename Arc::Label;
327 using Weight =
typename Arc::Weight;
333 table_(std::make_shared<internal::EncodeTable<Arc>>(flags)),
337 : flags_(mapper.flags_),
339 table_(mapper.table_),
344 : flags_(mapper.flags_),
346 table_(mapper.table_),
347 error_(mapper.error_) {}
349 Arc operator()(
const Arc &arc);
352 return (type_ ==
ENCODE && (flags_ & kEncodeWeights))
365 uint8_t
Flags()
const {
return flags_; }
368 uint64_t outprops = inprops;
369 if (error_) outprops |=
kError;
371 if (flags_ & kEncodeLabels) {
374 if (flags_ & kEncodeWeights) {
382 if (flags_ & kEncodeLabels) {
385 if (flags_ & kEncodeWeights) {
397 return table ?
new EncodeMapper(table->Flags(), type, table) :
nullptr;
402 std::ifstream strm(std::string(source),
403 std::ios_base::in | std::ios_base::binary);
405 LOG(ERROR) <<
"EncodeMapper: Can't open file: " << source;
408 return Read(strm, source, type);
411 bool Write(std::ostream &strm, std::string_view source)
const {
412 return table_->Write(strm, source);
415 bool Write(std::string_view source)
const {
416 std::ofstream strm(std::string(source),
417 std::ios_base::out | std::ios_base::binary);
419 LOG(ERROR) <<
"EncodeMapper: Can't open file: " << source;
422 return Write(strm, source);
430 table_->SetInputSymbols(syms);
434 table_->SetOutputSymbols(syms);
440 std::shared_ptr<internal::EncodeTable<Arc>> table_;
445 : flags_(flags), type_(type), table_(table), error_(
false) {}
457 ((!(flags_ & kEncodeWeights) ||
458 ((flags_ & kEncodeWeights) && arc.weight == Weight::Zero())))) {
461 const auto label = table_->Encode(arc);
462 return Arc(label, flags_ & kEncodeLabels ? label : arc.olabel,
463 flags_ & kEncodeWeights ? Weight::One() : arc.weight,
470 if (arc.ilabel == 0)
return arc;
471 if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
472 FSTERROR() <<
"EncodeMapper: Label-encoded arc has different " 473 "input and output labels";
476 if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
477 FSTERROR() <<
"EncodeMapper: Weight-encoded arc has non-trivial weight";
480 const auto triple = table_->Decode(arc.ilabel);
482 FSTERROR() <<
"EncodeMapper: Decode failed";
485 }
else if (triple->ilabel ==
kNoLabel) {
487 return Arc(0, 0, triple->weight, arc.nextstate);
489 return Arc(triple->ilabel,
490 flags_ & kEncodeLabels ? triple->olabel : arc.olabel,
491 flags_ & kEncodeWeights ? triple->weight : arc.weight,
546 FSTERROR() <<
"EncodeFst::Copy(true): Not allowed";
596 :
public StateIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
605 :
public ArcIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
614 :
public StateIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
623 :
public ArcIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
637 #endif // FST_ENCODE_H_
void ArcMap(MutableFst< A > *fst, C *mapper)
bool Write(std::ostream &strm, std::string_view source) const
typename ArcMapFst< Arc, Arc, EncodeMapper< Arc > >::Arc Arc
constexpr uint64_t kWeightInvariantProperties
const SymbolTable * OutputSymbols() const
constexpr uint8_t kEncodeFlags
const SymbolTable * OutputSymbols() const
bool Write(std::string_view source) const
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
TripleHash(uint8_t flags)
constexpr int32_t kEncodeDeprecatedMagicNumber
ArcIterator(const DecodeFst< Arc > &fst, typename Arc::StateId s)
constexpr uint8_t kEncodeHasISymbols
EncodeFst(const Fst< Arc > &fst, Mapper *encoder)
virtual SymbolTable * Copy() const
DecodeFst(const Fst< Arc > &fst, const Mapper &encoder)
const SymbolTable * InputSymbols() const override=0
void SetInputSymbols(const SymbolTable *syms)
constexpr uint64_t kError
constexpr uint64_t kRmSuperFinalProperties
virtual void SetInputSymbols(const SymbolTable *isyms)=0
Triple(Label ilabel, Label olabel, Weight weight)
void RmFinalEpsilon(MutableFst< Arc > *fst)
MapFinalAction FinalAction() const
static EncodeTable * Read(std::istream &strm, std::string_view source)
DecodeFst(const DecodeFst &fst, bool safe=false)
constexpr uint64_t kUnweightedCycles
constexpr MapSymbolsAction InputSymbolsAction() const
typename ArcMapFst< Arc, Arc, EncodeMapper< Arc > >::Arc Arc
void SetOutputSymbols(const SymbolTable *syms)
static EncodeMapper * Read(std::istream &strm, std::string_view source, EncodeType type=ENCODE)
DecodeFst * Copy(bool safe=false) const override
constexpr uint8_t kEncodeHasOSymbols
std::ostream & WriteType(std::ostream &strm, const T t)
EncodeFst(const Fst< Arc > &fst, const Mapper &encoder)
EncodeMapper(const EncodeMapper &mapper)
const SymbolTable * OutputSymbols() const override=0
EncodeMapper(const EncodeMapper &mapper, EncodeType type)
EncodeTable(uint8_t flags)
EncodeMapper(uint8_t flags, EncodeType type=ENCODE)
ArcIterator(const EncodeFst< Arc > &fst, typename Arc::StateId s)
void Write(std::ostream &strm) const
constexpr int32_t kEncodeMagicNumber
constexpr MapSymbolsAction OutputSymbolsAction() const
Arc operator()(const Arc &arc)
StateIterator(const DecodeFst< Arc > &fst)
constexpr uint8_t kEncodeLabels
void SetInputSymbols(const SymbolTable *syms)
const Triple * Decode(Label label) const
constexpr uint64_t kIDeterministic
constexpr uint64_t kOLabelInvariantProperties
StateIterator(const EncodeFst< Arc > &fst)
const SymbolTable * InputSymbols() const
Label Encode(const Arc &arc)
constexpr uint64_t kILabelInvariantProperties
constexpr uint8_t kEncodeWeights
constexpr uint64_t kFstProperties
constexpr uint64_t kUnweighted
virtual const SymbolTable * InputSymbols() const =0
Triple(const Arc &arc, uint8_t flags)
static std::unique_ptr< Triple > Read(std::istream &strm)
bool operator()(const Triple *x, const Triple *y) const
static EncodeMapper * Read(std::string_view source, EncodeType type=ENCODE)
const SymbolTable * InputSymbols() const
bool Write(std::ostream &strm, std::string_view source) const
std::istream & ReadType(std::istream &strm, T *t)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
size_t operator()(const Triple *triple) const
uint64_t Properties(uint64_t inprops)
typename Arc::Label Label
EncodeFst(const EncodeFst &fst, bool copy=false)
typename Arc::Weight Weight
static SymbolTable * Read(std::istream &strm, std::string_view source)
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
void SetOutputSymbols(const SymbolTable *syms)
constexpr uint64_t kAddSuperFinalProperties
bool operator==(const Triple &other) const
constexpr uint64_t kAcceptor
EncodeFst * Copy(bool safe=false) const override
virtual const SymbolTable * OutputSymbols() const =0