FST  openfst-1.8.3
OpenFst Library
expander-cache.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Cache implementations for ExpanderFst.
19 //
20 // Expander caches must expose a State type and a FindOrExpand template method:
21 //
22 // class ExpanderCache {
23 // public:
24 // class State;
25 //
26 // template <Expander>
27 // State* FindOrExpander(Expander& expander, StateId id) {
28 // if (id is found in cache) return cached_state;
29 //
30 // // Use the provided expander to create a new cached state and cache it.
31 // expander.Expand(id, &new_state);
32 // insert new_state into cache;
33 // return new_state;
34 // }
35 // };
36 //
37 // Cache implementations must be copyable and assignable. It is up to the
38 // implementation whether this means it will discard the contents of the cache,
39 // copy all of the cache, share some of the cache etc. It is *REQUIRED* that the
40 // copy be "safe", the copy and the original must be usable from concurrent
41 // threads without accessing any internally shared state.
42 
43 #ifndef FST_EXPANDER_CACHE_H_
44 #define FST_EXPANDER_CACHE_H_
45 
46 #include <cstddef>
47 #include <deque>
48 #include <memory>
49 #include <utility>
50 #include <vector>
51 
52 #include <fst/cache.h>
53 #include <fst/fst.h>
54 #include <unordered_map>
55 #include <unordered_map>
56 
57 namespace fst {
58 
59 // Stateful allocators can't be used without careful handling in threaded
60 // contexts, so arbitrary stl allocators aren't supported here.
61 template <class A>
63  public:
64  using Arc = A;
65  using Weight = typename Arc::Weight;
66  using StateId = typename Arc::StateId;
67 
68  void Reset() {
69  final_weight_ = Weight::Zero();
70  niepsilons_ = 0;
71  noepsilons_ = 0;
72  arcs_.clear();
73  }
74 
75  Weight Final() const { return final_weight_; }
76 
77  size_t NumInputEpsilons() const { return niepsilons_; }
78 
79  size_t NumOutputEpsilons() const { return noepsilons_; }
80 
81  size_t NumArcs() const { return arcs_.size(); }
82 
83  const Arc &GetArc(size_t n) const { return arcs_[n]; }
84 
85  const Arc *Arcs() const { return arcs_.empty() ? nullptr : &arcs_[0]; }
86 
87  void SetFinal(Weight weight) { final_weight_ = weight; }
88 
89  void ReserveArcs(size_t n) { arcs_.reserve(n); }
90 
91  void AddArc(const Arc &arc) {
92  if (arc.ilabel == 0) ++niepsilons_;
93  if (arc.olabel == 0) ++noepsilons_;
94  arcs_.push_back(arc);
95  }
96 
97  void AddArc(Arc &&arc) {
98  if (arc.ilabel == 0) ++niepsilons_;
99  if (arc.olabel == 0) ++noepsilons_;
100  arcs_.push_back(std::move(arc));
101  }
102 
103  int *MutableRefCount() const { return nullptr; }
104 
105  private:
106  Weight final_weight_ = Weight::Zero();
107  size_t niepsilons_ = 0; // Number of input epsilons.
108  size_t noepsilons_ = 0; // Number of output epsilons.
109  std::vector<Arc> arcs_;
110 };
111 
112 template <class A>
114  public:
115  using Arc = A;
116  using StateId = typename Arc::StateId;
117 
118  // Reference-counted state.
119  class State : public SimpleVectorCacheState<Arc> {
120  public:
121  int *MutableRefCount() { return &ref_count_; }
122 
123  void Reset() {
125  ref_count_ = 0;
126  }
127 
128  private:
129  int ref_count_ = 0;
130 
132  };
133 
134  NoGcKeepOneExpanderCache() : state_(new State) {}
135 
137  : state_(new State(*copy.state_)) {}
138 
139  template <class Expander>
140  State *FindOrExpand(Expander &expander, StateId state_id) {
141  if (state_id == state_id_) return state_.get();
142  if (state_->ref_count_ > 0) cache_[state_id_] = std::move(state_);
143  state_id_ = state_id;
144  if (cache_.empty()) {
145  state_->Reset();
146  expander.Expand(state_id_, state_.get());
147  return state_.get();
148  }
149  if (auto i = cache_.find(state_id_); i != cache_.end()) {
150  state_ = std::move(i->second);
151  }
152  if (state_ == nullptr) {
153  state_ = std::make_unique<State>();
154  expander.Expand(state_id_, state_.get());
155  }
156  return state_.get();
157  }
158 
159  StateId state_id_ = kNoStateId;
160  std::unique_ptr<State> state_;
161  std::unordered_map<StateId, std::unique_ptr<State>> cache_;
162 };
163 
164 template <class A>
166  public:
167  using Arc = A;
168  using StateId = typename Arc::StateId;
169 
171 
172  HashExpanderCache(const HashExpanderCache &copy) { *this = copy; }
173 
175  for (const auto &[id, state] : copy.cache_) {
176  cache_[id] = std::make_unique<State>(*state);
177  }
178  return *this;
179  }
180 
181  ~HashExpanderCache() = default;
182 
183  template <class Expander>
184  State *FindOrExpand(Expander &expander, StateId state_id) {
185  auto [it, inserted] = cache_.emplace(state_id, nullptr);
186  if (inserted) {
187  it->second = std::make_unique<State>();
188  expander.Expand(state_id, it->second.get());
189  }
190  return it->second.get();
191  }
192 
193  private:
194  std::unordered_map<StateId, std::unique_ptr<State>> cache_;
195 };
196 
197 template <class A>
199  public:
200  using Arc = A;
201  using StateId = typename Arc::StateId;
202 
204 
205  VectorExpanderCache() : vec_(0, nullptr) {}
206 
207  VectorExpanderCache(const VectorExpanderCache &copy) { *this = copy; }
208 
210  vec_.resize(copy.vec_.size());
211  for (StateId i = 0; i < copy.vec_.size(); ++i) {
212  const auto *state = copy.vec_[i];
213  if (state != nullptr) {
214  states_.emplace_back(*state);
215  vec_[i] = &states_.back();
216  }
217  }
218  return *this;
219  }
220 
221  template <class Expander>
222  State *FindOrExpand(Expander &expander, StateId state_id) {
223  if (state_id >= vec_.size()) vec_.resize(state_id + 1);
224  auto **slot = &vec_[state_id];
225  if (*slot == nullptr) {
226  states_.emplace_back();
227  *slot = &states_.back();
228  expander.Expand(state_id, *slot);
229  }
230  return *slot;
231  }
232 
233  private:
234  std::deque<State> states_;
235  std::vector<State *> vec_;
236 };
237 
238 template <class Expander>
240 
241 } // namespace fst
242 #endif // FST_EXPANDER_CACHE_H_
const Arc & GetArc(size_t n) const
const Arc * Arcs() const
void SetFinal(Weight weight)
VectorExpanderCache(const VectorExpanderCache &copy)
typename Arc::StateId StateId
typename Arc::StateId StateId
HashExpanderCache & operator=(const HashExpanderCache &copy)
State * FindOrExpand(Expander &expander, StateId state_id)
constexpr int kNoStateId
Definition: fst.h:196
VectorExpanderCache & operator=(const VectorExpanderCache &copy)
NoGcKeepOneExpanderCache(const NoGcKeepOneExpanderCache &copy)
typename Arc::StateId StateId
State * FindOrExpand(Expander &expander, StateId state_id)
std::unordered_map< StateId, std::unique_ptr< State > > cache_
size_t NumOutputEpsilons() const
size_t NumInputEpsilons() const
void AddArc(const Arc &arc)
State * FindOrExpand(Expander &expander, StateId state_id)
std::unique_ptr< State > state_
HashExpanderCache(const HashExpanderCache &copy)