FST  openfst-1.8.3
OpenFst Library
compress.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Compresses and decompresses unweighted FSTs.
19 
20 #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_
21 #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_
22 
23 #include <algorithm>
24 #include <cstdint>
25 #include <cstdio>
26 #include <ios>
27 #include <iostream>
28 #include <istream>
29 #include <map>
30 #include <memory>
31 #include <ostream>
32 #include <queue>
33 #include <string>
34 #include <utility>
35 #include <vector>
36 
37 #include <fst/compat.h>
38 #include <fst/log.h>
40 #include <fst/arc-map.h>
41 #include <fst/arcfilter.h>
42 #include <fst/encode.h>
43 #include <fst/expanded-fst.h>
44 #include <fstream>
45 #include <fst/fst.h>
46 #include <fst/mutable-fst.h>
47 #include <fst/properties.h>
48 #include <fst/queue.h>
49 #include <fst/statesort.h>
50 #include <fst/util.h>
51 #include <fst/vector-fst.h>
52 #include <fst/visit.h>
53 #include <string_view>
54 
55 namespace fst {
56 
57 // Identifies stream data as a vanilla compressed FST.
58 inline constexpr int32_t kCompressMagicNumber = 1858869554;
59 
60 namespace internal {
61 
62 // Expands a Lempel Ziv code and returns the set of code words where
63 // expanded_code[i] is the i^th Lempel Ziv codeword.
64 template <class Var, class Edge>
65 [[nodiscard]] bool ExpandLZCode(const std::vector<std::pair<Var, Edge>> &code,
66  std::vector<std::vector<Edge>> *expanded_code) {
67  expanded_code->resize(code.size());
68  for (int i = 0; i < code.size(); ++i) {
69  if (code[i].first > i) {
70  LOG(ERROR) << "ExpandLZCode: Not a valid code";
71  return false;
72  }
73  auto &codeword = (*expanded_code)[i];
74  if (code[i].first == 0) {
75  codeword.resize(1, code[i].second);
76  } else {
77  const auto &other_codeword = (*expanded_code)[code[i].first - 1];
78  codeword.resize(other_codeword.size() + 1);
79  std::copy(other_codeword.cbegin(), other_codeword.cend(),
80  codeword.begin());
81  codeword[other_codeword.size()] = code[i].second;
82  }
83  }
84  return true;
85 }
86 
87 } // namespace internal
88 
89 // Lempel Ziv on data structure Edge, with a less-than operator EdgeLessThan and
90 // an equals operator EdgeEquals.
91 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
92 class LempelZiv {
93  public:
94  LempelZiv() : dict_number_(0), default_edge_() {
95  root_.current_number = dict_number_++;
96  root_.current_edge = default_edge_;
97  decode_vector_.emplace_back(0, default_edge_);
98  }
99 
100  // Encodes a vector input into output.
101  void BatchEncode(const std::vector<Edge> &input,
102  std::vector<std::pair<Var, Edge>> *output);
103 
104  // Decodes codedvector to output, returning false if the index exceeds the
105  // size.
106  [[nodiscard]] bool BatchDecode(const std::vector<std::pair<Var, Edge>> &input,
107  std::vector<Edge> *output);
108 
109  // Decodes a single dictionary element, returning false if the index exceeds
110  // the size.
111  [[nodiscard]] bool SingleDecode(const Var &index, Edge *output) {
112  if (index >= decode_vector_.size()) {
113  LOG(ERROR) << "LempelZiv::SingleDecode: "
114  << "Index exceeded the dictionary size";
115  return false;
116  } else {
117  *output = decode_vector_[index].second;
118  return true;
119  }
120  }
121 
122  private:
123  struct Node {
124  Var current_number;
125  Edge current_edge;
126  std::map<Edge, std::unique_ptr<Node>, EdgeLessThan> next_number;
127  };
128 
129  Node root_;
130  Var dict_number_;
131  std::vector<std::pair<Var, Edge>> decode_vector_;
132  Edge default_edge_;
133 };
134 
135 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
137  const std::vector<Edge> &input, std::vector<std::pair<Var, Edge>> *output) {
138  for (auto it = input.cbegin(); it != input.cend(); ++it) {
139  auto *temp_node = &root_;
140  while (it != input.cend()) {
141  auto next = temp_node->next_number.find(*it);
142  if (next != temp_node->next_number.cend()) {
143  temp_node = next->second.get();
144  ++it;
145  } else {
146  break;
147  }
148  }
149  if (it == input.cend() && temp_node->current_number != 0) {
150  output->emplace_back(temp_node->current_number, default_edge_);
151  } else if (it != input.cend()) {
152  output->emplace_back(temp_node->current_number, *it);
153  auto new_node = std::make_unique<Node>();
154  new_node->current_number = dict_number_++;
155  new_node->current_edge = *it;
156  temp_node->next_number[*it] = std::move(new_node);
157  }
158  if (it == input.cend()) break;
159  }
160 }
161 
162 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
164  const std::vector<std::pair<Var, Edge>> &input, std::vector<Edge> *output) {
165  for (const auto &[var, edge] : input) {
166  std::vector<Edge> temp_output;
167  EdgeEquals InstEdgeEquals;
168  if (InstEdgeEquals(edge, default_edge_) != 1) {
169  decode_vector_.emplace_back(var, edge);
170  temp_output.push_back(edge);
171  }
172  auto temp_integer = var;
173  if (temp_integer >= decode_vector_.size()) {
174  LOG(ERROR) << "LempelZiv::BatchDecode: "
175  << "Index exceeded the dictionary size";
176  return false;
177  } else {
178  while (temp_integer != 0) {
179  temp_output.push_back(decode_vector_[temp_integer].second);
180  temp_integer = decode_vector_[temp_integer].first;
181  }
182  output->insert(output->cend(), temp_output.rbegin(), temp_output.rend());
183  }
184  }
185  return true;
186 }
187 
188 template <class Arc>
189 class Compressor {
190  public:
191  using Label = typename Arc::Label;
192  using StateId = typename Arc::StateId;
193  using Weight = typename Arc::Weight;
194 
195  Compressor() = default;
196 
197  // Compresses an FST into a boolean vector code, returning true on success.
198  [[nodiscard]] bool Compress(const Fst<Arc> &fst, std::ostream &strm);
199 
200  // Decompresses the boolean vector into an FST, returning true on success.
201  [[nodiscard]] bool Decompress(std::istream &strm, std::string_view source,
203 
204  // Computes the BFS order of a FST.
205  void BfsOrder(const ExpandedFst<Arc> &fst, std::vector<StateId> *order);
206 
207  // Preprocessing step to convert an FST to a isomorphic FST.
208  void Preprocess(const Fst<Arc> &fst, MutableFst<Arc> *preprocessedfst,
209  EncodeMapper<Arc> *encoder);
210 
211  // Performs Lempel Ziv and outputs a stream of integers.
212  void EncodeProcessedFst(const ExpandedFst<Arc> &fst, std::ostream &strm);
213 
214  // Decodes FST from the stream.
215  void DecodeProcessedFst(const std::vector<StateId> &input,
216  MutableFst<Arc> *fst, bool unweighted);
217 
218  // Writes the boolean file to the stream.
219  void WriteToStream(std::ostream &strm);
220 
221  // Writes the weights to the stream.
222  void WriteWeight(const std::vector<Weight> &input, std::ostream &strm);
223 
224  void ReadWeight(std::istream &strm, std::vector<Weight> *output);
225 
226  // Same as fst::Decode, but doesn't remove the final epsilons.
227  void DecodeForCompress(MutableFst<Arc> *fst, const EncodeMapper<Arc> &mapper);
228 
229  // Updates buffer_code_.
230  template <class CVar>
231  void WriteToBuffer(CVar input) {
232  std::vector<bool> current_code;
233  Elias<CVar>::DeltaEncode(input, &current_code);
234  buffer_code_.insert(buffer_code_.cend(), current_code.begin(),
235  current_code.end());
236  }
237 
238  private:
239  struct LZLabel {
240  LZLabel() : label(0) {}
241  Label label;
242  };
243 
244  struct LabelLessThan {
245  bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const {
246  return labelone.label < labeltwo.label;
247  }
248  };
249 
250  struct LabelEquals {
251  bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const {
252  return labelone.label == labeltwo.label;
253  }
254  };
255 
256  struct Transition {
257  Transition() : nextstate(0), label(0), weight(Weight::Zero()) {}
258 
259  StateId nextstate;
260  Label label;
261  Weight weight;
262  };
263 
264  struct TransitionLessThan {
265  bool operator()(const Transition &transition_one,
266  const Transition &transition_two) const {
267  if (transition_one.nextstate == transition_two.nextstate) {
268  return transition_one.label < transition_two.label;
269  } else {
270  return transition_one.nextstate < transition_two.nextstate;
271  }
272  }
273  };
274 
275  struct TransitionEquals {
276  bool operator()(const Transition &transition_one,
277  const Transition &transition_two) const {
278  return transition_one.nextstate == transition_two.nextstate &&
279  transition_one.label == transition_two.label;
280  }
281  };
282 
283  struct OldDictCompare {
284  bool operator()(const std::pair<StateId, Transition> &pair_one,
285  const std::pair<StateId, Transition> &pair_two) const {
286  if (pair_one.second.nextstate == pair_two.second.nextstate) {
287  return pair_one.second.label < pair_two.second.label;
288  } else {
289  return pair_one.second.nextstate < pair_two.second.nextstate;
290  }
291  }
292  };
293 
294  std::vector<bool> buffer_code_;
295  std::vector<Weight> arc_weight_;
296  std::vector<Weight> final_weight_;
297 };
298 
299 template <class Arc>
301  const EncodeMapper<Arc> &mapper) {
302  ArcMap(fst, EncodeMapper<Arc>(mapper, DECODE));
303  fst->SetInputSymbols(mapper.InputSymbols());
304  fst->SetOutputSymbols(mapper.OutputSymbols());
305 }
306 
307 template <class Arc>
309  std::vector<StateId> *order) {
310  class BfsVisitor {
311  public:
312  // Requires order->size() >= fst.NumStates().
313  explicit BfsVisitor(std::vector<StateId> *order) : order_(order) {}
314 
315  void InitVisit(const Fst<Arc> &fst) {}
316 
317  bool InitState(StateId s, StateId) {
318  order_->at(s) = num_bfs_states_++;
319  return true;
320  }
321 
322  bool WhiteArc(StateId s, const Arc &arc) { return true; }
323  bool GreyArc(StateId s, const Arc &arc) { return true; }
324  bool BlackArc(StateId s, const Arc &arc) { return true; }
325  void FinishState(StateId s) {}
326  void FinishVisit() {}
327 
328  private:
329  std::vector<StateId> *order_ = nullptr;
330  StateId num_bfs_states_ = 0;
331  };
332 
333  order->assign(fst.NumStates(), kNoStateId);
334  BfsVisitor visitor(order);
335  FifoQueue<StateId> queue;
336  Visit(fst, &visitor, &queue, AnyArcFilter<Arc>());
337 }
338 
339 template <class Arc>
341  MutableFst<Arc> *preprocessedfst,
342  EncodeMapper<Arc> *encoder) {
343  *preprocessedfst = fst;
344  if (!preprocessedfst->NumStates()) return;
345  // Relabels the edges and develops a dictionary.
346  Encode(preprocessedfst, encoder);
347  std::vector<StateId> order;
348  // Finds the BFS sorting order of the FST.
349  BfsOrder(*preprocessedfst, &order);
350  // Reorders the states according to the BFS order.
351  StateSort(preprocessedfst, order);
352 }
353 
354 template <class Arc>
356  std::ostream &strm) {
357  std::vector<StateId> output;
360  std::vector<LZLabel> current_new_input;
361  std::vector<Transition> current_old_input;
362  std::vector<std::pair<StateId, LZLabel>> current_new_output;
363  std::vector<std::pair<StateId, Transition>> current_old_output;
364  std::vector<StateId> final_states;
365  const auto number_of_states = fst.NumStates();
366  StateId seen_states = 0;
367  // Adds the number of states.
368  WriteToBuffer<StateId>(number_of_states);
369  for (StateId state = 0; state < number_of_states; ++state) {
370  current_new_input.clear();
371  current_old_input.clear();
372  current_new_output.clear();
373  current_old_output.clear();
374  if (state > seen_states) ++seen_states;
375  // Collects the final states.
376  if (fst.Final(state) != Weight::Zero()) {
377  final_states.push_back(state);
378  final_weight_.push_back(fst.Final(state));
379  }
380  // Reads the states.
381  for (ArcIterator<Fst<Arc>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
382  const auto &arc = aiter.Value();
383  if (arc.nextstate > seen_states) { // RILEY: > or >= ?
384  ++seen_states;
385  LZLabel temp_label;
386  temp_label.label = arc.ilabel;
387  arc_weight_.push_back(arc.weight);
388  current_new_input.push_back(temp_label);
389  } else {
390  Transition temp_transition;
391  temp_transition.nextstate = arc.nextstate;
392  temp_transition.label = arc.ilabel;
393  temp_transition.weight = arc.weight;
394  current_old_input.push_back(temp_transition);
395  }
396  }
397  // Adds new states.
398  dict_new.BatchEncode(current_new_input, &current_new_output);
399  WriteToBuffer<StateId>(current_new_output.size());
400  for (auto it = current_new_output.cbegin(); it != current_new_output.cend();
401  ++it) {
402  WriteToBuffer<StateId>(it->first);
403  WriteToBuffer<Label>((it->second).label);
404  }
405  // Adds old states by sorting and using difference coding.
406  static const TransitionLessThan transition_less_than;
407  std::sort(current_old_input.begin(), current_old_input.end(),
408  transition_less_than);
409  for (auto it = current_old_input.begin(); it != current_old_input.end();
410  ++it) {
411  arc_weight_.push_back(it->weight);
412  }
413  dict_old.BatchEncode(current_old_input, &current_old_output);
414  std::vector<StateId> dict_old_temp;
415  std::vector<Transition> transition_old_temp;
416  for (auto it = current_old_output.begin(); it != current_old_output.end();
417  ++it) {
418  dict_old_temp.push_back(it->first);
419  transition_old_temp.push_back(it->second);
420  }
421  if (!transition_old_temp.empty()) {
422  if ((transition_old_temp.back()).nextstate == 0 &&
423  (transition_old_temp.back()).label == 0) {
424  transition_old_temp.pop_back();
425  }
426  }
427  std::sort(dict_old_temp.begin(), dict_old_temp.end());
428  std::sort(transition_old_temp.begin(), transition_old_temp.end(),
429  transition_less_than);
430  WriteToBuffer<StateId>(dict_old_temp.size());
431  if (dict_old_temp.size() == transition_old_temp.size()) {
432  WriteToBuffer<int>(0);
433  } else {
434  WriteToBuffer<int>(1);
435  }
436  StateId previous;
437  if (!dict_old_temp.empty()) {
438  WriteToBuffer<StateId>(dict_old_temp.front());
439  previous = dict_old_temp.front();
440  }
441  if (dict_old_temp.size() > 1) {
442  for (auto it = dict_old_temp.begin() + 1; it != dict_old_temp.end();
443  ++it) {
444  WriteToBuffer<StateId>(*it - previous);
445  previous = *it;
446  }
447  }
448  if (!transition_old_temp.empty()) {
449  WriteToBuffer<StateId>((transition_old_temp.front()).nextstate);
450  previous = transition_old_temp.front().nextstate;
451  WriteToBuffer<Label>(transition_old_temp.front().label);
452  }
453  if (transition_old_temp.size() > 1) {
454  for (auto it = transition_old_temp.begin() + 1;
455  it != transition_old_temp.end(); ++it) {
456  WriteToBuffer<StateId>(it->nextstate - previous);
457  previous = it->nextstate;
458  WriteToBuffer<StateId>(it->label);
459  }
460  }
461  }
462  // Adds final states.
463  WriteToBuffer<StateId>(final_states.size());
464  if (!final_states.empty()) {
465  for (auto it = final_states.begin(); it != final_states.end(); ++it) {
466  WriteToBuffer<StateId>(*it);
467  }
468  }
469  WriteToStream(strm);
470  const uint8_t unweighted = fst.Properties(kUnweighted, true) == kUnweighted;
471  WriteType(strm, unweighted);
472  if (unweighted == 0) {
473  WriteWeight(arc_weight_, strm);
474  WriteWeight(final_weight_, strm);
475  }
476 }
477 
478 template <class Arc>
479 void Compressor<Arc>::DecodeProcessedFst(const std::vector<StateId> &input,
481  bool unweighted) {
484  std::vector<std::pair<StateId, LZLabel>> current_new_input;
485  std::vector<std::pair<StateId, Transition>> current_old_input;
486  std::vector<LZLabel> current_new_output;
487  std::vector<Transition> current_old_output;
488  std::vector<std::pair<StateId, Transition>> actual_old_dict_numbers;
489  std::vector<Transition> actual_old_dict_transitions;
490  auto arc_weight_it = arc_weight_.begin();
491  Transition default_transition;
492  StateId seen_states = 1;
493  // Adds states..
494  const StateId num_states = input.front();
495  if (num_states > 0) {
496  const StateId start_state = fst->AddState();
497  fst->SetStart(start_state);
498  for (StateId state = 1; state < num_states; ++state) {
499  fst->AddState();
500  }
501  }
502  auto main_it = input.cbegin();
503  ++main_it;
504  for (StateId current_state = 0; current_state < num_states; ++current_state) {
505  if (current_state >= seen_states) ++seen_states;
506  current_new_input.clear();
507  current_new_output.clear();
508  current_old_input.clear();
509  current_old_output.clear();
510  // New states.
511  StateId current_number_new_elements = *main_it;
512  ++main_it;
513  for (StateId new_integer = 0; new_integer < current_number_new_elements;
514  ++new_integer) {
515  std::pair<StateId, LZLabel> temp_new_dict_element;
516  temp_new_dict_element.first = *main_it;
517  ++main_it;
518  LZLabel temp_label;
519  temp_label.label = *main_it;
520  ++main_it;
521  temp_new_dict_element.second = temp_label;
522  current_new_input.push_back(temp_new_dict_element);
523  }
524  if (!dict_new.BatchDecode(current_new_input, &current_new_output)) {
525  FSTERROR() << "Compressor::Decode: failed";
526  fst->SetProperties(kError, kError);
527  return;
528  }
529  for (const auto &label : current_new_output) {
530  if (!unweighted) {
531  fst->AddArc(current_state,
532  Arc(label.label, label.label, *arc_weight_it, seen_states));
533  ++arc_weight_it;
534  } else {
535  fst->AddArc(current_state,
536  Arc(label.label, label.label, Weight::One(), seen_states));
537  }
538  ++seen_states;
539  }
540  StateId current_number_old_elements = *main_it;
541  ++main_it;
542  StateId is_zero_removed = *main_it;
543  ++main_it;
544  StateId previous = 0;
545  actual_old_dict_numbers.clear();
546  for (StateId new_integer = 0; new_integer < current_number_old_elements;
547  ++new_integer) {
548  std::pair<StateId, Transition> pair_temp_transition;
549  if (new_integer == 0) {
550  pair_temp_transition.first = *main_it;
551  previous = *main_it;
552  } else {
553  pair_temp_transition.first = *main_it + previous;
554  previous = pair_temp_transition.first;
555  }
556  ++main_it;
557  Transition temp_test;
558  if (!dict_old.SingleDecode(pair_temp_transition.first, &temp_test)) {
559  FSTERROR() << "Compressor::Decode: failed";
560  fst->SetProperties(kError, kError);
561  return;
562  }
563  pair_temp_transition.second = temp_test;
564  actual_old_dict_numbers.push_back(pair_temp_transition);
565  }
566  // Reorders the dictionary elements.
567  static const OldDictCompare old_dict_compare;
568  std::sort(actual_old_dict_numbers.begin(), actual_old_dict_numbers.end(),
569  old_dict_compare);
570  // Transitions.
571  previous = 0;
572  actual_old_dict_transitions.clear();
573  for (StateId new_integer = 0;
574  new_integer < current_number_old_elements - is_zero_removed;
575  ++new_integer) {
576  Transition temp_transition;
577  if (new_integer == 0) {
578  temp_transition.nextstate = *main_it;
579  previous = *main_it;
580  } else {
581  temp_transition.nextstate = *main_it + previous;
582  previous = temp_transition.nextstate;
583  }
584  ++main_it;
585  temp_transition.label = *main_it;
586  ++main_it;
587  actual_old_dict_transitions.push_back(temp_transition);
588  }
589  if (is_zero_removed == 1) {
590  actual_old_dict_transitions.push_back(default_transition);
591  }
592  auto trans_it = actual_old_dict_transitions.cbegin();
593  auto dict_it = actual_old_dict_numbers.cbegin();
594  while (trans_it != actual_old_dict_transitions.cend() &&
595  dict_it != actual_old_dict_numbers.cend()) {
596  if (dict_it->first == 0) {
597  ++dict_it;
598  } else {
599  std::pair<StateId, Transition> temp_pair;
600  static const TransitionEquals transition_equals;
601  static const TransitionLessThan transition_less_than;
602  if (transition_equals(*trans_it, default_transition)) {
603  temp_pair.first = dict_it->first;
604  temp_pair.second = default_transition;
605  ++dict_it;
606  } else if (transition_less_than(dict_it->second, *trans_it)) {
607  temp_pair.first = dict_it->first;
608  temp_pair.second = *trans_it;
609  ++dict_it;
610  } else {
611  temp_pair.first = 0;
612  temp_pair.second = *trans_it;
613  }
614  ++trans_it;
615  current_old_input.push_back(temp_pair);
616  }
617  }
618  while (trans_it != actual_old_dict_transitions.cend()) {
619  std::pair<StateId, Transition> temp_pair;
620  temp_pair.first = 0;
621  temp_pair.second = *trans_it;
622  ++trans_it;
623  current_old_input.push_back(temp_pair);
624  }
625  // Adds old elements in the dictionary.
626  if (!dict_old.BatchDecode(current_old_input, &current_old_output)) {
627  FSTERROR() << "Compressor::Decode: Failed";
628  fst->SetProperties(kError, kError);
629  return;
630  }
631  for (auto it = current_old_output.cbegin(); it != current_old_output.cend();
632  ++it) {
633  if (!unweighted) {
634  fst->AddArc(current_state,
635  Arc(it->label, it->label, *arc_weight_it, it->nextstate));
636  ++arc_weight_it;
637  } else {
638  fst->AddArc(current_state,
639  Arc(it->label, it->label, Weight::One(), it->nextstate));
640  }
641  }
642  }
643  // Adds the final states.
644  StateId number_of_final_states = *main_it;
645  if (number_of_final_states > 0) {
646  ++main_it;
647  for (StateId temp_int = 0; temp_int < number_of_final_states; ++temp_int) {
648  if (unweighted) {
649  fst->SetFinal(*main_it, Weight(0));
650  } else {
651  fst->SetFinal(*main_it, final_weight_[temp_int]);
652  }
653  ++main_it;
654  }
655  }
656 }
657 
658 template <class Arc>
659 void Compressor<Arc>::ReadWeight(std::istream &strm,
660  std::vector<Weight> *output) {
661  int64_t size;
662  Weight weight;
663  ReadType(strm, &size);
664  for (int64_t i = 0; i < size; ++i) {
665  weight.Read(strm);
666  output->push_back(weight);
667  }
668 }
669 
670 template <class Arc>
671 [[nodiscard]] bool Compressor<Arc>::Decompress(std::istream &strm,
672  std::string_view source,
673  MutableFst<Arc> *fst) {
674  fst->DeleteStates();
675  int32_t magic_number = 0;
676  ReadType(strm, &magic_number);
677  if (magic_number != kCompressMagicNumber) {
678  LOG(ERROR) << "Decompress: Bad compressed Fst: " << source;
679  return false;
680  }
681  std::unique_ptr<EncodeMapper<Arc>> encoder(
682  EncodeMapper<Arc>::Read(strm, "Decoding", DECODE));
683  if (encoder == nullptr) return false;
684  std::vector<bool> bool_code;
685  uint8_t block;
686  uint8_t msb = 128;
687  int64_t data_size;
688  ReadType(strm, &data_size);
689  for (int64_t i = 0; i < data_size; ++i) {
690  ReadType(strm, &block);
691  for (int j = 0; j < 8; ++j) {
692  uint8_t temp = msb & block;
693  bool_code.push_back(temp == 128);
694  block = block << 1;
695  }
696  }
697  std::vector<StateId> int_code;
698  Elias<StateId>::BatchDecode(bool_code, &int_code);
699  bool_code.clear();
700  uint8_t unweighted;
701  ReadType(strm, &unweighted);
702  if (unweighted == 0) {
703  ReadWeight(strm, &arc_weight_);
704  ReadWeight(strm, &final_weight_);
705  }
706  DecodeProcessedFst(int_code, fst, unweighted);
707  DecodeForCompress(fst, *encoder);
708  return !fst->Properties(kError, false);
709 }
710 
711 template <class Arc>
712 void Compressor<Arc>::WriteWeight(const std::vector<Weight> &input,
713  std::ostream &strm) {
714  int64_t size = input.size();
715  WriteType(strm, size);
716  for (auto it = input.begin(); it != input.end(); ++it) {
717  it->Write(strm);
718  }
719 }
720 
721 template <class Arc>
722 void Compressor<Arc>::WriteToStream(std::ostream &strm) {
723  while (buffer_code_.size() % 8 != 0) buffer_code_.push_back(true);
724  int64_t data_size = buffer_code_.size() / 8;
725  WriteType(strm, data_size);
726  int64_t i = 0;
727  uint8_t block;
728  for (auto it = buffer_code_.begin(); it != buffer_code_.end(); ++it) {
729  if (i % 8 == 0) {
730  if (i > 0) WriteType(strm, block);
731  block = 0;
732  } else {
733  block = block << 1;
734  }
735  block |= *it;
736  ++i;
737  }
738  WriteType(strm, block);
739 }
740 
741 template <class Arc>
742 [[nodiscard]] bool Compressor<Arc>::Compress(const Fst<Arc> &fst,
743  std::ostream &strm) {
744  VectorFst<Arc> processedfst;
746  Preprocess(fst, &processedfst, &encoder);
747  WriteType(strm, kCompressMagicNumber);
748  encoder.Write(strm, "ostream");
749  EncodeProcessedFst(processedfst, strm);
750  return true;
751 }
752 
753 // Convenience functions that call the compressor and decompressor.
754 
755 template <class Arc>
756 [[nodiscard]] bool Compress(const Fst<Arc> &fst, std::ostream &strm) {
757  Compressor<Arc> comp;
758  return comp.Compress(fst, strm);
759 }
760 
761 template <class Arc>
762 [[nodiscard]] bool Compress(const Fst<Arc> &fst, const std::string &source) {
763  std::ofstream fstrm;
764  if (!source.empty()) {
765  fstrm.open(source, std::ios_base::out | std::ios_base::binary);
766  if (!fstrm) {
767  LOG(ERROR) << "Compress: Can't open file: " << source;
768  return false;
769  }
770  }
771  std::ostream &ostrm = fstrm.is_open() ? fstrm : std::cout;
772  if (!Compress(fst, ostrm)) return false;
773  return !!ostrm;
774 }
775 
776 template <class Arc>
777 [[nodiscard]] bool Decompress(std::istream &strm, std::string_view source,
778  MutableFst<Arc> *fst) {
779  Compressor<Arc> comp;
780  return comp.Decompress(strm, source, fst);
781 }
782 
783 // Returns true on success.
784 template <class Arc>
785 [[nodiscard]] bool Decompress(std::string_view source, MutableFst<Arc> *fst) {
786  std::ifstream fstrm;
787  if (!source.empty()) {
788  fstrm.open(std::string(source), std::ios_base::in | std::ios_base::binary);
789  if (!fstrm) {
790  LOG(ERROR) << "Decompress: Can't open file: " << source;
791  return false;
792  }
793  }
794  std::istream &istrm = fstrm.is_open() ? fstrm : std::cin;
795  if (!Decompress(istrm, source.empty() ? "standard input" : source, fst)) {
796  return false;
797  }
798  return !!istrm;
799 }
800 
801 } // namespace fst
802 
803 #endif // FST_EXTENSIONS_COMPRESS_COMPRESS_H_
constexpr int32_t kCompressMagicNumber
Definition: compress.h:58
bool ExpandLZCode(const std::vector< std::pair< Var, Edge >> &code, std::vector< std::vector< Edge >> *expanded_code)
Definition: compress.h:65
void ArcMap(MutableFst< A > *fst, C *mapper)
Definition: arc-map.h:120
void DecodeProcessedFst(const std::vector< StateId > &input, MutableFst< Arc > *fst, bool unweighted)
Definition: compress.h:479
void BatchEncode(const std::vector< Edge > &input, std::vector< std::pair< Var, Edge >> *output)
Definition: compress.h:136
void DecodeForCompress(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
Definition: compress.h:300
virtual uint64_t Properties(uint64_t mask, bool test) const =0
const SymbolTable * OutputSymbols() const
Definition: encode.h:427
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
Definition: encode.h:500
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
Definition: compress.h:742
void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter, bool access_only=false)
Definition: visit.h:77
constexpr uint64_t kError
Definition: properties.h:52
virtual void SetInputSymbols(const SymbolTable *isyms)=0
#define LOG(type)
Definition: log.h:53
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool BatchDecode(const std::vector< std::pair< Var, Edge >> &input, std::vector< Edge > *output)
Definition: compress.h:163
void ReadWeight(std::istream &strm, std::vector< Weight > *output)
Definition: compress.h:659
constexpr int kNoStateId
Definition: fst.h:196
void WriteWeight(const std::vector< Weight > &input, std::ostream &strm)
Definition: compress.h:712
void BfsOrder(const ExpandedFst< Arc > &fst, std::vector< StateId > *order)
Definition: compress.h:308
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
#define FSTERROR()
Definition: util.h:56
typename Arc::Label Label
Definition: compress.h:191
void EncodeProcessedFst(const ExpandedFst< Arc > &fst, std::ostream &strm)
Definition: compress.h:355
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
Definition: compress.h:756
static void BatchDecode(const std::vector< bool > &input, std::vector< Var > *output)
Definition: elias.h:77
typename Arc::StateId StateId
Definition: compress.h:192
constexpr uint8_t kEncodeLabels
Definition: encode.h:55
virtual void SetProperties(uint64_t props, uint64_t mask)=0
static void DeltaEncode(const Var &input, std::vector< bool > *code)
Definition: elias.h:59
bool Decompress(std::istream &strm, std::string_view source, MutableFst< Arc > *fst)
Definition: compress.h:671
const SymbolTable * InputSymbols() const
Definition: encode.h:425
bool Decompress(std::istream &strm, std::string_view source, MutableFst< Arc > *fst)
Definition: compress.h:777
virtual void AddArc(StateId, const Arc &)=0
constexpr uint64_t kUnweighted
Definition: properties.h:106
void WriteToStream(std::ostream &strm)
Definition: compress.h:722
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
bool Write(std::ostream &strm, std::string_view source) const
Definition: encode.h:411
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
void WriteToBuffer(CVar input)
Definition: compress.h:231
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
virtual StateId NumStates() const =0
void Preprocess(const Fst< Arc > &fst, MutableFst< Arc > *preprocessedfst, EncodeMapper< Arc > *encoder)
Definition: compress.h:340
bool SingleDecode(const Var &index, Edge *output)
Definition: compress.h:111
bool StateSort(std::vector< IntervalSet< Label >> *interval_sets, const std::vector< StateId > &order)
typename Arc::Weight Weight
Definition: compress.h:193