37 #include <unordered_map> 66 const std::string &
ArcType()
const {
return arctype_; }
68 uint8_t
Flags()
const {
return flags_; }
70 size_t Size()
const {
return size_; }
74 void SetArcType(
const std::string &arctype) { arctype_ = arctype; }
76 void SetFlags(uint8_t flags) { flags_ = flags; }
78 void SetSize(
size_t size) { size_ = size; }
82 bool Read(std::istream &strm,
const std::string &source);
84 bool Write(std::ostream &strm,
const std::string &source)
const;
109 : ilabel(ilabel), olabel(olabel), weight(std::move(weight)) {}
113 : ilabel(arc.ilabel),
114 olabel(flags & kEncodeLabels ? arc.olabel : 0),
115 weight(flags & kEncodeWeights ? arc.weight :
Weight::One()) {}
117 static std::unique_ptr<Triple>
Read(std::istream &strm) {
118 auto triple = std::make_unique<Triple>();
125 void Write(std::ostream &strm)
const {
133 return (ilabel == other.
ilabel && olabel == other.
olabel &&
153 size_t hash = triple->
ilabel;
154 static constexpr
int lshift = 5;
155 static constexpr
int rshift = CHAR_BIT *
sizeof(size_t) - 5;
156 if (flags_ & kEncodeLabels) {
157 hash = hash << lshift ^ hash >> rshift ^ triple->
olabel;
159 if (flags_ & kEncodeWeights) {
160 hash = hash << lshift ^ hash >> rshift ^ triple->
weight.Hash();
170 : flags_(flags), triple2label_(1024,
TripleHash(flags)) {}
177 if (arc.nextstate ==
kNoStateId && (flags_ & kEncodeWeights)) {
180 return Encode(std::make_unique<Triple>(arc, flags_));
186 if (label < 1 || label > triples_.size()) {
187 LOG(ERROR) <<
"EncodeTable::Decode: Unknown decode label: " << label;
190 return triples_[label - 1].get();
193 size_t Size()
const {
return triples_.size(); }
195 static EncodeTable *Read(std::istream &strm,
const std::string &source);
197 bool Write(std::ostream &strm,
const std::string &source)
const;
209 isymbols_.reset(syms->
Copy());
219 osymbols_.reset(syms->
Copy());
230 triple2label_.emplace(triple.get(), triples_.size() + 1);
231 if (insert_result.second) triples_.push_back(std::move(triple));
232 return insert_result.first->second;
236 std::vector<std::unique_ptr<Triple>> triples_;
237 std::unordered_map<const Triple *, Label, TripleHash, TripleEqual>
239 std::unique_ptr<SymbolTable> isymbols_;
240 std::unique_ptr<SymbolTable> osymbols_;
248 const std::string &source) {
250 if (!hdr.
Read(strm, source))
return nullptr;
251 const auto flags = hdr.
Flags();
252 const auto size = hdr.
Size();
253 auto table = std::make_unique<EncodeTable>(flags);
254 for (int64_t i = 0; i < size; ++i) {
255 table->triples_.emplace_back(std::move(Triple::Read(strm)));
256 table->triple2label_[table->triples_.back().get()] = table->triples_.size();
265 LOG(ERROR) <<
"EncodeTable::Read: Read failed: " << source;
268 return table.release();
273 const std::string &source)
const {
278 if (!hdr.
Write(strm, source))
return false;
279 for (
const auto &triple : triples_) triple->Write(strm);
284 LOG(ERROR) <<
"EncodeTable::Write: Write failed: " << source;
310 using Label =
typename Arc::Label;
311 using Weight =
typename Arc::Weight;
317 table_(std::make_shared<internal::EncodeTable<Arc>>(flags)),
321 : flags_(mapper.flags_),
323 table_(mapper.table_),
328 : flags_(mapper.flags_),
330 table_(mapper.table_),
331 error_(mapper.error_) {}
333 Arc operator()(
const Arc &arc);
336 return (type_ ==
ENCODE && (flags_ & kEncodeWeights))
349 uint8_t
Flags()
const {
return flags_; }
352 uint64_t outprops = inprops;
353 if (error_) outprops |=
kError;
355 if (flags_ & kEncodeLabels) {
358 if (flags_ & kEncodeWeights) {
364 return outprops & mask;
372 return table ?
new EncodeMapper(table->Flags(), type, table) :
nullptr;
377 std::ifstream strm(source, std::ios_base::in | std::ios_base::binary);
379 LOG(ERROR) <<
"EncodeMapper: Can't open file: " << source;
382 return Read(strm, source, type);
385 bool Write(std::ostream &strm,
const std::string &source)
const {
386 return table_->Write(strm, source);
389 bool Write(
const std::string &source)
const {
390 std::ofstream strm(source,
391 std::ios_base::out | std::ios_base::binary);
393 LOG(ERROR) <<
"EncodeMapper: Can't open file: " << source;
396 return Write(strm, source);
404 table_->SetInputSymbols(syms);
408 table_->SetOutputSymbols(syms);
414 std::shared_ptr<internal::EncodeTable<Arc>> table_;
419 : flags_(flags), type_(type), table_(table), error_(
false) {}
431 ((!(flags_ & kEncodeWeights) ||
432 ((flags_ & kEncodeWeights) && arc.weight == Weight::Zero())))) {
435 const auto label = table_->Encode(arc);
436 return Arc(label, flags_ & kEncodeLabels ? label : arc.olabel,
437 flags_ & kEncodeWeights ? Weight::One() : arc.weight,
444 if (arc.ilabel == 0)
return arc;
445 if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
446 FSTERROR() <<
"EncodeMapper: Label-encoded arc has different " 447 "input and output labels";
450 if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
451 FSTERROR() <<
"EncodeMapper: Weight-encoded arc has non-trivial weight";
454 const auto triple = table_->Decode(arc.ilabel);
456 FSTERROR() <<
"EncodeMapper: Decode failed";
459 }
else if (triple->ilabel ==
kNoLabel) {
461 return Arc(0, 0, triple->weight, arc.nextstate);
463 return Arc(triple->ilabel,
464 flags_ & kEncodeLabels ? triple->olabel : arc.olabel,
465 flags_ & kEncodeWeights ? triple->weight : arc.weight,
520 FSTERROR() <<
"EncodeFst::Copy(true): Not allowed";
570 :
public StateIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
579 :
public ArcIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
588 :
public StateIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
597 :
public ArcIterator<ArcMapFst<Arc, Arc, EncodeMapper<Arc>>> {
611 #endif // FST_ENCODE_H_
bool Write(const std::string &source) const
void ArcMap(MutableFst< A > *fst, C *mapper)
typename ArcMapFst< Arc, Arc, EncodeMapper< Arc > >::Arc Arc
constexpr uint64_t kWeightInvariantProperties
const SymbolTable * OutputSymbols() const
constexpr uint8_t kEncodeFlags
const SymbolTable * OutputSymbols() const
bool Write(std::ostream &strm, const std::string &source) const
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
static EncodeMapper * Read(const std::string &source, EncodeType type=ENCODE)
TripleHash(uint8_t flags)
ArcIterator(const DecodeFst< Arc > &fst, typename Arc::StateId s)
constexpr uint8_t kEncodeHasISymbols
EncodeFst(const Fst< Arc > &fst, Mapper *encoder)
virtual SymbolTable * Copy() const
DecodeFst(const Fst< Arc > &fst, const Mapper &encoder)
bool Write(std::ostream &strm, const std::string &source) const
const SymbolTable * InputSymbols() const override=0
void SetInputSymbols(const SymbolTable *syms)
constexpr uint64_t kError
constexpr uint64_t kRmSuperFinalProperties
virtual void SetInputSymbols(const SymbolTable *isyms)=0
Triple(Label ilabel, Label olabel, Weight weight)
void RmFinalEpsilon(MutableFst< Arc > *fst)
MapFinalAction FinalAction() const
DecodeFst(const DecodeFst &fst, bool safe=false)
constexpr MapSymbolsAction InputSymbolsAction() const
typename ArcMapFst< Arc, Arc, EncodeMapper< Arc > >::Arc Arc
void SetOutputSymbols(const SymbolTable *syms)
DecodeFst * Copy(bool safe=false) const override
static EncodeMapper * Read(std::istream &strm, const std::string &source, EncodeType type=ENCODE)
constexpr uint8_t kEncodeHasOSymbols
std::ostream & WriteType(std::ostream &strm, const T t)
EncodeFst(const Fst< Arc > &fst, const Mapper &encoder)
EncodeMapper(const EncodeMapper &mapper)
const SymbolTable * OutputSymbols() const override=0
EncodeMapper(const EncodeMapper &mapper, EncodeType type)
EncodeTable(uint8_t flags)
EncodeMapper(uint8_t flags, EncodeType type=ENCODE)
ArcIterator(const EncodeFst< Arc > &fst, typename Arc::StateId s)
void Write(std::ostream &strm) const
constexpr int32_t kEncodeMagicNumber
constexpr MapSymbolsAction OutputSymbolsAction() const
Arc operator()(const Arc &arc)
StateIterator(const DecodeFst< Arc > &fst)
constexpr uint8_t kEncodeLabels
static EncodeTable * Read(std::istream &strm, const std::string &source)
void SetInputSymbols(const SymbolTable *syms)
const Triple * Decode(Label label) const
static SymbolTable * Read(std::istream &strm, const std::string &source)
constexpr uint64_t kIDeterministic
constexpr uint64_t kOLabelInvariantProperties
StateIterator(const EncodeFst< Arc > &fst)
const SymbolTable * InputSymbols() const
Label Encode(const Arc &arc)
constexpr uint64_t kILabelInvariantProperties
constexpr uint8_t kEncodeWeights
constexpr uint64_t kFstProperties
virtual const SymbolTable * InputSymbols() const =0
Triple(const Arc &arc, uint8_t flags)
static std::unique_ptr< Triple > Read(std::istream &strm)
bool operator()(const Triple *x, const Triple *y) const
const SymbolTable * InputSymbols() const
std::istream & ReadType(std::istream &strm, T *t)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
size_t operator()(const Triple *triple) const
uint64_t Properties(uint64_t inprops)
typename Arc::Label Label
EncodeFst(const EncodeFst &fst, bool copy=false)
typename Arc::Weight Weight
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
void SetOutputSymbols(const SymbolTable *syms)
constexpr uint64_t kAddSuperFinalProperties
bool operator==(const Triple &other) const
EncodeFst * Copy(bool safe=false) const override
virtual const SymbolTable * OutputSymbols() const =0