FST  openfst-1.7.2
OpenFst Library
compress.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Compresses and decompresses unweighted FSTs.
5 
6 #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_
7 #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_
8 
9 #include <algorithm>
10 #include <cstdio>
11 #include <iostream>
12 #include <map>
13 #include <memory>
14 #include <queue>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include <fst/compat.h>
20 #include <fst/log.h>
23 #include <fst/encode.h>
24 #include <fst/fst.h>
25 #include <fst/mutable-fst.h>
26 #include <fst/statesort.h>
27 
28 namespace fst {
29 
30 // Identifies stream data as a vanilla compressed FST.
31 static const int32 kCompressMagicNumber = 1858869554;
32 // Identifies stream data as (probably) a Gzip file accidentally read from
33 // a vanilla stream, without gzip support.
34 static const int32 kGzipMagicNumber = 0x8b1f;
35 // Selects the two most significant bytes.
36 constexpr uint32 kGzipMask = 0xffffffff >> 16;
37 
38 namespace internal {
39 
40 // Expands a Lempel Ziv code and returns the set of code words. expanded_code[i]
41 // is the i^th Lempel Ziv codeword.
42 template <class Var, class Edge>
43 bool ExpandLZCode(const std::vector<std::pair<Var, Edge>> &code,
44  std::vector<std::vector<Edge>> *expanded_code) {
45  expanded_code->resize(code.size());
46  for (int i = 0; i < code.size(); ++i) {
47  if (code[i].first > i) {
48  LOG(ERROR) << "ExpandLZCode: Not a valid code";
49  return false;
50  }
51  if (code[i].first == 0) {
52  (*expanded_code)[i].resize(1, code[i].second);
53  } else {
54  (*expanded_code)[i].resize((*expanded_code)[code[i].first - 1].size() +
55  1);
56  std::copy((*expanded_code)[code[i].first - 1].begin(),
57  (*expanded_code)[code[i].first - 1].end(),
58  (*expanded_code)[i].begin());
59  (*expanded_code)[i][(*expanded_code)[code[i].first - 1].size()] =
60  code[i].second;
61  }
62  }
63  return true;
64 }
65 
66 } // namespace internal
67 
68 // Lempel Ziv on data structure Edge, with a less than operator
69 // EdgeLessThan and an equals operator EdgeEquals.
70 // Edge has a value defaultedge which it never takes and
71 // Edge is defined, it is initialized to defaultedge
72 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
73 class LempelZiv {
74  public:
75  LempelZiv() : dict_number_(0), default_edge_() {
76  root_.current_number = dict_number_++;
77  root_.current_edge = default_edge_;
78  decode_vector_.push_back(std::make_pair(0, default_edge_));
79  }
80  // Encodes a vector input into output
81  void BatchEncode(const std::vector<Edge> &input,
82  std::vector<std::pair<Var, Edge>> *output);
83 
84  // Decodes codedvector to output. Returns false if
85  // the index exceeds the size.
86  bool BatchDecode(const std::vector<std::pair<Var, Edge>> &input,
87  std::vector<Edge> *output);
88 
89  // Decodes a single dictionary element. Returns false
90  // if the index exceeds the size.
91  bool SingleDecode(const Var &index, Edge *output) {
92  if (index >= decode_vector_.size()) {
93  LOG(ERROR) << "LempelZiv::SingleDecode: "
94  << "Index exceeded the dictionary size";
95  return false;
96  } else {
97  *output = decode_vector_[index].second;
98  return true;
99  }
100  }
101 
103  for (auto it = (root_.next_number).begin(); it != (root_.next_number).end();
104  ++it) {
105  CleanUp(it->second);
106  }
107  }
108  // Adds a single dictionary element while decoding
109  // void AddDictElement(const std::pair<Var, Edge> &newdict) {
110  // EdgeEquals InstEdgeEquals;
111  // if (InstEdgeEquals(newdict.second, default_edge_) != 1)
112  // decode_vector_.push_back(newdict);
113  // }
114 
115  private:
116  // Node datastructure is used for encoding
117 
118  struct Node {
119  Var current_number;
120  Edge current_edge;
121  std::map<Edge, Node *, EdgeLessThan> next_number;
122  };
123 
124  void CleanUp(Node *temp) {
125  for (auto it = (temp->next_number).begin(); it != (temp->next_number).end();
126  ++it) {
127  CleanUp(it->second);
128  }
129  delete temp;
130  }
131  Node root_;
132  Var dict_number_;
133  // decode_vector_ is used for decoding
134  std::vector<std::pair<Var, Edge>> decode_vector_;
135  Edge default_edge_;
136 };
137 
138 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
140  const std::vector<Edge> &input, std::vector<std::pair<Var, Edge>> *output) {
141  for (typename std::vector<Edge>::const_iterator it = input.begin();
142  it != input.end(); ++it) {
143  Node *temp_node = &root_;
144  while (it != input.end()) {
145  auto next = (temp_node->next_number).find(*it);
146  if (next != (temp_node->next_number).end()) {
147  temp_node = next->second;
148  ++it;
149  } else {
150  break;
151  }
152  }
153  if (it == input.end() && temp_node->current_number != 0) {
154  output->push_back(
155  std::make_pair(temp_node->current_number, default_edge_));
156  } else if (it != input.end()) {
157  output->push_back(std::make_pair(temp_node->current_number, *it));
158  Node *new_node = new (Node);
159  new_node->current_number = dict_number_++;
160  new_node->current_edge = *it;
161  (temp_node->next_number)[*it] = new_node;
162  }
163  if (it == input.end()) break;
164  }
165 }
166 
167 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
169  const std::vector<std::pair<Var, Edge>> &input, std::vector<Edge> *output) {
170  for (typename std::vector<std::pair<Var, Edge>>::const_iterator it =
171  input.begin();
172  it != input.end(); ++it) {
173  std::vector<Edge> temp_output;
174  EdgeEquals InstEdgeEquals;
175  if (InstEdgeEquals(it->second, default_edge_) != 1) {
176  decode_vector_.push_back(*it);
177  temp_output.push_back(it->second);
178  }
179  Var temp_integer = it->first;
180  if (temp_integer >= decode_vector_.size()) {
181  LOG(ERROR) << "LempelZiv::BatchDecode: "
182  << "Index exceeded the dictionary size";
183  return false;
184  } else {
185  while (temp_integer != 0) {
186  temp_output.push_back(decode_vector_[temp_integer].second);
187  temp_integer = decode_vector_[temp_integer].first;
188  }
189  std::reverse(temp_output.begin(), temp_output.end());
190  output->insert(output->end(), temp_output.begin(), temp_output.end());
191  }
192  }
193  return true;
194 }
195 
196 // The main Compressor class
197 template <class Arc>
198 class Compressor {
199  public:
200  typedef typename Arc::StateId StateId;
201  typedef typename Arc::Label Label;
202  typedef typename Arc::Weight Weight;
203 
205 
206  // Compresses fst into a boolean vector code. Returns true on sucesss.
207  bool Compress(const Fst<Arc> &fst, std::ostream &strm);
208 
209  // Decompresses the boolean vector into Fst. Returns true on sucesss.
210  bool Decompress(std::istream &strm, const string &source,
211  MutableFst<Arc> *fst);
212 
213  // Finds the BFS order of a fst
214  void BfsOrder(const ExpandedFst<Arc> &fst, std::vector<StateId> *order);
215 
216  // Preprocessing step to convert fst to a isomorphic fst
217  // Returns a preproccess fst and a dictionary
218  void Preprocess(const Fst<Arc> &fst, MutableFst<Arc> *preprocessedfst,
219  EncodeMapper<Arc> *encoder);
220 
221  // Performs Lempel Ziv and outputs a stream of integers
222  // and sends it to a stream
223  void EncodeProcessedFst(const ExpandedFst<Arc> &fst, std::ostream &strm);
224 
225  // Decodes fst from the stream
226  void DecodeProcessedFst(const std::vector<StateId> &input,
227  MutableFst<Arc> *fst, bool unweighted);
228 
229  // Converts buffer_code_ to uint8 and writes to a stream.
230 
231  // Writes the boolean file to the stream
232  void WriteToStream(std::ostream &strm);
233 
234  // Writes the weights to the stream
235  void WriteWeight(const std::vector<Weight> &input, std::ostream &strm);
236 
237  void ReadWeight(std::istream &strm, std::vector<Weight> *output);
238 
239  // Same as fst::Decode without the line RmFinalEpsilon(fst)
240  void DecodeForCompress(MutableFst<Arc> *fst, const EncodeMapper<Arc> &mapper);
241 
242  // Updates the buffer_code_
243  template <class CVar>
244  void WriteToBuffer(CVar input) {
245  std::vector<bool> current_code;
246  Elias<CVar>::DeltaEncode(input, &current_code);
247  if (!buffer_code_.empty()) {
248  buffer_code_.insert(buffer_code_.end(), current_code.begin(),
249  current_code.end());
250  } else {
251  buffer_code_.assign(current_code.begin(), current_code.end());
252  }
253  }
254 
255  private:
256  struct LZLabel {
257  LZLabel() : label(0) {}
258  Label label;
259  };
260 
261  struct LabelLessThan {
262  bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const {
263  return labelone.label < labeltwo.label;
264  }
265  };
266 
267  struct LabelEquals {
268  bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const {
269  return labelone.label == labeltwo.label;
270  }
271  };
272 
273  struct Transition {
274  Transition() : nextstate(0), label(0), weight(Weight::Zero()) {}
275 
276  StateId nextstate;
277  Label label;
278  Weight weight;
279  };
280 
281  struct TransitionLessThan {
282  bool operator()(const Transition &transition_one,
283  const Transition &transition_two) const {
284  if (transition_one.nextstate == transition_two.nextstate)
285  return transition_one.label < transition_two.label;
286  else
287  return transition_one.nextstate < transition_two.nextstate;
288  }
289  } transition_less_than;
290 
291  struct TransitionEquals {
292  bool operator()(const Transition &transition_one,
293  const Transition &transition_two) const {
294  return transition_one.nextstate == transition_two.nextstate &&
295  transition_one.label == transition_two.label;
296  }
297  } transition_equals;
298 
299  struct OldDictCompare {
300  bool operator()(const std::pair<StateId, Transition> &pair_one,
301  const std::pair<StateId, Transition> &pair_two) const {
302  if ((pair_one.second).nextstate == (pair_two.second).nextstate)
303  return (pair_one.second).label < (pair_two.second).label;
304  else
305  return (pair_one.second).nextstate < (pair_two.second).nextstate;
306  }
307  } old_dict_compare;
308 
309  std::vector<bool> buffer_code_;
310  std::vector<Weight> arc_weight_;
311  std::vector<Weight> final_weight_;
312 };
313 
314 template <class Arc>
316  MutableFst<Arc> *fst, const EncodeMapper<Arc> &mapper) {
317  ArcMap(fst, EncodeMapper<Arc>(mapper, DECODE));
318  fst->SetInputSymbols(mapper.InputSymbols());
319  fst->SetOutputSymbols(mapper.OutputSymbols());
320 }
321 
322 // Compressor::BfsOrder
323 template <class Arc>
325  std::vector<StateId> *order) {
326  Arc arc;
327  StateId bfs_visit_number = 0;
328  std::queue<StateId> states_queue;
329  order->assign(fst.NumStates(), kNoStateId);
330  states_queue.push(fst.Start());
331  (*order)[fst.Start()] = bfs_visit_number++;
332  while (!states_queue.empty()) {
333  for (ArcIterator<Fst<Arc>> aiter(fst, states_queue.front()); !aiter.Done();
334  aiter.Next()) {
335  arc = aiter.Value();
336  StateId nextstate = arc.nextstate;
337  if ((*order)[nextstate] == kNoStateId) {
338  (*order)[nextstate] = bfs_visit_number++;
339  states_queue.push(nextstate);
340  }
341  }
342  states_queue.pop();
343  }
344 
345  // If the FST is unconnected, then the following
346  // code finds them
347  while (bfs_visit_number < fst.NumStates()) {
348  int unseen_state = 0;
349  for (unseen_state = 0; unseen_state < fst.NumStates(); ++unseen_state) {
350  if ((*order)[unseen_state] == kNoStateId) break;
351  }
352  states_queue.push(unseen_state);
353  (*order)[unseen_state] = bfs_visit_number++;
354  while (!states_queue.empty()) {
355  for (ArcIterator<Fst<Arc>> aiter(fst, states_queue.front());
356  !aiter.Done(); aiter.Next()) {
357  arc = aiter.Value();
358  StateId nextstate = arc.nextstate;
359  if ((*order)[nextstate] == kNoStateId) {
360  (*order)[nextstate] = bfs_visit_number++;
361  states_queue.push(nextstate);
362  }
363  }
364  states_queue.pop();
365  }
366  }
367 }
368 
369 template <class Arc>
371  MutableFst<Arc> *preprocessedfst,
372  EncodeMapper<Arc> *encoder) {
373  *preprocessedfst = fst;
374  if (!preprocessedfst->NumStates()) {
375  return;
376  }
377  // Relabels the edges and develops a dictionary
378  Encode(preprocessedfst, encoder);
379  std::vector<StateId> order;
380  // Finds the BFS sorting order of the fst
381  BfsOrder(*preprocessedfst, &order);
382  // Reorders the states according to the BFS order
383  StateSort(preprocessedfst, order);
384 }
385 
386 template <class Arc>
388  std::ostream &strm) {
389  std::vector<StateId> output;
392  std::vector<LZLabel> current_new_input;
393  std::vector<Transition> current_old_input;
394  std::vector<std::pair<StateId, LZLabel>> current_new_output;
395  std::vector<std::pair<StateId, Transition>> current_old_output;
396  std::vector<StateId> final_states;
397 
398  StateId number_of_states = fst.NumStates();
399 
400  StateId seen_states = 0;
401  // Adding the number of states
402  WriteToBuffer<StateId>(number_of_states);
403 
404  for (StateId state = 0; state < number_of_states; ++state) {
405  current_new_input.clear();
406  current_old_input.clear();
407  current_new_output.clear();
408  current_old_output.clear();
409  if (state > seen_states) ++seen_states;
410 
411  // Collecting the final states
412  if (fst.Final(state) != Weight::Zero()) {
413  final_states.push_back(state);
414  final_weight_.push_back(fst.Final(state));
415  }
416 
417  // Reading the states
418  for (ArcIterator<Fst<Arc>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
419  Arc arc = aiter.Value();
420  if (arc.nextstate > seen_states) { // RILEY: > or >= ?
421  ++seen_states;
422  LZLabel temp_label;
423  temp_label.label = arc.ilabel;
424  arc_weight_.push_back(arc.weight);
425  current_new_input.push_back(temp_label);
426  } else {
427  Transition temp_transition;
428  temp_transition.nextstate = arc.nextstate;
429  temp_transition.label = arc.ilabel;
430  temp_transition.weight = arc.weight;
431  current_old_input.push_back(temp_transition);
432  }
433  }
434  // Adding new states
435  dict_new.BatchEncode(current_new_input, &current_new_output);
436  WriteToBuffer<StateId>(current_new_output.size());
437 
438  for (auto it = current_new_output.begin(); it != current_new_output.end();
439  ++it) {
440  WriteToBuffer<StateId>(it->first);
441  WriteToBuffer<Label>((it->second).label);
442  }
443  // Adding old states by sorting and using difference coding
444  std::sort(current_old_input.begin(), current_old_input.end(),
445  transition_less_than);
446  for (auto it = current_old_input.begin(); it != current_old_input.end();
447  ++it) {
448  arc_weight_.push_back(it->weight);
449  }
450  dict_old.BatchEncode(current_old_input, &current_old_output);
451  std::vector<StateId> dict_old_temp;
452  std::vector<Transition> transition_old_temp;
453  for (auto it = current_old_output.begin(); it != current_old_output.end();
454  ++it) {
455  dict_old_temp.push_back(it->first);
456  transition_old_temp.push_back(it->second);
457  }
458  if (!transition_old_temp.empty()) {
459  if ((transition_old_temp.back()).nextstate == 0 &&
460  (transition_old_temp.back()).label == 0) {
461  transition_old_temp.pop_back();
462  }
463  }
464  std::sort(dict_old_temp.begin(), dict_old_temp.end());
465  std::sort(transition_old_temp.begin(), transition_old_temp.end(),
466  transition_less_than);
467 
468  WriteToBuffer<StateId>(dict_old_temp.size());
469  if (dict_old_temp.size() != transition_old_temp.size())
470  WriteToBuffer<int>(1);
471  else
472  WriteToBuffer<int>(0);
473 
474  StateId previous;
475  if (!dict_old_temp.empty()) {
476  WriteToBuffer<StateId>(dict_old_temp.front());
477  previous = dict_old_temp.front();
478  }
479  if (dict_old_temp.size() > 1) {
480  for (auto it = dict_old_temp.begin() + 1; it != dict_old_temp.end();
481  ++it) {
482  WriteToBuffer<StateId>(*it - previous);
483  previous = *it;
484  }
485  }
486  if (!transition_old_temp.empty()) {
487  WriteToBuffer<StateId>((transition_old_temp.front()).nextstate);
488  previous = (transition_old_temp.front()).nextstate;
489  WriteToBuffer<Label>((transition_old_temp.front()).label);
490  }
491  if (transition_old_temp.size() > 1) {
492  for (auto it = transition_old_temp.begin() + 1;
493  it != transition_old_temp.end(); ++it) {
494  WriteToBuffer<StateId>(it->nextstate - previous);
495  previous = it->nextstate;
496  WriteToBuffer<StateId>(it->label);
497  }
498  }
499  }
500  // Adding final states
501  WriteToBuffer<StateId>(final_states.size());
502  if (!final_states.empty()) {
503  for (auto it = final_states.begin(); it != final_states.end(); ++it) {
504  WriteToBuffer<StateId>(*it);
505  }
506  }
507  WriteToStream(strm);
508  uint8 unweighted = (fst.Properties(kUnweighted, true) == kUnweighted);
509  WriteType(strm, unweighted);
510  if (unweighted == 0) {
511  WriteWeight(arc_weight_, strm);
512  WriteWeight(final_weight_, strm);
513  }
514 }
515 
516 template <class Arc>
517 void Compressor<Arc>::DecodeProcessedFst(const std::vector<StateId> &input,
519  bool unweighted) {
522  std::vector<std::pair<StateId, LZLabel>> current_new_input;
523  std::vector<std::pair<StateId, Transition>> current_old_input;
524  std::vector<LZLabel> current_new_output;
525  std::vector<Transition> current_old_output;
526  std::vector<std::pair<StateId, Transition>> actual_old_dict_numbers;
527  std::vector<Transition> actual_old_dict_transitions;
528  auto arc_weight_it = arc_weight_.begin();
529  Transition default_transition;
530  StateId seen_states = 1;
531 
532  // Adding states.
533  const StateId num_states = input.front();
534  if (num_states > 0) {
535  const StateId start_state = fst->AddState();
536  fst->SetStart(start_state);
537  for (StateId state = 1; state < num_states; ++state) {
538  fst->AddState();
539  }
540  }
541 
542  typename std::vector<StateId>::const_iterator main_it = input.begin();
543  ++main_it;
544 
545  for (StateId current_state = 0; current_state < num_states; ++current_state) {
546  if (current_state >= seen_states) ++seen_states;
547  current_new_input.clear();
548  current_new_output.clear();
549  current_old_input.clear();
550  current_old_output.clear();
551  // New states
552  StateId current_number_new_elements = *main_it;
553  ++main_it;
554  for (StateId new_integer = 0; new_integer < current_number_new_elements;
555  ++new_integer) {
556  std::pair<StateId, LZLabel> temp_new_dict_element;
557  temp_new_dict_element.first = *main_it;
558  ++main_it;
559  LZLabel temp_label;
560  temp_label.label = *main_it;
561  ++main_it;
562  temp_new_dict_element.second = temp_label;
563  current_new_input.push_back(temp_new_dict_element);
564  }
565  dict_new.BatchDecode(current_new_input, &current_new_output);
566  for (auto it = current_new_output.begin(); it != current_new_output.end();
567  ++it) {
568  if (!unweighted) {
569  fst->AddArc(current_state,
570  Arc(it->label, it->label, *arc_weight_it, seen_states++));
571  ++arc_weight_it;
572  } else {
573  fst->AddArc(current_state,
574  Arc(it->label, it->label, Weight::One(), seen_states++));
575  }
576  }
577 
578  // Old states dictionary
579  StateId current_number_old_elements = *main_it;
580  ++main_it;
581  StateId is_zero_removed = *main_it;
582  ++main_it;
583  StateId previous = 0;
584  actual_old_dict_numbers.clear();
585  for (StateId new_integer = 0; new_integer < current_number_old_elements;
586  ++new_integer) {
587  std::pair<StateId, Transition> pair_temp_transition;
588  if (new_integer == 0) {
589  pair_temp_transition.first = *main_it;
590  previous = *main_it;
591  } else {
592  pair_temp_transition.first = *main_it + previous;
593  previous = pair_temp_transition.first;
594  }
595  ++main_it;
596  Transition temp_test;
597  if (!dict_old.SingleDecode(pair_temp_transition.first, &temp_test)) {
598  FSTERROR() << "Compressor::Decode: failed";
599  fst->DeleteStates();
600  fst->SetProperties(kError, kError);
601  return;
602  }
603  pair_temp_transition.second = temp_test;
604  actual_old_dict_numbers.push_back(pair_temp_transition);
605  }
606 
607  // Reordering the dictionary elements
608  std::sort(actual_old_dict_numbers.begin(), actual_old_dict_numbers.end(),
609  old_dict_compare);
610 
611  // Transitions
612  previous = 0;
613  actual_old_dict_transitions.clear();
614 
615  for (StateId new_integer = 0;
616  new_integer < current_number_old_elements - is_zero_removed;
617  ++new_integer) {
618  Transition temp_transition;
619  if (new_integer == 0) {
620  temp_transition.nextstate = *main_it;
621  previous = *main_it;
622  } else {
623  temp_transition.nextstate = *main_it + previous;
624  previous = temp_transition.nextstate;
625  }
626  ++main_it;
627  temp_transition.label = *main_it;
628  ++main_it;
629  actual_old_dict_transitions.push_back(temp_transition);
630  }
631 
632  if (is_zero_removed == 1) {
633  actual_old_dict_transitions.push_back(default_transition);
634  }
635 
636  auto trans_it = actual_old_dict_transitions.begin();
637  auto dict_it = actual_old_dict_numbers.begin();
638 
639  while (trans_it != actual_old_dict_transitions.end() &&
640  dict_it != actual_old_dict_numbers.end()) {
641  if (dict_it->first == 0) {
642  ++dict_it;
643  } else {
644  std::pair<StateId, Transition> temp_pair;
645  if (transition_equals(*trans_it, default_transition) == 1) {
646  temp_pair.first = dict_it->first;
647  temp_pair.second = default_transition;
648  ++dict_it;
649  } else if (transition_less_than(dict_it->second, *trans_it) == 1) {
650  temp_pair.first = dict_it->first;
651  temp_pair.second = *trans_it;
652  ++dict_it;
653  } else {
654  temp_pair.first = 0;
655  temp_pair.second = *trans_it;
656  }
657  ++trans_it;
658  current_old_input.push_back(temp_pair);
659  }
660  }
661  while (trans_it != actual_old_dict_transitions.end()) {
662  std::pair<StateId, Transition> temp_pair;
663  temp_pair.first = 0;
664  temp_pair.second = *trans_it;
665  ++trans_it;
666  current_old_input.push_back(temp_pair);
667  }
668 
669  // Adding old elements in the dictionary
670  if (!dict_old.BatchDecode(current_old_input, &current_old_output)) {
671  FSTERROR() << "Compressor::Decode: Failed";
672  fst->DeleteStates();
673  fst->SetProperties(kError, kError);
674  return;
675  }
676 
677  for (auto it = current_old_output.begin(); it != current_old_output.end();
678  ++it) {
679  if (!unweighted) {
680  fst->AddArc(current_state,
681  Arc(it->label, it->label, *arc_weight_it, it->nextstate));
682  ++arc_weight_it;
683  } else {
684  fst->AddArc(current_state,
685  Arc(it->label, it->label, Weight::One(), it->nextstate));
686  }
687  }
688  }
689  // Adding the final states
690  StateId number_of_final_states = *main_it;
691  if (number_of_final_states > 0) {
692  ++main_it;
693  for (StateId temp_int = 0; temp_int < number_of_final_states; ++temp_int) {
694  if (!unweighted) {
695  fst->SetFinal(*main_it, final_weight_[temp_int]);
696  } else {
697  fst->SetFinal(*main_it, Weight(0));
698  }
699  ++main_it;
700  }
701  }
702 }
703 
704 template <class Arc>
705 void Compressor<Arc>::ReadWeight(std::istream &strm,
706  std::vector<Weight> *output) {
707  int64 size;
708  Weight weight;
709  ReadType(strm, &size);
710  for (int64 i = 0; i < size; ++i) {
711  weight.Read(strm);
712  output->push_back(weight);
713  }
714 }
715 
716 template <class Arc>
717 bool Compressor<Arc>::Decompress(std::istream &strm, const string &source,
718  MutableFst<Arc> *fst) {
719  fst->DeleteStates();
720  int32 magic_number = 0;
721  ReadType(strm, &magic_number);
722  if (magic_number != kCompressMagicNumber) {
723  LOG(ERROR) << "Decompress: Bad compressed Fst: " << source;
724  // If the most significant two bytes of the magic number match the
725  // gzip magic number, then we are probably reading a gzip file as an
726  // ordinary stream.
727  if ((magic_number & kGzipMask) == kGzipMagicNumber) {
728  LOG(ERROR) << "Decompress: Fst appears to be compressed with Gzip, but "
729  "gzip decompression was not requested. Try with "
730  "the --gzip flag"
731  ".";
732  }
733  return false;
734  }
735  std::unique_ptr<EncodeMapper<Arc>> encoder(
736  EncodeMapper<Arc>::Read(strm, "Decoding", DECODE));
737  std::vector<bool> bool_code;
738  uint8 block;
739  uint8 msb = 128;
740  int64 data_size;
741  ReadType(strm, &data_size);
742  for (int64 i = 0; i < data_size; ++i) {
743  ReadType(strm, &block);
744  for (int j = 0; j < 8; ++j) {
745  uint8 temp = msb & block;
746  if (temp == 128)
747  bool_code.push_back(1);
748  else
749  bool_code.push_back(0);
750  block = block << 1;
751  }
752  }
753  std::vector<StateId> int_code;
754  Elias<StateId>::BatchDecode(bool_code, &int_code);
755  bool_code.clear();
756  uint8 unweighted;
757  ReadType(strm, &unweighted);
758  if (unweighted == 0) {
759  ReadWeight(strm, &arc_weight_);
760  ReadWeight(strm, &final_weight_);
761  }
762  DecodeProcessedFst(int_code, fst, unweighted);
763  DecodeForCompress(fst, *encoder);
764  return !fst->Properties(kError, false);
765 }
766 
767 template <class Arc>
768 void Compressor<Arc>::WriteWeight(const std::vector<Weight> &input,
769  std::ostream &strm) {
770  int64 size = input.size();
771  WriteType(strm, size);
772  for (typename std::vector<Weight>::const_iterator it = input.begin();
773  it != input.end(); ++it) {
774  it->Write(strm);
775  }
776 }
777 
778 template <class Arc>
779 void Compressor<Arc>::WriteToStream(std::ostream &strm) {
780  while (buffer_code_.size() % 8 != 0) buffer_code_.push_back(1);
781  int64 data_size = buffer_code_.size() / 8;
782  WriteType(strm, data_size);
783  std::vector<bool>::const_iterator it;
784  int64 i;
785  uint8 block;
786  for (it = buffer_code_.begin(), i = 0; it != buffer_code_.end(); ++it, ++i) {
787  if (i % 8 == 0) {
788  if (i > 0) WriteType(strm, block);
789  block = 0;
790  } else {
791  block = block << 1;
792  }
793  block |= *it;
794  }
795  WriteType(strm, block);
796 }
797 
798 template <class Arc>
799 bool Compressor<Arc>::Compress(const Fst<Arc> &fst, std::ostream &strm) {
800  VectorFst<Arc> processedfst;
801  EncodeMapper<Arc> encoder(kEncodeLabels, ENCODE);
802  Preprocess(fst, &processedfst, &encoder);
803  WriteType(strm, kCompressMagicNumber);
804  encoder.Write(strm, "encoder stream");
805  EncodeProcessedFst(processedfst, strm);
806  return true;
807 }
808 
809 // Convenience functions that call the compressor and decompressor.
810 
811 template <class Arc>
812 void Compress(const Fst<Arc> &fst, std::ostream &strm) {
813  Compressor<Arc> comp;
814  comp.Compress(fst, strm);
815 }
816 
817 // Returns true on success.
818 template <class Arc>
819 bool Compress(const Fst<Arc> &fst, const string &file_name,
820  const bool gzip = false) {
821  if (gzip) {
822  if (file_name.empty()) {
823  std::stringstream strm;
824  Compress(fst, strm);
825  OGzFile gzfile(fileno(stdout));
826  gzfile.write(strm);
827  if (!gzfile) {
828  LOG(ERROR) << "Compress: Can't write to file: stdout";
829  return false;
830  }
831  } else {
832  std::stringstream strm;
833  Compress(fst, strm);
834  OGzFile gzfile(file_name);
835  if (!gzfile) {
836  LOG(ERROR) << "Compress: Can't open file: " << file_name;
837  return false;
838  }
839  gzfile.write(strm);
840  if (!gzfile) {
841  LOG(ERROR) << "Compress: Can't write to file: " << file_name;
842  return false;
843  }
844  }
845  } else if (file_name.empty()) {
846  Compress(fst, std::cout);
847  } else {
848  std::ofstream strm(file_name,
849  std::ios_base::out | std::ios_base::binary);
850  if (!strm) {
851  LOG(ERROR) << "Compress: Can't open file: " << file_name;
852  return false;
853  }
854  Compress(fst, strm);
855  }
856  return true;
857 }
858 
859 template <class Arc>
860 void Decompress(std::istream &strm, const string &source,
861  MutableFst<Arc> *fst) {
862  Compressor<Arc> comp;
863  comp.Decompress(strm, source, fst);
864 }
865 
866 // Returns true on success.
867 template <class Arc>
868 bool Decompress(const string &file_name, MutableFst<Arc> *fst,
869  const bool gzip = false) {
870  if (gzip) {
871  if (file_name.empty()) {
872  IGzFile gzfile(fileno(stdin));
873  Decompress(*gzfile.read(), "stdin", fst);
874  if (!gzfile) {
875  LOG(ERROR) << "Decompress: Can't read from file: stdin";
876  return false;
877  }
878  } else {
879  IGzFile gzfile(file_name);
880  if (!gzfile) {
881  LOG(ERROR) << "Decompress: Can't open file: " << file_name;
882  return false;
883  }
884  Decompress(*gzfile.read(), file_name, fst);
885  if (!gzfile) {
886  LOG(ERROR) << "Decompress: Can't read from file: " << file_name;
887  return false;
888  }
889  }
890  } else if (file_name.empty()) {
891  Decompress(std::cin, "stdin", fst);
892  } else {
893  std::ifstream strm(file_name,
894  std::ios_base::in | std::ios_base::binary);
895  if (!strm) {
896  LOG(ERROR) << "Decompress: Can't open file: " << file_name;
897  return false;
898  }
899  Decompress(strm, file_name, fst);
900  }
901  return true;
902 }
903 
904 } // namespace fst
905 
906 #endif // FST_EXTENSIONS_COMPRESS_COMPRESS_H_
bool ExpandLZCode(const std::vector< std::pair< Var, Edge >> &code, std::vector< std::vector< Edge >> *expanded_code)
Definition: compress.h:43
void ArcMap(MutableFst< A > *fst, C *mapper)
Definition: arc-map.h:94
void DecodeProcessedFst(const std::vector< StateId > &input, MutableFst< Arc > *fst, bool unweighted)
Definition: compress.h:517
void BatchEncode(const std::vector< Edge > &input, std::vector< std::pair< Var, Edge >> *output)
Definition: compress.h:139
void DecodeForCompress(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
Definition: compress.h:315
const SymbolTable * OutputSymbols() const
Definition: encode.h:352
bool Write(std::ostream &strm, const string &source) const
Definition: encode.h:319
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
Definition: encode.h:419
virtual void AddArc(StateId, const Arc &arc)=0
void Decompress(std::istream &strm, const string &source, MutableFst< Arc > *fst)
Definition: compress.h:860
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
Definition: compress.h:799
void StateSort(MutableFst< Arc > *fst, const std::vector< typename Arc::StateId > &order)
Definition: statesort.h:23
virtual void SetInputSymbols(const SymbolTable *isyms)=0
#define LOG(type)
Definition: log.h:48
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool BatchDecode(const std::vector< std::pair< Var, Edge >> &input, std::vector< Edge > *output)
Definition: compress.h:168
void ReadWeight(std::istream &strm, std::vector< Weight > *output)
Definition: compress.h:705
constexpr int kNoStateId
Definition: fst.h:180
void WriteWeight(const std::vector< Weight > &input, std::ostream &strm)
Definition: compress.h:768
virtual uint64 Properties(uint64 mask, bool test) const =0
void BfsOrder(const ExpandedFst< Arc > &fst, std::vector< StateId > *order)
Definition: compress.h:324
int64_t int64
Definition: types.h:27
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:155
#define FSTERROR()
Definition: util.h:35
constexpr uint64 kUnweighted
Definition: properties.h:87
void EncodeProcessedFst(const ExpandedFst< Arc > &fst, std::ostream &strm)
Definition: compress.h:387
virtual void SetFinal(StateId, Weight)=0
static void BatchDecode(const std::vector< bool > &input, std::vector< Var > *output)
Definition: elias.h:63
uint8_t uint8
Definition: types.h:29
static void DeltaEncode(const Var &input, std::vector< bool > *code)
Definition: elias.h:45
virtual StateId Start() const =0
void write(const std::stringstream &ssbuf)
Definition: gzfile.h:82
const SymbolTable * InputSymbols() const
Definition: encode.h:350
uint32_t uint32
Definition: types.h:31
constexpr uint32 kGzipMask
Definition: compress.h:36
std::unique_ptr< std::stringstream > read()
Definition: gzfile.h:105
int32_t int32
Definition: types.h:26
Arc::StateId StateId
Definition: compress.h:200
void WriteToStream(std::ostream &strm)
Definition: compress.h:779
constexpr uint64 kError
Definition: properties.h:33
bool Decompress(std::istream &strm, const string &source, MutableFst< Arc > *fst)
Definition: compress.h:717
Arc::Weight Weight
Definition: compress.h:202
Arc::Label Label
Definition: compress.h:201
virtual StateId AddState()=0
virtual void DeleteStates(const std::vector< StateId > &)=0
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:47
void WriteToBuffer(CVar input)
Definition: compress.h:244
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
void Compress(const Fst< Arc > &fst, std::ostream &strm)
Definition: compress.h:812
virtual StateId NumStates() const =0
void Preprocess(const Fst< Arc > &fst, MutableFst< Arc > *preprocessedfst, EncodeMapper< Arc > *encoder)
Definition: compress.h:370
bool SingleDecode(const Var &index, Edge *output)
Definition: compress.h:91
virtual void SetProperties(uint64 props, uint64 mask)=0