20 #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_ 21 #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_ 56 template <
class Var,
class Edge>
58 std::vector<std::vector<Edge>> *expanded_code) {
59 expanded_code->resize(code.size());
60 for (
int i = 0; i < code.size(); ++i) {
61 if (code[i].first > i) {
62 LOG(ERROR) <<
"ExpandLZCode: Not a valid code";
65 auto &codeword = (*expanded_code)[i];
66 if (code[i].first == 0) {
67 codeword.resize(1, code[i].second);
69 const auto &other_codeword = (*expanded_code)[code[i].first - 1];
70 codeword.resize(other_codeword.size() + 1);
71 std::copy(other_codeword.cbegin(), other_codeword.cend(),
73 codeword[other_codeword.size()] = code[i].second;
83 template <
class Var,
class Edge,
class EdgeLessThan,
class EdgeEquals>
87 root_.current_number = dict_number_++;
88 root_.current_edge = default_edge_;
89 decode_vector_.emplace_back(0, default_edge_);
93 void BatchEncode(
const std::vector<Edge> &input,
94 std::vector<std::pair<Var, Edge>> *output);
98 bool BatchDecode(
const std::vector<std::pair<Var, Edge>> &input,
99 std::vector<Edge> *output);
104 if (index >= decode_vector_.size()) {
105 LOG(ERROR) <<
"LempelZiv::SingleDecode: " 106 <<
"Index exceeded the dictionary size";
109 *output = decode_vector_[index].second;
118 std::map<Edge, std::unique_ptr<Node>, EdgeLessThan> next_number;
123 std::vector<std::pair<Var, Edge>> decode_vector_;
127 template <
class Var,
class Edge,
class EdgeLessThan,
class EdgeEquals>
129 const std::vector<Edge> &input, std::vector<std::pair<Var, Edge>> *output) {
130 for (
auto it = input.cbegin(); it != input.cend(); ++it) {
131 auto *temp_node = &root_;
132 while (it != input.cend()) {
133 auto next = temp_node->next_number.find(*it);
134 if (next != temp_node->next_number.cend()) {
135 temp_node = next->second.get();
141 if (it == input.cend() && temp_node->current_number != 0) {
142 output->emplace_back(temp_node->current_number, default_edge_);
143 }
else if (it != input.cend()) {
144 output->emplace_back(temp_node->current_number, *it);
145 auto new_node = std::make_unique<Node>();
146 new_node->current_number = dict_number_++;
147 new_node->current_edge = *it;
148 temp_node->next_number[*it] = std::move(new_node);
150 if (it == input.cend())
break;
154 template <
class Var,
class Edge,
class EdgeLessThan,
class EdgeEquals>
156 const std::vector<std::pair<Var, Edge>> &input, std::vector<Edge> *output) {
157 for (
const auto &[var, edge] : input) {
158 std::vector<Edge> temp_output;
159 EdgeEquals InstEdgeEquals;
160 if (InstEdgeEquals(edge, default_edge_) != 1) {
161 decode_vector_.emplace_back(var, edge);
162 temp_output.push_back(edge);
164 auto temp_integer = var;
165 if (temp_integer >= decode_vector_.size()) {
166 LOG(ERROR) <<
"LempelZiv::BatchDecode: " 167 <<
"Index exceeded the dictionary size";
170 while (temp_integer != 0) {
171 temp_output.push_back(decode_vector_[temp_integer].second);
172 temp_integer = decode_vector_[temp_integer].first;
174 output->insert(output->cend(), temp_output.rbegin(), temp_output.rend());
193 bool Decompress(std::istream &strm,
const std::string &source,
207 void DecodeProcessedFst(
const std::vector<StateId> &input,
211 void WriteToStream(std::ostream &strm);
214 void WriteWeight(
const std::vector<Weight> &input, std::ostream &strm);
216 void ReadWeight(std::istream &strm, std::vector<Weight> *output);
222 template <
class CVar>
224 std::vector<bool> current_code;
226 buffer_code_.insert(buffer_code_.cend(), current_code.begin(),
232 LZLabel() : label(0) {}
236 struct LabelLessThan {
237 bool operator()(
const LZLabel &labelone,
const LZLabel &labeltwo)
const {
238 return labelone.label < labeltwo.label;
243 bool operator()(
const LZLabel &labelone,
const LZLabel &labeltwo)
const {
244 return labelone.label == labeltwo.label;
249 Transition() : nextstate(0), label(0), weight(Weight::Zero()) {}
256 struct TransitionLessThan {
257 bool operator()(
const Transition &transition_one,
258 const Transition &transition_two)
const {
259 if (transition_one.nextstate == transition_two.nextstate) {
260 return transition_one.label < transition_two.label;
262 return transition_one.nextstate < transition_two.nextstate;
267 struct TransitionEquals {
268 bool operator()(
const Transition &transition_one,
269 const Transition &transition_two)
const {
270 return transition_one.nextstate == transition_two.nextstate &&
271 transition_one.label == transition_two.label;
275 struct OldDictCompare {
276 bool operator()(
const std::pair<StateId, Transition> &pair_one,
277 const std::pair<StateId, Transition> &pair_two)
const {
278 if (pair_one.second.nextstate == pair_two.second.nextstate) {
279 return pair_one.second.label < pair_two.second.label;
281 return pair_one.second.nextstate < pair_two.second.nextstate;
286 std::vector<bool> buffer_code_;
287 std::vector<Weight> arc_weight_;
288 std::vector<Weight> final_weight_;
301 std::vector<StateId> *order) {
305 explicit BfsVisitor(std::vector<StateId> *order) : order_(order) {}
307 void InitVisit(
const Fst<Arc> &fst) {}
310 order_->at(s) = num_bfs_states_++;
314 bool WhiteArc(
StateId s,
const Arc &arc) {
return true; }
315 bool GreyArc(
StateId s,
const Arc &arc) {
return true; }
316 bool BlackArc(
StateId s,
const Arc &arc) {
return true; }
318 void FinishVisit() {}
321 std::vector<StateId> *order_ =
nullptr;
326 BfsVisitor visitor(order);
335 *preprocessedfst = fst;
336 if (!preprocessedfst->
NumStates())
return;
338 Encode(preprocessedfst, encoder);
339 std::vector<StateId> order;
341 BfsOrder(*preprocessedfst, &order);
348 std::ostream &strm) {
349 std::vector<StateId> output;
352 std::vector<LZLabel> current_new_input;
353 std::vector<Transition> current_old_input;
354 std::vector<std::pair<StateId, LZLabel>> current_new_output;
355 std::vector<std::pair<StateId, Transition>> current_old_output;
356 std::vector<StateId> final_states;
357 const auto number_of_states = fst.
NumStates();
360 WriteToBuffer<StateId>(number_of_states);
361 for (
StateId state = 0; state < number_of_states; ++state) {
362 current_new_input.clear();
363 current_old_input.clear();
364 current_new_output.clear();
365 current_old_output.clear();
366 if (state > seen_states) ++seen_states;
368 if (fst.
Final(state) != Weight::Zero()) {
369 final_states.push_back(state);
370 final_weight_.push_back(fst.
Final(state));
374 const auto &arc = aiter.Value();
375 if (arc.nextstate > seen_states) {
378 temp_label.label = arc.ilabel;
379 arc_weight_.push_back(arc.weight);
380 current_new_input.push_back(temp_label);
382 Transition temp_transition;
383 temp_transition.nextstate = arc.nextstate;
384 temp_transition.label = arc.ilabel;
385 temp_transition.weight = arc.weight;
386 current_old_input.push_back(temp_transition);
390 dict_new.
BatchEncode(current_new_input, ¤t_new_output);
391 WriteToBuffer<StateId>(current_new_output.size());
392 for (
auto it = current_new_output.cbegin(); it != current_new_output.cend();
394 WriteToBuffer<StateId>(it->first);
395 WriteToBuffer<Label>((it->second).label);
398 static const TransitionLessThan transition_less_than;
399 std::sort(current_old_input.begin(), current_old_input.end(),
400 transition_less_than);
401 for (
auto it = current_old_input.begin(); it != current_old_input.end();
403 arc_weight_.push_back(it->weight);
405 dict_old.
BatchEncode(current_old_input, ¤t_old_output);
406 std::vector<StateId> dict_old_temp;
407 std::vector<Transition> transition_old_temp;
408 for (
auto it = current_old_output.begin(); it != current_old_output.end();
410 dict_old_temp.push_back(it->first);
411 transition_old_temp.push_back(it->second);
413 if (!transition_old_temp.empty()) {
414 if ((transition_old_temp.back()).nextstate == 0 &&
415 (transition_old_temp.back()).label == 0) {
416 transition_old_temp.pop_back();
419 std::sort(dict_old_temp.begin(), dict_old_temp.end());
420 std::sort(transition_old_temp.begin(), transition_old_temp.end(),
421 transition_less_than);
422 WriteToBuffer<StateId>(dict_old_temp.size());
423 if (dict_old_temp.size() == transition_old_temp.size()) {
424 WriteToBuffer<int>(0);
426 WriteToBuffer<int>(1);
429 if (!dict_old_temp.empty()) {
430 WriteToBuffer<StateId>(dict_old_temp.front());
431 previous = dict_old_temp.front();
433 if (dict_old_temp.size() > 1) {
434 for (
auto it = dict_old_temp.begin() + 1; it != dict_old_temp.end();
436 WriteToBuffer<StateId>(*it - previous);
440 if (!transition_old_temp.empty()) {
441 WriteToBuffer<StateId>((transition_old_temp.front()).nextstate);
442 previous = transition_old_temp.front().nextstate;
443 WriteToBuffer<Label>(transition_old_temp.front().label);
445 if (transition_old_temp.size() > 1) {
446 for (
auto it = transition_old_temp.begin() + 1;
447 it != transition_old_temp.end(); ++it) {
448 WriteToBuffer<StateId>(it->nextstate - previous);
449 previous = it->nextstate;
450 WriteToBuffer<StateId>(it->label);
455 WriteToBuffer<StateId>(final_states.size());
456 if (!final_states.empty()) {
457 for (
auto it = final_states.begin(); it != final_states.end(); ++it) {
458 WriteToBuffer<StateId>(*it);
464 if (unweighted == 0) {
465 WriteWeight(arc_weight_, strm);
466 WriteWeight(final_weight_, strm);
476 std::vector<std::pair<StateId, LZLabel>> current_new_input;
477 std::vector<std::pair<StateId, Transition>> current_old_input;
478 std::vector<LZLabel> current_new_output;
479 std::vector<Transition> current_old_output;
480 std::vector<std::pair<StateId, Transition>> actual_old_dict_numbers;
481 std::vector<Transition> actual_old_dict_transitions;
482 auto arc_weight_it = arc_weight_.begin();
483 Transition default_transition;
486 const StateId num_states = input.front();
487 if (num_states > 0) {
490 for (
StateId state = 1; state < num_states; ++state) {
494 auto main_it = input.cbegin();
496 for (
StateId current_state = 0; current_state < num_states; ++current_state) {
497 if (current_state >= seen_states) ++seen_states;
498 current_new_input.clear();
499 current_new_output.clear();
500 current_old_input.clear();
501 current_old_output.clear();
503 StateId current_number_new_elements = *main_it;
505 for (
StateId new_integer = 0; new_integer < current_number_new_elements;
507 std::pair<StateId, LZLabel> temp_new_dict_element;
508 temp_new_dict_element.first = *main_it;
511 temp_label.label = *main_it;
513 temp_new_dict_element.second = temp_label;
514 current_new_input.push_back(temp_new_dict_element);
516 dict_new.
BatchDecode(current_new_input, ¤t_new_output);
517 for (
const auto &label : current_new_output) {
519 fst->
AddArc(current_state,
520 Arc(label.label, label.label, *arc_weight_it, seen_states));
523 fst->
AddArc(current_state,
524 Arc(label.label, label.label, Weight::One(), seen_states));
528 StateId current_number_old_elements = *main_it;
530 StateId is_zero_removed = *main_it;
533 actual_old_dict_numbers.clear();
534 for (
StateId new_integer = 0; new_integer < current_number_old_elements;
536 std::pair<StateId, Transition> pair_temp_transition;
537 if (new_integer == 0) {
538 pair_temp_transition.first = *main_it;
541 pair_temp_transition.first = *main_it + previous;
542 previous = pair_temp_transition.first;
545 Transition temp_test;
546 if (!dict_old.
SingleDecode(pair_temp_transition.first, &temp_test)) {
547 FSTERROR() <<
"Compressor::Decode: failed";
551 pair_temp_transition.second = temp_test;
552 actual_old_dict_numbers.push_back(pair_temp_transition);
555 static const OldDictCompare old_dict_compare;
556 std::sort(actual_old_dict_numbers.begin(), actual_old_dict_numbers.end(),
560 actual_old_dict_transitions.clear();
562 new_integer < current_number_old_elements - is_zero_removed;
564 Transition temp_transition;
565 if (new_integer == 0) {
566 temp_transition.nextstate = *main_it;
569 temp_transition.nextstate = *main_it + previous;
570 previous = temp_transition.nextstate;
573 temp_transition.label = *main_it;
575 actual_old_dict_transitions.push_back(temp_transition);
577 if (is_zero_removed == 1) {
578 actual_old_dict_transitions.push_back(default_transition);
580 auto trans_it = actual_old_dict_transitions.cbegin();
581 auto dict_it = actual_old_dict_numbers.cbegin();
582 while (trans_it != actual_old_dict_transitions.cend() &&
583 dict_it != actual_old_dict_numbers.cend()) {
584 if (dict_it->first == 0) {
587 std::pair<StateId, Transition> temp_pair;
588 static const TransitionEquals transition_equals;
589 static const TransitionLessThan transition_less_than;
590 if (transition_equals(*trans_it, default_transition)) {
591 temp_pair.first = dict_it->first;
592 temp_pair.second = default_transition;
594 }
else if (transition_less_than(dict_it->second, *trans_it)) {
595 temp_pair.first = dict_it->first;
596 temp_pair.second = *trans_it;
600 temp_pair.second = *trans_it;
603 current_old_input.push_back(temp_pair);
606 while (trans_it != actual_old_dict_transitions.cend()) {
607 std::pair<StateId, Transition> temp_pair;
609 temp_pair.second = *trans_it;
611 current_old_input.push_back(temp_pair);
614 if (!dict_old.
BatchDecode(current_old_input, ¤t_old_output)) {
615 FSTERROR() <<
"Compressor::Decode: Failed";
619 for (
auto it = current_old_output.cbegin(); it != current_old_output.cend();
622 fst->
AddArc(current_state,
623 Arc(it->label, it->label, *arc_weight_it, it->nextstate));
626 fst->
AddArc(current_state,
627 Arc(it->label, it->label, Weight::One(), it->nextstate));
632 StateId number_of_final_states = *main_it;
633 if (number_of_final_states > 0) {
635 for (
StateId temp_int = 0; temp_int < number_of_final_states; ++temp_int) {
639 fst->
SetFinal(*main_it, final_weight_[temp_int]);
648 std::vector<Weight> *output) {
652 for (int64_t i = 0; i < size; ++i) {
654 output->push_back(weight);
662 int32_t magic_number = 0;
664 if (magic_number != kCompressMagicNumber) {
665 LOG(ERROR) <<
"Decompress: Bad compressed Fst: " << source;
668 std::unique_ptr<EncodeMapper<Arc>> encoder(
670 std::vector<bool> bool_code;
675 for (int64_t i = 0; i < data_size; ++i) {
677 for (
int j = 0; j < 8; ++j) {
678 uint8_t temp = msb & block;
679 bool_code.push_back(temp == 128);
683 std::vector<StateId> int_code;
688 if (unweighted == 0) {
689 ReadWeight(strm, &arc_weight_);
690 ReadWeight(strm, &final_weight_);
692 DecodeProcessedFst(int_code, fst, unweighted);
693 DecodeForCompress(fst, *encoder);
699 std::ostream &strm) {
700 int64_t size = input.size();
702 for (
auto it = input.begin(); it != input.end(); ++it) {
709 while (buffer_code_.size() % 8 != 0) buffer_code_.push_back(
true);
710 int64_t data_size = buffer_code_.size() / 8;
714 for (
auto it = buffer_code_.begin(); it != buffer_code_.end(); ++it) {
731 Preprocess(fst, &processedfst, &encoder);
733 encoder.
Write(strm,
"ostream");
734 EncodeProcessedFst(processedfst, strm);
749 if (!source.empty()) {
750 fstrm.open(source, std::ios_base::out | std::ios_base::binary);
752 LOG(ERROR) <<
"Compress: Can't open file: " << source;
756 std::ostream &ostrm = fstrm.is_open() ? fstrm : std::cout;
762 bool Decompress(std::istream &strm,
const std::string &source,
773 if (!source.empty()) {
774 fstrm.open(source, std::ios_base::in | std::ios_base::binary);
776 LOG(ERROR) <<
"Decompress: Can't open file: " << source;
780 std::istream &istrm = fstrm.is_open() ? fstrm : std::cin;
781 Decompress(istrm, source.empty() ?
"standard input" : source, fst);
787 #endif // FST_EXTENSIONS_COMPRESS_COMPRESS_H_ constexpr int32_t kCompressMagicNumber
bool Decompress(std::istream &strm, const std::string &source, MutableFst< Arc > *fst)
bool ExpandLZCode(const std::vector< std::pair< Var, Edge >> &code, std::vector< std::vector< Edge >> *expanded_code)
void ArcMap(MutableFst< A > *fst, C *mapper)
void DecodeProcessedFst(const std::vector< StateId > &input, MutableFst< Arc > *fst, bool unweighted)
void BatchEncode(const std::vector< Edge > &input, std::vector< std::pair< Var, Edge >> *output)
void DecodeForCompress(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
const SymbolTable * OutputSymbols() const
bool Decompress(std::istream &strm, const std::string &source, MutableFst< Arc > *fst)
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter, bool access_only=false)
bool Write(std::ostream &strm, const std::string &source) const
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool BatchDecode(const std::vector< std::pair< Var, Edge >> &input, std::vector< Edge > *output)
void ReadWeight(std::istream &strm, std::vector< Weight > *output)
void WriteWeight(const std::vector< Weight > &input, std::ostream &strm)
void BfsOrder(const ExpandedFst< Arc > &fst, std::vector< StateId > *order)
std::ostream & WriteType(std::ostream &strm, const T t)
typename Arc::Label Label
void EncodeProcessedFst(const ExpandedFst< Arc > &fst, std::ostream &strm)
static void BatchDecode(const std::vector< bool > &input, std::vector< Var > *output)
typename Arc::StateId StateId
constexpr uint8_t kEncodeLabels
virtual void SetProperties(uint64_t props, uint64_t mask)=0
static void DeltaEncode(const Var &input, std::vector< bool > *code)
const SymbolTable * InputSymbols() const
virtual void AddArc(StateId, const Arc &)=0
constexpr uint64_t kUnweighted
void WriteToStream(std::ostream &strm)
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
std::istream & ReadType(std::istream &strm, T *t)
void WriteToBuffer(CVar input)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
void Compress(const Fst< Arc > &fst, std::ostream &strm)
virtual StateId NumStates() const =0
void Preprocess(const Fst< Arc > &fst, MutableFst< Arc > *preprocessedfst, EncodeMapper< Arc > *encoder)
bool SingleDecode(const Var &index, Edge *output)
bool StateSort(std::vector< IntervalSet< Label >> *interval_sets, const std::vector< StateId > &order)
typename Arc::Weight Weight