20 #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_ 21 #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_ 53 #include <string_view> 64 template <
class Var,
class Edge>
65 [[nodiscard]]
bool ExpandLZCode(
const std::vector<std::pair<Var, Edge>> &code,
66 std::vector<std::vector<Edge>> *expanded_code) {
67 expanded_code->resize(code.size());
68 for (
int i = 0; i < code.size(); ++i) {
69 if (code[i].first > i) {
70 LOG(ERROR) <<
"ExpandLZCode: Not a valid code";
73 auto &codeword = (*expanded_code)[i];
74 if (code[i].first == 0) {
75 codeword.resize(1, code[i].second);
77 const auto &other_codeword = (*expanded_code)[code[i].first - 1];
78 codeword.resize(other_codeword.size() + 1);
79 std::copy(other_codeword.cbegin(), other_codeword.cend(),
81 codeword[other_codeword.size()] = code[i].second;
91 template <
class Var,
class Edge,
class EdgeLessThan,
class EdgeEquals>
95 root_.current_number = dict_number_++;
96 root_.current_edge = default_edge_;
97 decode_vector_.emplace_back(0, default_edge_);
101 void BatchEncode(
const std::vector<Edge> &input,
102 std::vector<std::pair<Var, Edge>> *output);
106 [[nodiscard]]
bool BatchDecode(
const std::vector<std::pair<Var, Edge>> &input,
107 std::vector<Edge> *output);
112 if (index >= decode_vector_.size()) {
113 LOG(ERROR) <<
"LempelZiv::SingleDecode: " 114 <<
"Index exceeded the dictionary size";
117 *output = decode_vector_[index].second;
126 std::map<Edge, std::unique_ptr<Node>, EdgeLessThan> next_number;
131 std::vector<std::pair<Var, Edge>> decode_vector_;
135 template <
class Var,
class Edge,
class EdgeLessThan,
class EdgeEquals>
137 const std::vector<Edge> &input, std::vector<std::pair<Var, Edge>> *output) {
138 for (
auto it = input.cbegin(); it != input.cend(); ++it) {
139 auto *temp_node = &root_;
140 while (it != input.cend()) {
141 auto next = temp_node->next_number.find(*it);
142 if (next != temp_node->next_number.cend()) {
143 temp_node = next->second.get();
149 if (it == input.cend() && temp_node->current_number != 0) {
150 output->emplace_back(temp_node->current_number, default_edge_);
151 }
else if (it != input.cend()) {
152 output->emplace_back(temp_node->current_number, *it);
153 auto new_node = std::make_unique<Node>();
154 new_node->current_number = dict_number_++;
155 new_node->current_edge = *it;
156 temp_node->next_number[*it] = std::move(new_node);
158 if (it == input.cend())
break;
162 template <
class Var,
class Edge,
class EdgeLessThan,
class EdgeEquals>
164 const std::vector<std::pair<Var, Edge>> &input, std::vector<Edge> *output) {
165 for (
const auto &[var, edge] : input) {
166 std::vector<Edge> temp_output;
167 EdgeEquals InstEdgeEquals;
168 if (InstEdgeEquals(edge, default_edge_) != 1) {
169 decode_vector_.emplace_back(var, edge);
170 temp_output.push_back(edge);
172 auto temp_integer = var;
173 if (temp_integer >= decode_vector_.size()) {
174 LOG(ERROR) <<
"LempelZiv::BatchDecode: " 175 <<
"Index exceeded the dictionary size";
178 while (temp_integer != 0) {
179 temp_output.push_back(decode_vector_[temp_integer].second);
180 temp_integer = decode_vector_[temp_integer].first;
182 output->insert(output->cend(), temp_output.rbegin(), temp_output.rend());
201 [[nodiscard]]
bool Decompress(std::istream &strm, std::string_view source,
215 void DecodeProcessedFst(
const std::vector<StateId> &input,
219 void WriteToStream(std::ostream &strm);
222 void WriteWeight(
const std::vector<Weight> &input, std::ostream &strm);
224 void ReadWeight(std::istream &strm, std::vector<Weight> *output);
230 template <
class CVar>
232 std::vector<bool> current_code;
234 buffer_code_.insert(buffer_code_.cend(), current_code.begin(),
240 LZLabel() : label(0) {}
244 struct LabelLessThan {
245 bool operator()(
const LZLabel &labelone,
const LZLabel &labeltwo)
const {
246 return labelone.label < labeltwo.label;
251 bool operator()(
const LZLabel &labelone,
const LZLabel &labeltwo)
const {
252 return labelone.label == labeltwo.label;
257 Transition() : nextstate(0), label(0), weight(Weight::Zero()) {}
264 struct TransitionLessThan {
265 bool operator()(
const Transition &transition_one,
266 const Transition &transition_two)
const {
267 if (transition_one.nextstate == transition_two.nextstate) {
268 return transition_one.label < transition_two.label;
270 return transition_one.nextstate < transition_two.nextstate;
275 struct TransitionEquals {
276 bool operator()(
const Transition &transition_one,
277 const Transition &transition_two)
const {
278 return transition_one.nextstate == transition_two.nextstate &&
279 transition_one.label == transition_two.label;
283 struct OldDictCompare {
284 bool operator()(
const std::pair<StateId, Transition> &pair_one,
285 const std::pair<StateId, Transition> &pair_two)
const {
286 if (pair_one.second.nextstate == pair_two.second.nextstate) {
287 return pair_one.second.label < pair_two.second.label;
289 return pair_one.second.nextstate < pair_two.second.nextstate;
294 std::vector<bool> buffer_code_;
295 std::vector<Weight> arc_weight_;
296 std::vector<Weight> final_weight_;
309 std::vector<StateId> *order) {
313 explicit BfsVisitor(std::vector<StateId> *order) : order_(order) {}
315 void InitVisit(
const Fst<Arc> &fst) {}
318 order_->at(s) = num_bfs_states_++;
322 bool WhiteArc(
StateId s,
const Arc &arc) {
return true; }
323 bool GreyArc(
StateId s,
const Arc &arc) {
return true; }
324 bool BlackArc(
StateId s,
const Arc &arc) {
return true; }
326 void FinishVisit() {}
329 std::vector<StateId> *order_ =
nullptr;
334 BfsVisitor visitor(order);
343 *preprocessedfst = fst;
344 if (!preprocessedfst->
NumStates())
return;
346 Encode(preprocessedfst, encoder);
347 std::vector<StateId> order;
349 BfsOrder(*preprocessedfst, &order);
356 std::ostream &strm) {
357 std::vector<StateId> output;
360 std::vector<LZLabel> current_new_input;
361 std::vector<Transition> current_old_input;
362 std::vector<std::pair<StateId, LZLabel>> current_new_output;
363 std::vector<std::pair<StateId, Transition>> current_old_output;
364 std::vector<StateId> final_states;
365 const auto number_of_states = fst.
NumStates();
368 WriteToBuffer<StateId>(number_of_states);
369 for (
StateId state = 0; state < number_of_states; ++state) {
370 current_new_input.clear();
371 current_old_input.clear();
372 current_new_output.clear();
373 current_old_output.clear();
374 if (state > seen_states) ++seen_states;
376 if (fst.
Final(state) != Weight::Zero()) {
377 final_states.push_back(state);
378 final_weight_.push_back(fst.
Final(state));
382 const auto &arc = aiter.Value();
383 if (arc.nextstate > seen_states) {
386 temp_label.label = arc.ilabel;
387 arc_weight_.push_back(arc.weight);
388 current_new_input.push_back(temp_label);
390 Transition temp_transition;
391 temp_transition.nextstate = arc.nextstate;
392 temp_transition.label = arc.ilabel;
393 temp_transition.weight = arc.weight;
394 current_old_input.push_back(temp_transition);
398 dict_new.
BatchEncode(current_new_input, ¤t_new_output);
399 WriteToBuffer<StateId>(current_new_output.size());
400 for (
auto it = current_new_output.cbegin(); it != current_new_output.cend();
402 WriteToBuffer<StateId>(it->first);
403 WriteToBuffer<Label>((it->second).label);
406 static const TransitionLessThan transition_less_than;
407 std::sort(current_old_input.begin(), current_old_input.end(),
408 transition_less_than);
409 for (
auto it = current_old_input.begin(); it != current_old_input.end();
411 arc_weight_.push_back(it->weight);
413 dict_old.
BatchEncode(current_old_input, ¤t_old_output);
414 std::vector<StateId> dict_old_temp;
415 std::vector<Transition> transition_old_temp;
416 for (
auto it = current_old_output.begin(); it != current_old_output.end();
418 dict_old_temp.push_back(it->first);
419 transition_old_temp.push_back(it->second);
421 if (!transition_old_temp.empty()) {
422 if ((transition_old_temp.back()).nextstate == 0 &&
423 (transition_old_temp.back()).label == 0) {
424 transition_old_temp.pop_back();
427 std::sort(dict_old_temp.begin(), dict_old_temp.end());
428 std::sort(transition_old_temp.begin(), transition_old_temp.end(),
429 transition_less_than);
430 WriteToBuffer<StateId>(dict_old_temp.size());
431 if (dict_old_temp.size() == transition_old_temp.size()) {
432 WriteToBuffer<int>(0);
434 WriteToBuffer<int>(1);
437 if (!dict_old_temp.empty()) {
438 WriteToBuffer<StateId>(dict_old_temp.front());
439 previous = dict_old_temp.front();
441 if (dict_old_temp.size() > 1) {
442 for (
auto it = dict_old_temp.begin() + 1; it != dict_old_temp.end();
444 WriteToBuffer<StateId>(*it - previous);
448 if (!transition_old_temp.empty()) {
449 WriteToBuffer<StateId>((transition_old_temp.front()).nextstate);
450 previous = transition_old_temp.front().nextstate;
451 WriteToBuffer<Label>(transition_old_temp.front().label);
453 if (transition_old_temp.size() > 1) {
454 for (
auto it = transition_old_temp.begin() + 1;
455 it != transition_old_temp.end(); ++it) {
456 WriteToBuffer<StateId>(it->nextstate - previous);
457 previous = it->nextstate;
458 WriteToBuffer<StateId>(it->label);
463 WriteToBuffer<StateId>(final_states.size());
464 if (!final_states.empty()) {
465 for (
auto it = final_states.begin(); it != final_states.end(); ++it) {
466 WriteToBuffer<StateId>(*it);
472 if (unweighted == 0) {
473 WriteWeight(arc_weight_, strm);
474 WriteWeight(final_weight_, strm);
484 std::vector<std::pair<StateId, LZLabel>> current_new_input;
485 std::vector<std::pair<StateId, Transition>> current_old_input;
486 std::vector<LZLabel> current_new_output;
487 std::vector<Transition> current_old_output;
488 std::vector<std::pair<StateId, Transition>> actual_old_dict_numbers;
489 std::vector<Transition> actual_old_dict_transitions;
490 auto arc_weight_it = arc_weight_.begin();
491 Transition default_transition;
494 const StateId num_states = input.front();
495 if (num_states > 0) {
498 for (
StateId state = 1; state < num_states; ++state) {
502 auto main_it = input.cbegin();
504 for (
StateId current_state = 0; current_state < num_states; ++current_state) {
505 if (current_state >= seen_states) ++seen_states;
506 current_new_input.clear();
507 current_new_output.clear();
508 current_old_input.clear();
509 current_old_output.clear();
511 StateId current_number_new_elements = *main_it;
513 for (
StateId new_integer = 0; new_integer < current_number_new_elements;
515 std::pair<StateId, LZLabel> temp_new_dict_element;
516 temp_new_dict_element.first = *main_it;
519 temp_label.label = *main_it;
521 temp_new_dict_element.second = temp_label;
522 current_new_input.push_back(temp_new_dict_element);
524 if (!dict_new.
BatchDecode(current_new_input, ¤t_new_output)) {
525 FSTERROR() <<
"Compressor::Decode: failed";
529 for (
const auto &label : current_new_output) {
531 fst->
AddArc(current_state,
532 Arc(label.label, label.label, *arc_weight_it, seen_states));
535 fst->
AddArc(current_state,
536 Arc(label.label, label.label, Weight::One(), seen_states));
540 StateId current_number_old_elements = *main_it;
542 StateId is_zero_removed = *main_it;
545 actual_old_dict_numbers.clear();
546 for (
StateId new_integer = 0; new_integer < current_number_old_elements;
548 std::pair<StateId, Transition> pair_temp_transition;
549 if (new_integer == 0) {
550 pair_temp_transition.first = *main_it;
553 pair_temp_transition.first = *main_it + previous;
554 previous = pair_temp_transition.first;
557 Transition temp_test;
558 if (!dict_old.
SingleDecode(pair_temp_transition.first, &temp_test)) {
559 FSTERROR() <<
"Compressor::Decode: failed";
563 pair_temp_transition.second = temp_test;
564 actual_old_dict_numbers.push_back(pair_temp_transition);
567 static const OldDictCompare old_dict_compare;
568 std::sort(actual_old_dict_numbers.begin(), actual_old_dict_numbers.end(),
572 actual_old_dict_transitions.clear();
574 new_integer < current_number_old_elements - is_zero_removed;
576 Transition temp_transition;
577 if (new_integer == 0) {
578 temp_transition.nextstate = *main_it;
581 temp_transition.nextstate = *main_it + previous;
582 previous = temp_transition.nextstate;
585 temp_transition.label = *main_it;
587 actual_old_dict_transitions.push_back(temp_transition);
589 if (is_zero_removed == 1) {
590 actual_old_dict_transitions.push_back(default_transition);
592 auto trans_it = actual_old_dict_transitions.cbegin();
593 auto dict_it = actual_old_dict_numbers.cbegin();
594 while (trans_it != actual_old_dict_transitions.cend() &&
595 dict_it != actual_old_dict_numbers.cend()) {
596 if (dict_it->first == 0) {
599 std::pair<StateId, Transition> temp_pair;
600 static const TransitionEquals transition_equals;
601 static const TransitionLessThan transition_less_than;
602 if (transition_equals(*trans_it, default_transition)) {
603 temp_pair.first = dict_it->first;
604 temp_pair.second = default_transition;
606 }
else if (transition_less_than(dict_it->second, *trans_it)) {
607 temp_pair.first = dict_it->first;
608 temp_pair.second = *trans_it;
612 temp_pair.second = *trans_it;
615 current_old_input.push_back(temp_pair);
618 while (trans_it != actual_old_dict_transitions.cend()) {
619 std::pair<StateId, Transition> temp_pair;
621 temp_pair.second = *trans_it;
623 current_old_input.push_back(temp_pair);
626 if (!dict_old.
BatchDecode(current_old_input, ¤t_old_output)) {
627 FSTERROR() <<
"Compressor::Decode: Failed";
631 for (
auto it = current_old_output.cbegin(); it != current_old_output.cend();
634 fst->
AddArc(current_state,
635 Arc(it->label, it->label, *arc_weight_it, it->nextstate));
638 fst->
AddArc(current_state,
639 Arc(it->label, it->label, Weight::One(), it->nextstate));
644 StateId number_of_final_states = *main_it;
645 if (number_of_final_states > 0) {
647 for (
StateId temp_int = 0; temp_int < number_of_final_states; ++temp_int) {
651 fst->
SetFinal(*main_it, final_weight_[temp_int]);
660 std::vector<Weight> *output) {
664 for (int64_t i = 0; i < size; ++i) {
666 output->push_back(weight);
672 std::string_view source,
675 int32_t magic_number = 0;
677 if (magic_number != kCompressMagicNumber) {
678 LOG(ERROR) <<
"Decompress: Bad compressed Fst: " << source;
681 std::unique_ptr<EncodeMapper<Arc>> encoder(
683 if (encoder ==
nullptr)
return false;
684 std::vector<bool> bool_code;
689 for (int64_t i = 0; i < data_size; ++i) {
691 for (
int j = 0; j < 8; ++j) {
692 uint8_t temp = msb & block;
693 bool_code.push_back(temp == 128);
697 std::vector<StateId> int_code;
702 if (unweighted == 0) {
703 ReadWeight(strm, &arc_weight_);
704 ReadWeight(strm, &final_weight_);
706 DecodeProcessedFst(int_code, fst, unweighted);
707 DecodeForCompress(fst, *encoder);
713 std::ostream &strm) {
714 int64_t size = input.size();
716 for (
auto it = input.begin(); it != input.end(); ++it) {
723 while (buffer_code_.size() % 8 != 0) buffer_code_.push_back(
true);
724 int64_t data_size = buffer_code_.size() / 8;
728 for (
auto it = buffer_code_.begin(); it != buffer_code_.end(); ++it) {
743 std::ostream &strm) {
746 Preprocess(fst, &processedfst, &encoder);
748 encoder.
Write(strm,
"ostream");
749 EncodeProcessedFst(processedfst, strm);
764 if (!source.empty()) {
765 fstrm.open(source, std::ios_base::out | std::ios_base::binary);
767 LOG(ERROR) <<
"Compress: Can't open file: " << source;
771 std::ostream &ostrm = fstrm.is_open() ? fstrm : std::cout;
772 if (!
Compress(fst, ostrm))
return false;
777 [[nodiscard]]
bool Decompress(std::istream &strm, std::string_view source,
787 if (!source.empty()) {
788 fstrm.open(std::string(source), std::ios_base::in | std::ios_base::binary);
790 LOG(ERROR) <<
"Decompress: Can't open file: " << source;
794 std::istream &istrm = fstrm.is_open() ? fstrm : std::cin;
795 if (!
Decompress(istrm, source.empty() ?
"standard input" : source, fst)) {
803 #endif // FST_EXTENSIONS_COMPRESS_COMPRESS_H_ constexpr int32_t kCompressMagicNumber
bool ExpandLZCode(const std::vector< std::pair< Var, Edge >> &code, std::vector< std::vector< Edge >> *expanded_code)
void ArcMap(MutableFst< A > *fst, C *mapper)
void DecodeProcessedFst(const std::vector< StateId > &input, MutableFst< Arc > *fst, bool unweighted)
void BatchEncode(const std::vector< Edge > &input, std::vector< std::pair< Var, Edge >> *output)
void DecodeForCompress(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
const SymbolTable * OutputSymbols() const
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter, bool access_only=false)
constexpr uint64_t kError
virtual void SetInputSymbols(const SymbolTable *isyms)=0
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool BatchDecode(const std::vector< std::pair< Var, Edge >> &input, std::vector< Edge > *output)
void ReadWeight(std::istream &strm, std::vector< Weight > *output)
void WriteWeight(const std::vector< Weight > &input, std::ostream &strm)
void BfsOrder(const ExpandedFst< Arc > &fst, std::vector< StateId > *order)
std::ostream & WriteType(std::ostream &strm, const T t)
typename Arc::Label Label
void EncodeProcessedFst(const ExpandedFst< Arc > &fst, std::ostream &strm)
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
static void BatchDecode(const std::vector< bool > &input, std::vector< Var > *output)
typename Arc::StateId StateId
constexpr uint8_t kEncodeLabels
virtual void SetProperties(uint64_t props, uint64_t mask)=0
static void DeltaEncode(const Var &input, std::vector< bool > *code)
bool Decompress(std::istream &strm, std::string_view source, MutableFst< Arc > *fst)
const SymbolTable * InputSymbols() const
bool Decompress(std::istream &strm, std::string_view source, MutableFst< Arc > *fst)
virtual void AddArc(StateId, const Arc &)=0
constexpr uint64_t kUnweighted
void WriteToStream(std::ostream &strm)
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
bool Write(std::ostream &strm, std::string_view source) const
std::istream & ReadType(std::istream &strm, T *t)
void WriteToBuffer(CVar input)
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
virtual StateId NumStates() const =0
void Preprocess(const Fst< Arc > &fst, MutableFst< Arc > *preprocessedfst, EncodeMapper< Arc > *encoder)
bool SingleDecode(const Var &index, Edge *output)
bool StateSort(std::vector< IntervalSet< Label >> *interval_sets, const std::vector< StateId > &order)
typename Arc::Weight Weight