FST  openfst-1.8.2.post1
OpenFst Library
compress.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Compresses and decompresses unweighted FSTs.
19 
20 #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_
21 #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_
22 
23 #include <algorithm>
24 #include <cstdint>
25 #include <cstdio>
26 #include <ios>
27 #include <iostream>
28 #include <istream>
29 #include <map>
30 #include <memory>
31 #include <ostream>
32 #include <queue>
33 #include <string>
34 #include <vector>
35 
36 #include <fst/compat.h>
37 #include <fst/log.h>
39 #include <fst/encode.h>
40 #include <fstream>
41 #include <fst/fst.h>
42 #include <fst/mutable-fst.h>
43 #include <fst/queue.h>
44 #include <fst/statesort.h>
45 #include <fst/visit.h>
46 
47 namespace fst {
48 
49 // Identifies stream data as a vanilla compressed FST.
50 inline constexpr int32_t kCompressMagicNumber = 1858869554;
51 
52 namespace internal {
53 
54 // Expands a Lempel Ziv code and returns the set of code words where
55 // expanded_code[i] is the i^th Lempel Ziv codeword.
56 template <class Var, class Edge>
57 bool ExpandLZCode(const std::vector<std::pair<Var, Edge>> &code,
58  std::vector<std::vector<Edge>> *expanded_code) {
59  expanded_code->resize(code.size());
60  for (int i = 0; i < code.size(); ++i) {
61  if (code[i].first > i) {
62  LOG(ERROR) << "ExpandLZCode: Not a valid code";
63  return false;
64  }
65  auto &codeword = (*expanded_code)[i];
66  if (code[i].first == 0) {
67  codeword.resize(1, code[i].second);
68  } else {
69  const auto &other_codeword = (*expanded_code)[code[i].first - 1];
70  codeword.resize(other_codeword.size() + 1);
71  std::copy(other_codeword.cbegin(), other_codeword.cend(),
72  codeword.begin());
73  codeword[other_codeword.size()] = code[i].second;
74  }
75  }
76  return true;
77 }
78 
79 } // namespace internal
80 
81 // Lempel Ziv on data structure Edge, with a less-than operator EdgeLessThan and
82 // an equals operator EdgeEquals.
83 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
84 class LempelZiv {
85  public:
86  LempelZiv() : dict_number_(0), default_edge_() {
87  root_.current_number = dict_number_++;
88  root_.current_edge = default_edge_;
89  decode_vector_.emplace_back(0, default_edge_);
90  }
91 
92  // Encodes a vector input into output.
93  void BatchEncode(const std::vector<Edge> &input,
94  std::vector<std::pair<Var, Edge>> *output);
95 
96  // Decodes codedvector to output, returning false if the index exceeds the
97  // size.
98  bool BatchDecode(const std::vector<std::pair<Var, Edge>> &input,
99  std::vector<Edge> *output);
100 
101  // Decodes a single dictionary element, returning false if the index exceeds
102  // the size.
103  bool SingleDecode(const Var &index, Edge *output) {
104  if (index >= decode_vector_.size()) {
105  LOG(ERROR) << "LempelZiv::SingleDecode: "
106  << "Index exceeded the dictionary size";
107  return false;
108  } else {
109  *output = decode_vector_[index].second;
110  return true;
111  }
112  }
113 
114  private:
115  struct Node {
116  Var current_number;
117  Edge current_edge;
118  std::map<Edge, std::unique_ptr<Node>, EdgeLessThan> next_number;
119  };
120 
121  Node root_;
122  Var dict_number_;
123  std::vector<std::pair<Var, Edge>> decode_vector_;
124  Edge default_edge_;
125 };
126 
127 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
129  const std::vector<Edge> &input, std::vector<std::pair<Var, Edge>> *output) {
130  for (auto it = input.cbegin(); it != input.cend(); ++it) {
131  auto *temp_node = &root_;
132  while (it != input.cend()) {
133  auto next = temp_node->next_number.find(*it);
134  if (next != temp_node->next_number.cend()) {
135  temp_node = next->second.get();
136  ++it;
137  } else {
138  break;
139  }
140  }
141  if (it == input.cend() && temp_node->current_number != 0) {
142  output->emplace_back(temp_node->current_number, default_edge_);
143  } else if (it != input.cend()) {
144  output->emplace_back(temp_node->current_number, *it);
145  auto new_node = std::make_unique<Node>();
146  new_node->current_number = dict_number_++;
147  new_node->current_edge = *it;
148  temp_node->next_number[*it] = std::move(new_node);
149  }
150  if (it == input.cend()) break;
151  }
152 }
153 
154 template <class Var, class Edge, class EdgeLessThan, class EdgeEquals>
156  const std::vector<std::pair<Var, Edge>> &input, std::vector<Edge> *output) {
157  for (const auto &[var, edge] : input) {
158  std::vector<Edge> temp_output;
159  EdgeEquals InstEdgeEquals;
160  if (InstEdgeEquals(edge, default_edge_) != 1) {
161  decode_vector_.emplace_back(var, edge);
162  temp_output.push_back(edge);
163  }
164  auto temp_integer = var;
165  if (temp_integer >= decode_vector_.size()) {
166  LOG(ERROR) << "LempelZiv::BatchDecode: "
167  << "Index exceeded the dictionary size";
168  return false;
169  } else {
170  while (temp_integer != 0) {
171  temp_output.push_back(decode_vector_[temp_integer].second);
172  temp_integer = decode_vector_[temp_integer].first;
173  }
174  output->insert(output->cend(), temp_output.rbegin(), temp_output.rend());
175  }
176  }
177  return true;
178 }
179 
180 template <class Arc>
181 class Compressor {
182  public:
183  using Label = typename Arc::Label;
184  using StateId = typename Arc::StateId;
185  using Weight = typename Arc::Weight;
186 
187  Compressor() = default;
188 
189  // Compresses an FST into a boolean vector code, returning true on success.
190  bool Compress(const Fst<Arc> &fst, std::ostream &strm);
191 
192  // Decompresses the boolean vector into an FST, returning true on success.
193  bool Decompress(std::istream &strm, const std::string &source,
195 
196  // Computes the BFS order of a FST.
197  void BfsOrder(const ExpandedFst<Arc> &fst, std::vector<StateId> *order);
198 
199  // Preprocessing step to convert an FST to a isomorphic FST.
200  void Preprocess(const Fst<Arc> &fst, MutableFst<Arc> *preprocessedfst,
201  EncodeMapper<Arc> *encoder);
202 
203  // Performs Lempel Ziv and outputs a stream of integers.
204  void EncodeProcessedFst(const ExpandedFst<Arc> &fst, std::ostream &strm);
205 
206  // Decodes FST from the stream.
207  void DecodeProcessedFst(const std::vector<StateId> &input,
208  MutableFst<Arc> *fst, bool unweighted);
209 
210  // Writes the boolean file to the stream.
211  void WriteToStream(std::ostream &strm);
212 
213  // Writes the weights to the stream.
214  void WriteWeight(const std::vector<Weight> &input, std::ostream &strm);
215 
216  void ReadWeight(std::istream &strm, std::vector<Weight> *output);
217 
218  // Same as fst::Decode, but doesn't remove the final epsilons.
219  void DecodeForCompress(MutableFst<Arc> *fst, const EncodeMapper<Arc> &mapper);
220 
221  // Updates buffer_code_.
222  template <class CVar>
223  void WriteToBuffer(CVar input) {
224  std::vector<bool> current_code;
225  Elias<CVar>::DeltaEncode(input, &current_code);
226  buffer_code_.insert(buffer_code_.cend(), current_code.begin(),
227  current_code.end());
228  }
229 
230  private:
231  struct LZLabel {
232  LZLabel() : label(0) {}
233  Label label;
234  };
235 
236  struct LabelLessThan {
237  bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const {
238  return labelone.label < labeltwo.label;
239  }
240  };
241 
242  struct LabelEquals {
243  bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const {
244  return labelone.label == labeltwo.label;
245  }
246  };
247 
248  struct Transition {
249  Transition() : nextstate(0), label(0), weight(Weight::Zero()) {}
250 
251  StateId nextstate;
252  Label label;
253  Weight weight;
254  };
255 
256  struct TransitionLessThan {
257  bool operator()(const Transition &transition_one,
258  const Transition &transition_two) const {
259  if (transition_one.nextstate == transition_two.nextstate) {
260  return transition_one.label < transition_two.label;
261  } else {
262  return transition_one.nextstate < transition_two.nextstate;
263  }
264  }
265  };
266 
267  struct TransitionEquals {
268  bool operator()(const Transition &transition_one,
269  const Transition &transition_two) const {
270  return transition_one.nextstate == transition_two.nextstate &&
271  transition_one.label == transition_two.label;
272  }
273  };
274 
275  struct OldDictCompare {
276  bool operator()(const std::pair<StateId, Transition> &pair_one,
277  const std::pair<StateId, Transition> &pair_two) const {
278  if (pair_one.second.nextstate == pair_two.second.nextstate) {
279  return pair_one.second.label < pair_two.second.label;
280  } else {
281  return pair_one.second.nextstate < pair_two.second.nextstate;
282  }
283  }
284  };
285 
286  std::vector<bool> buffer_code_;
287  std::vector<Weight> arc_weight_;
288  std::vector<Weight> final_weight_;
289 };
290 
291 template <class Arc>
293  const EncodeMapper<Arc> &mapper) {
294  ArcMap(fst, EncodeMapper<Arc>(mapper, DECODE));
295  fst->SetInputSymbols(mapper.InputSymbols());
296  fst->SetOutputSymbols(mapper.OutputSymbols());
297 }
298 
299 template <class Arc>
301  std::vector<StateId> *order) {
302  class BfsVisitor {
303  public:
304  // Requires order->size() >= fst.NumStates().
305  explicit BfsVisitor(std::vector<StateId> *order) : order_(order) {}
306 
307  void InitVisit(const Fst<Arc> &fst) {}
308 
309  bool InitState(StateId s, StateId) {
310  order_->at(s) = num_bfs_states_++;
311  return true;
312  }
313 
314  bool WhiteArc(StateId s, const Arc &arc) { return true; }
315  bool GreyArc(StateId s, const Arc &arc) { return true; }
316  bool BlackArc(StateId s, const Arc &arc) { return true; }
317  void FinishState(StateId s) {}
318  void FinishVisit() {}
319 
320  private:
321  std::vector<StateId> *order_ = nullptr;
322  StateId num_bfs_states_ = 0;
323  };
324 
325  order->assign(fst.NumStates(), kNoStateId);
326  BfsVisitor visitor(order);
327  FifoQueue<StateId> queue;
328  Visit(fst, &visitor, &queue, AnyArcFilter<Arc>());
329 }
330 
331 template <class Arc>
333  MutableFst<Arc> *preprocessedfst,
334  EncodeMapper<Arc> *encoder) {
335  *preprocessedfst = fst;
336  if (!preprocessedfst->NumStates()) return;
337  // Relabels the edges and develops a dictionary.
338  Encode(preprocessedfst, encoder);
339  std::vector<StateId> order;
340  // Finds the BFS sorting order of the FST.
341  BfsOrder(*preprocessedfst, &order);
342  // Reorders the states according to the BFS order.
343  StateSort(preprocessedfst, order);
344 }
345 
346 template <class Arc>
348  std::ostream &strm) {
349  std::vector<StateId> output;
352  std::vector<LZLabel> current_new_input;
353  std::vector<Transition> current_old_input;
354  std::vector<std::pair<StateId, LZLabel>> current_new_output;
355  std::vector<std::pair<StateId, Transition>> current_old_output;
356  std::vector<StateId> final_states;
357  const auto number_of_states = fst.NumStates();
358  StateId seen_states = 0;
359  // Adds the number of states.
360  WriteToBuffer<StateId>(number_of_states);
361  for (StateId state = 0; state < number_of_states; ++state) {
362  current_new_input.clear();
363  current_old_input.clear();
364  current_new_output.clear();
365  current_old_output.clear();
366  if (state > seen_states) ++seen_states;
367  // Collects the final states.
368  if (fst.Final(state) != Weight::Zero()) {
369  final_states.push_back(state);
370  final_weight_.push_back(fst.Final(state));
371  }
372  // Reads the states.
373  for (ArcIterator<Fst<Arc>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
374  const auto &arc = aiter.Value();
375  if (arc.nextstate > seen_states) { // RILEY: > or >= ?
376  ++seen_states;
377  LZLabel temp_label;
378  temp_label.label = arc.ilabel;
379  arc_weight_.push_back(arc.weight);
380  current_new_input.push_back(temp_label);
381  } else {
382  Transition temp_transition;
383  temp_transition.nextstate = arc.nextstate;
384  temp_transition.label = arc.ilabel;
385  temp_transition.weight = arc.weight;
386  current_old_input.push_back(temp_transition);
387  }
388  }
389  // Adds new states.
390  dict_new.BatchEncode(current_new_input, &current_new_output);
391  WriteToBuffer<StateId>(current_new_output.size());
392  for (auto it = current_new_output.cbegin(); it != current_new_output.cend();
393  ++it) {
394  WriteToBuffer<StateId>(it->first);
395  WriteToBuffer<Label>((it->second).label);
396  }
397  // Adds old states by sorting and using difference coding.
398  static const TransitionLessThan transition_less_than;
399  std::sort(current_old_input.begin(), current_old_input.end(),
400  transition_less_than);
401  for (auto it = current_old_input.begin(); it != current_old_input.end();
402  ++it) {
403  arc_weight_.push_back(it->weight);
404  }
405  dict_old.BatchEncode(current_old_input, &current_old_output);
406  std::vector<StateId> dict_old_temp;
407  std::vector<Transition> transition_old_temp;
408  for (auto it = current_old_output.begin(); it != current_old_output.end();
409  ++it) {
410  dict_old_temp.push_back(it->first);
411  transition_old_temp.push_back(it->second);
412  }
413  if (!transition_old_temp.empty()) {
414  if ((transition_old_temp.back()).nextstate == 0 &&
415  (transition_old_temp.back()).label == 0) {
416  transition_old_temp.pop_back();
417  }
418  }
419  std::sort(dict_old_temp.begin(), dict_old_temp.end());
420  std::sort(transition_old_temp.begin(), transition_old_temp.end(),
421  transition_less_than);
422  WriteToBuffer<StateId>(dict_old_temp.size());
423  if (dict_old_temp.size() == transition_old_temp.size()) {
424  WriteToBuffer<int>(0);
425  } else {
426  WriteToBuffer<int>(1);
427  }
428  StateId previous;
429  if (!dict_old_temp.empty()) {
430  WriteToBuffer<StateId>(dict_old_temp.front());
431  previous = dict_old_temp.front();
432  }
433  if (dict_old_temp.size() > 1) {
434  for (auto it = dict_old_temp.begin() + 1; it != dict_old_temp.end();
435  ++it) {
436  WriteToBuffer<StateId>(*it - previous);
437  previous = *it;
438  }
439  }
440  if (!transition_old_temp.empty()) {
441  WriteToBuffer<StateId>((transition_old_temp.front()).nextstate);
442  previous = transition_old_temp.front().nextstate;
443  WriteToBuffer<Label>(transition_old_temp.front().label);
444  }
445  if (transition_old_temp.size() > 1) {
446  for (auto it = transition_old_temp.begin() + 1;
447  it != transition_old_temp.end(); ++it) {
448  WriteToBuffer<StateId>(it->nextstate - previous);
449  previous = it->nextstate;
450  WriteToBuffer<StateId>(it->label);
451  }
452  }
453  }
454  // Adds final states.
455  WriteToBuffer<StateId>(final_states.size());
456  if (!final_states.empty()) {
457  for (auto it = final_states.begin(); it != final_states.end(); ++it) {
458  WriteToBuffer<StateId>(*it);
459  }
460  }
461  WriteToStream(strm);
462  const uint8_t unweighted = fst.Properties(kUnweighted, true) == kUnweighted;
463  WriteType(strm, unweighted);
464  if (unweighted == 0) {
465  WriteWeight(arc_weight_, strm);
466  WriteWeight(final_weight_, strm);
467  }
468 }
469 
470 template <class Arc>
471 void Compressor<Arc>::DecodeProcessedFst(const std::vector<StateId> &input,
473  bool unweighted) {
476  std::vector<std::pair<StateId, LZLabel>> current_new_input;
477  std::vector<std::pair<StateId, Transition>> current_old_input;
478  std::vector<LZLabel> current_new_output;
479  std::vector<Transition> current_old_output;
480  std::vector<std::pair<StateId, Transition>> actual_old_dict_numbers;
481  std::vector<Transition> actual_old_dict_transitions;
482  auto arc_weight_it = arc_weight_.begin();
483  Transition default_transition;
484  StateId seen_states = 1;
485  // Adds states..
486  const StateId num_states = input.front();
487  if (num_states > 0) {
488  const StateId start_state = fst->AddState();
489  fst->SetStart(start_state);
490  for (StateId state = 1; state < num_states; ++state) {
491  fst->AddState();
492  }
493  }
494  auto main_it = input.cbegin();
495  ++main_it;
496  for (StateId current_state = 0; current_state < num_states; ++current_state) {
497  if (current_state >= seen_states) ++seen_states;
498  current_new_input.clear();
499  current_new_output.clear();
500  current_old_input.clear();
501  current_old_output.clear();
502  // New states.
503  StateId current_number_new_elements = *main_it;
504  ++main_it;
505  for (StateId new_integer = 0; new_integer < current_number_new_elements;
506  ++new_integer) {
507  std::pair<StateId, LZLabel> temp_new_dict_element;
508  temp_new_dict_element.first = *main_it;
509  ++main_it;
510  LZLabel temp_label;
511  temp_label.label = *main_it;
512  ++main_it;
513  temp_new_dict_element.second = temp_label;
514  current_new_input.push_back(temp_new_dict_element);
515  }
516  dict_new.BatchDecode(current_new_input, &current_new_output);
517  for (const auto &label : current_new_output) {
518  if (!unweighted) {
519  fst->AddArc(current_state,
520  Arc(label.label, label.label, *arc_weight_it, seen_states));
521  ++arc_weight_it;
522  } else {
523  fst->AddArc(current_state,
524  Arc(label.label, label.label, Weight::One(), seen_states));
525  }
526  ++seen_states;
527  }
528  StateId current_number_old_elements = *main_it;
529  ++main_it;
530  StateId is_zero_removed = *main_it;
531  ++main_it;
532  StateId previous = 0;
533  actual_old_dict_numbers.clear();
534  for (StateId new_integer = 0; new_integer < current_number_old_elements;
535  ++new_integer) {
536  std::pair<StateId, Transition> pair_temp_transition;
537  if (new_integer == 0) {
538  pair_temp_transition.first = *main_it;
539  previous = *main_it;
540  } else {
541  pair_temp_transition.first = *main_it + previous;
542  previous = pair_temp_transition.first;
543  }
544  ++main_it;
545  Transition temp_test;
546  if (!dict_old.SingleDecode(pair_temp_transition.first, &temp_test)) {
547  FSTERROR() << "Compressor::Decode: failed";
548  fst->SetProperties(kError, kError);
549  return;
550  }
551  pair_temp_transition.second = temp_test;
552  actual_old_dict_numbers.push_back(pair_temp_transition);
553  }
554  // Reorders the dictionary elements.
555  static const OldDictCompare old_dict_compare;
556  std::sort(actual_old_dict_numbers.begin(), actual_old_dict_numbers.end(),
557  old_dict_compare);
558  // Transitions.
559  previous = 0;
560  actual_old_dict_transitions.clear();
561  for (StateId new_integer = 0;
562  new_integer < current_number_old_elements - is_zero_removed;
563  ++new_integer) {
564  Transition temp_transition;
565  if (new_integer == 0) {
566  temp_transition.nextstate = *main_it;
567  previous = *main_it;
568  } else {
569  temp_transition.nextstate = *main_it + previous;
570  previous = temp_transition.nextstate;
571  }
572  ++main_it;
573  temp_transition.label = *main_it;
574  ++main_it;
575  actual_old_dict_transitions.push_back(temp_transition);
576  }
577  if (is_zero_removed == 1) {
578  actual_old_dict_transitions.push_back(default_transition);
579  }
580  auto trans_it = actual_old_dict_transitions.cbegin();
581  auto dict_it = actual_old_dict_numbers.cbegin();
582  while (trans_it != actual_old_dict_transitions.cend() &&
583  dict_it != actual_old_dict_numbers.cend()) {
584  if (dict_it->first == 0) {
585  ++dict_it;
586  } else {
587  std::pair<StateId, Transition> temp_pair;
588  static const TransitionEquals transition_equals;
589  static const TransitionLessThan transition_less_than;
590  if (transition_equals(*trans_it, default_transition)) {
591  temp_pair.first = dict_it->first;
592  temp_pair.second = default_transition;
593  ++dict_it;
594  } else if (transition_less_than(dict_it->second, *trans_it)) {
595  temp_pair.first = dict_it->first;
596  temp_pair.second = *trans_it;
597  ++dict_it;
598  } else {
599  temp_pair.first = 0;
600  temp_pair.second = *trans_it;
601  }
602  ++trans_it;
603  current_old_input.push_back(temp_pair);
604  }
605  }
606  while (trans_it != actual_old_dict_transitions.cend()) {
607  std::pair<StateId, Transition> temp_pair;
608  temp_pair.first = 0;
609  temp_pair.second = *trans_it;
610  ++trans_it;
611  current_old_input.push_back(temp_pair);
612  }
613  // Adds old elements in the dictionary.
614  if (!dict_old.BatchDecode(current_old_input, &current_old_output)) {
615  FSTERROR() << "Compressor::Decode: Failed";
616  fst->SetProperties(kError, kError);
617  return;
618  }
619  for (auto it = current_old_output.cbegin(); it != current_old_output.cend();
620  ++it) {
621  if (!unweighted) {
622  fst->AddArc(current_state,
623  Arc(it->label, it->label, *arc_weight_it, it->nextstate));
624  ++arc_weight_it;
625  } else {
626  fst->AddArc(current_state,
627  Arc(it->label, it->label, Weight::One(), it->nextstate));
628  }
629  }
630  }
631  // Adds the final states.
632  StateId number_of_final_states = *main_it;
633  if (number_of_final_states > 0) {
634  ++main_it;
635  for (StateId temp_int = 0; temp_int < number_of_final_states; ++temp_int) {
636  if (unweighted) {
637  fst->SetFinal(*main_it, Weight(0));
638  } else {
639  fst->SetFinal(*main_it, final_weight_[temp_int]);
640  }
641  ++main_it;
642  }
643  }
644 }
645 
646 template <class Arc>
647 void Compressor<Arc>::ReadWeight(std::istream &strm,
648  std::vector<Weight> *output) {
649  int64_t size;
650  Weight weight;
651  ReadType(strm, &size);
652  for (int64_t i = 0; i < size; ++i) {
653  weight.Read(strm);
654  output->push_back(weight);
655  }
656 }
657 
658 template <class Arc>
659 bool Compressor<Arc>::Decompress(std::istream &strm, const std::string &source,
660  MutableFst<Arc> *fst) {
661  fst->DeleteStates();
662  int32_t magic_number = 0;
663  ReadType(strm, &magic_number);
664  if (magic_number != kCompressMagicNumber) {
665  LOG(ERROR) << "Decompress: Bad compressed Fst: " << source;
666  return false;
667  }
668  std::unique_ptr<EncodeMapper<Arc>> encoder(
669  EncodeMapper<Arc>::Read(strm, "Decoding", DECODE));
670  std::vector<bool> bool_code;
671  uint8_t block;
672  uint8_t msb = 128;
673  int64_t data_size;
674  ReadType(strm, &data_size);
675  for (int64_t i = 0; i < data_size; ++i) {
676  ReadType(strm, &block);
677  for (int j = 0; j < 8; ++j) {
678  uint8_t temp = msb & block;
679  bool_code.push_back(temp == 128);
680  block = block << 1;
681  }
682  }
683  std::vector<StateId> int_code;
684  Elias<StateId>::BatchDecode(bool_code, &int_code);
685  bool_code.clear();
686  uint8_t unweighted;
687  ReadType(strm, &unweighted);
688  if (unweighted == 0) {
689  ReadWeight(strm, &arc_weight_);
690  ReadWeight(strm, &final_weight_);
691  }
692  DecodeProcessedFst(int_code, fst, unweighted);
693  DecodeForCompress(fst, *encoder);
694  return !fst->Properties(kError, false);
695 }
696 
697 template <class Arc>
698 void Compressor<Arc>::WriteWeight(const std::vector<Weight> &input,
699  std::ostream &strm) {
700  int64_t size = input.size();
701  WriteType(strm, size);
702  for (auto it = input.begin(); it != input.end(); ++it) {
703  it->Write(strm);
704  }
705 }
706 
707 template <class Arc>
708 void Compressor<Arc>::WriteToStream(std::ostream &strm) {
709  while (buffer_code_.size() % 8 != 0) buffer_code_.push_back(true);
710  int64_t data_size = buffer_code_.size() / 8;
711  WriteType(strm, data_size);
712  int64_t i = 0;
713  uint8_t block;
714  for (auto it = buffer_code_.begin(); it != buffer_code_.end(); ++it) {
715  if (i % 8 == 0) {
716  if (i > 0) WriteType(strm, block);
717  block = 0;
718  } else {
719  block = block << 1;
720  }
721  block |= *it;
722  ++i;
723  }
724  WriteType(strm, block);
725 }
726 
727 template <class Arc>
728 bool Compressor<Arc>::Compress(const Fst<Arc> &fst, std::ostream &strm) {
729  VectorFst<Arc> processedfst;
731  Preprocess(fst, &processedfst, &encoder);
732  WriteType(strm, kCompressMagicNumber);
733  encoder.Write(strm, "ostream");
734  EncodeProcessedFst(processedfst, strm);
735  return true;
736 }
737 
738 // Convenience functions that call the compressor and decompressor.
739 
740 template <class Arc>
741 void Compress(const Fst<Arc> &fst, std::ostream &strm) {
742  Compressor<Arc> comp;
743  comp.Compress(fst, strm);
744 }
745 
746 template <class Arc>
747 bool Compress(const Fst<Arc> &fst, const std::string &source) {
748  std::ofstream fstrm;
749  if (!source.empty()) {
750  fstrm.open(source, std::ios_base::out | std::ios_base::binary);
751  if (!fstrm) {
752  LOG(ERROR) << "Compress: Can't open file: " << source;
753  return false;
754  }
755  }
756  std::ostream &ostrm = fstrm.is_open() ? fstrm : std::cout;
757  Compress(fst, ostrm);
758  return !!ostrm;
759 }
760 
761 template <class Arc>
762 bool Decompress(std::istream &strm, const std::string &source,
763  MutableFst<Arc> *fst) {
764  Compressor<Arc> comp;
765  comp.Decompress(strm, source, fst);
766  return true;
767 }
768 
769 // Returns true on success.
770 template <class Arc>
771 bool Decompress(const std::string &source, MutableFst<Arc> *fst) {
772  std::ifstream fstrm;
773  if (!source.empty()) {
774  fstrm.open(source, std::ios_base::in | std::ios_base::binary);
775  if (!fstrm) {
776  LOG(ERROR) << "Decompress: Can't open file: " << source;
777  return false;
778  }
779  }
780  std::istream &istrm = fstrm.is_open() ? fstrm : std::cin;
781  Decompress(istrm, source.empty() ? "standard input" : source, fst);
782  return !!istrm;
783 }
784 
785 } // namespace fst
786 
787 #endif // FST_EXTENSIONS_COMPRESS_COMPRESS_H_
constexpr int32_t kCompressMagicNumber
Definition: compress.h:50
bool Decompress(std::istream &strm, const std::string &source, MutableFst< Arc > *fst)
Definition: compress.h:762
bool ExpandLZCode(const std::vector< std::pair< Var, Edge >> &code, std::vector< std::vector< Edge >> *expanded_code)
Definition: compress.h:57
void ArcMap(MutableFst< A > *fst, C *mapper)
Definition: arc-map.h:110
void DecodeProcessedFst(const std::vector< StateId > &input, MutableFst< Arc > *fst, bool unweighted)
Definition: compress.h:471
void BatchEncode(const std::vector< Edge > &input, std::vector< std::pair< Var, Edge >> *output)
Definition: compress.h:128
void DecodeForCompress(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
Definition: compress.h:292
virtual uint64_t Properties(uint64_t mask, bool test) const =0
const SymbolTable * OutputSymbols() const
Definition: encode.h:401
bool Decompress(std::istream &strm, const std::string &source, MutableFst< Arc > *fst)
Definition: compress.h:659
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
Definition: encode.h:474
bool Compress(const Fst< Arc > &fst, std::ostream &strm)
Definition: compress.h:728
void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter, bool access_only=false)
Definition: visit.h:74
bool Write(std::ostream &strm, const std::string &source) const
Definition: encode.h:385
constexpr uint64_t kError
Definition: properties.h:51
virtual void SetInputSymbols(const SymbolTable *isyms)=0
#define LOG(type)
Definition: log.h:49
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
bool BatchDecode(const std::vector< std::pair< Var, Edge >> &input, std::vector< Edge > *output)
Definition: compress.h:155
void ReadWeight(std::istream &strm, std::vector< Weight > *output)
Definition: compress.h:647
constexpr int kNoStateId
Definition: fst.h:202
void WriteWeight(const std::vector< Weight > &input, std::ostream &strm)
Definition: compress.h:698
void BfsOrder(const ExpandedFst< Arc > &fst, std::vector< StateId > *order)
Definition: compress.h:300
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:214
#define FSTERROR()
Definition: util.h:53
typename Arc::Label Label
Definition: compress.h:183
void EncodeProcessedFst(const ExpandedFst< Arc > &fst, std::ostream &strm)
Definition: compress.h:347
static void BatchDecode(const std::vector< bool > &input, std::vector< Var > *output)
Definition: elias.h:77
typename Arc::StateId StateId
Definition: compress.h:184
constexpr uint8_t kEncodeLabels
Definition: encode.h:43
virtual void SetProperties(uint64_t props, uint64_t mask)=0
static void DeltaEncode(const Var &input, std::vector< bool > *code)
Definition: elias.h:59
const SymbolTable * InputSymbols() const
Definition: encode.h:399
virtual void AddArc(StateId, const Arc &)=0
constexpr uint64_t kUnweighted
Definition: properties.h:105
void WriteToStream(std::ostream &strm)
Definition: compress.h:708
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
virtual void DeleteStates(const std::vector< StateId > &)=0
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:68
void WriteToBuffer(CVar input)
Definition: compress.h:223
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
void Compress(const Fst< Arc > &fst, std::ostream &strm)
Definition: compress.h:741
virtual StateId NumStates() const =0
void Preprocess(const Fst< Arc > &fst, MutableFst< Arc > *preprocessedfst, EncodeMapper< Arc > *encoder)
Definition: compress.h:332
bool SingleDecode(const Var &index, Edge *output)
Definition: compress.h:103
bool StateSort(std::vector< IntervalSet< Label >> *interval_sets, const std::vector< StateId > &order)
typename Arc::Weight Weight
Definition: compress.h:185