FST  openfst-1.7.1
OpenFst Library
paren.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Common classes for PDT parentheses.
5 
6 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
7 #define FST_EXTENSIONS_PDT_PAREN_H_
8 
9 #include <algorithm>
10 #include <set>
11 #include <unordered_map>
12 #include <unordered_set>
13 
14 #include <fst/log.h>
15 
17 #include <fst/extensions/pdt/pdt.h>
18 #include <fst/dfs-visit.h>
19 #include <fst/fst.h>
20 
21 
22 namespace fst {
23 namespace internal {
24 
25 // ParenState: Pair of an open (close) parenthesis and its destination (source)
26 // state.
27 
28 template <class Arc>
29 struct ParenState {
30  using Label = typename Arc::Label;
31  using StateId = typename Arc::StateId;
32 
33  Label paren_id; // ID of open (close) paren.
34  StateId state_id; // Destination (source) state of open (close) paren.
35 
36  explicit ParenState(Label paren_id = kNoLabel, StateId state_id = kNoStateId)
37  : paren_id(paren_id), state_id(state_id) {}
38 
39  bool operator==(const ParenState<Arc> &other) const {
40  if (&other == this) return true;
41  return other.paren_id == paren_id && other.state_id == state_id;
42  }
43 
44  bool operator!=(const ParenState<Arc> &other) const {
45  return !(other == *this);
46  }
47 
48  struct Hash {
49  size_t operator()(const ParenState<Arc> &pstate) const {
50  static constexpr auto prime = 7853;
51  return pstate.paren_id + pstate.state_id * prime;
52  }
53  };
54 };
55 
56 // Creates an FST-style const iterator from an STL-style map.
57 template <class Map>
58 class MapIterator {
59  public:
60  using StlIterator = typename Map::const_iterator;
61  using ValueType = typename Map::mapped_type;
62 
63  MapIterator(const Map &map, StlIterator it)
64  : begin_(it), end_(map.end()), it_(it) {}
65 
66  bool Done() const { return it_ == end_ || it_->first != begin_->first; }
67 
68  ValueType Value() const { return it_->second; }
69 
70  void Next() { ++it_; }
71 
72  void Reset() { it_ = begin_; }
73 
74  private:
75  const StlIterator begin_;
76  const StlIterator end_;
77  StlIterator it_;
78 };
79 
80 // PdtParenReachable: Provides various parenthesis reachability information.
81 
82 template <class Arc>
84  public:
85  using Label = typename Arc::Label;
86  using StateId = typename Arc::StateId;
87 
89  using StateHash = typename State::Hash;
90 
91  // Maps from state ID to reachable paren IDs from (to) that state.
92  using ParenMultimap = std::unordered_multimap<StateId, Label>;
93 
94  // Maps from paren ID and state ID to reachable state set ID.
95  using StateSetMap = std::unordered_map<State, ssize_t, StateHash>;
96 
97  // Maps from paren ID and state ID to arcs exiting that state with that
98  // Label.
99  using ParenArcMultimap = std::unordered_map<State, Arc, StateHash>;
100 
102 
104 
106 
107  // Computes close (open) parenthesis reachability information for a PDT with
108  // bounded stack.
110  const std::vector<std::pair<Label, Label>> &parens,
111  bool close)
112  : fst_(fst), parens_(parens), close_(close), error_(false) {
113  paren_map_.reserve(2 * parens.size());
114  for (size_t i = 0; i < parens.size(); ++i) {
115  const auto &pair = parens[i];
116  paren_map_[pair.first] = i;
117  paren_map_[pair.second] = i;
118  }
119  if (close_) {
120  const auto start = fst.Start();
121  if (start == kNoStateId) return;
122  if (!DFSearch(start)) {
123  FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
124  error_ = true;
125  }
126  } else {
127  FSTERROR() << "PdtParenReachable: Open paren info not implemented";
128  error_ = true;
129  }
130  }
131 
132  bool Error() const { return error_; }
133 
134  // Given a state ID, returns an iterator over paren IDs for close (open)
135  // parens reachable from that state along balanced paths.
137  return ParenIterator(paren_multimap_, paren_multimap_.find(s));
138  }
139 
140  // Given a paren ID and a state ID s, returns an iterator over states that can
141  // be reached along balanced paths from (to) s that have have close (open)
142  // parentheses matching the paren ID exiting (entering) those states.
144  const State paren_state(paren_id, s);
145  const auto it = set_map_.find(paren_state);
146  if (it == set_map_.end()) {
147  return state_sets_.FindSet(-1);
148  } else {
149  return state_sets_.FindSet(it->second);
150  }
151  }
152 
153  // Given a paren ID and a state ID s, return an iterator over arcs that exit
154  // (enter) s and are labeled with a close (open) parenthesis matching the
155  // paren ID.
157  const State paren_state(paren_id, s);
158  return ParenArcIterator(paren_arc_multimap_,
159  paren_arc_multimap_.find(paren_state));
160  }
161 
162  private:
163  // Returns false when cycle detected during DFS gathering paren and state set
164  // information.
165  bool DFSearch(StateId s);
166 
167  // Unions state sets together gathered by the DFS.
168  void ComputeStateSet(StateId s);
169 
170  // Gathers state set(s) from state.
171  void UpdateStateSet(StateId nextstate, std::set<Label> *paren_set,
172  std::vector<std::set<StateId>> *state_sets) const;
173 
174  const Fst<Arc> &fst_;
175  // Paren IDs to labels.
176  const std::vector<std::pair<Label, Label>> &parens_;
177  // Close/open paren info?
178  const bool close_;
179  // Labels to paren IDs.
180  std::unordered_map<Label, Label> paren_map_;
181  // Paren reachability.
182  ParenMultimap paren_multimap_;
183  // Paren arcs.
184  ParenArcMultimap paren_arc_multimap_;
185  // DFS states.
186  std::vector<uint8> state_color_;
187  // Reachable states to IDs.
188  mutable Collection<ssize_t, StateId> state_sets_;
189  // IDs to reachable states.
190  StateSetMap set_map_;
191  bool error_;
192 
193  PdtParenReachable(const PdtParenReachable &) = delete;
194  PdtParenReachable &operator=(const PdtParenReachable &) = delete;
195 };
196 
197 // Gathers paren and state set information.
198 template <class Arc>
200  static constexpr uint8 kWhiteState = 0x01; // Undiscovered.
201  static constexpr uint8 kGreyState = 0x02; // Discovered & unfinished.
202  static constexpr uint8 kBlackState = 0x04; // Finished.
203  if (s >= state_color_.size()) state_color_.resize(s + 1, kWhiteState);
204  if (state_color_[s] == kBlackState) return true;
205  if (state_color_[s] == kGreyState) return false;
206  state_color_[s] = kGreyState;
207  for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
208  const auto &arc = aiter.Value();
209  const auto it = paren_map_.find(arc.ilabel);
210  if (it != paren_map_.end()) { // Paren?
211  const auto paren_id = it->second;
212  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
213  if (!DFSearch(arc.nextstate)) return false;
214  for (auto set_iter = FindStates(paren_id, arc.nextstate);
215  !set_iter.Done(); set_iter.Next()) {
216  for (auto paren_arc_iter =
217  FindParenArcs(paren_id, set_iter.Element());
218  !paren_arc_iter.Done(); paren_arc_iter.Next()) {
219  const auto &cparc = paren_arc_iter.Value();
220  if (!DFSearch(cparc.nextstate)) return false;
221  }
222  }
223  }
224  } else if (!DFSearch(arc.nextstate)) { // Non-paren.
225  return false;
226  }
227  }
228  ComputeStateSet(s);
229  state_color_[s] = kBlackState;
230  return true;
231 }
232 
233 // Unions state sets.
234 template <class Arc>
236  std::set<Label> paren_set;
237  std::vector<std::set<StateId>> state_sets(parens_.size());
238  for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
239  const auto &arc = aiter.Value();
240  const auto it = paren_map_.find(arc.ilabel);
241  if (it != paren_map_.end()) { // Paren?
242  const auto paren_id = it->second;
243  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
244  for (auto set_iter = FindStates(paren_id, arc.nextstate);
245  !set_iter.Done(); set_iter.Next()) {
246  for (auto paren_arc_iter =
247  FindParenArcs(paren_id, set_iter.Element());
248  !paren_arc_iter.Done(); paren_arc_iter.Next()) {
249  const auto &cparc = paren_arc_iter.Value();
250  UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
251  }
252  }
253  } else { // Close paren.
254  paren_set.insert(paren_id);
255  state_sets[paren_id].insert(s);
256  const State paren_state(paren_id, s);
257  paren_arc_multimap_.insert(std::make_pair(paren_state, arc));
258  }
259  } else { // Non-paren.
260  UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
261  }
262  }
263  std::vector<StateId> state_set;
264  for (auto paren_iter = paren_set.begin(); paren_iter != paren_set.end();
265  ++paren_iter) {
266  state_set.clear();
267  const auto paren_id = *paren_iter;
268  paren_multimap_.insert(std::make_pair(s, paren_id));
269  for (auto state_iter = state_sets[paren_id].begin();
270  state_iter != state_sets[paren_id].end(); ++state_iter) {
271  state_set.push_back(*state_iter);
272  }
273  const State paren_state(paren_id, s);
274  set_map_[paren_state] = state_sets_.FindId(state_set);
275  }
276 }
277 
278 // Gathers state sets.
279 template <class Arc>
281  StateId nextstate, std::set<Label> *paren_set,
282  std::vector<std::set<StateId>> *state_sets) const {
283  for (auto paren_iter = FindParens(nextstate); !paren_iter.Done();
284  paren_iter.Next()) {
285  const auto paren_id = paren_iter.Value();
286  paren_set->insert(paren_id);
287  for (auto set_iter = FindStates(paren_id, nextstate); !set_iter.Done();
288  set_iter.Next()) {
289  (*state_sets)[paren_id].insert(set_iter.Element());
290  }
291  }
292 }
293 
294 // Stores balancing parenthesis data for a PDT. Unlike PdtParenReachable above
295 // this allows on-the-fly construction (e.g., in PdtShortestPath).
296 template <class Arc>
298  public:
299  using Label = typename Arc::Label;
300  using StateId = typename Arc::StateId;
301 
303  using StateHash = typename State::Hash;
304 
305  // Set for open parens.
306  using OpenParenSet = std::unordered_set<State, StateHash>;
307 
308  // Maps from open paren destination state to parenthesis ID.
309  using OpenParenMap = std::unordered_multimap<StateId, Label>;
310 
311  // Maps from open paren state to source states of matching close parens
312  using CloseParenMap = std::unordered_multimap<State, StateId, StateHash>;
313 
314  // Maps from open paren state to close source set ID.
315  using CloseSourceMap = std::unordered_map<State, ssize_t, StateHash>;
316 
318 
320 
321  void Clear() {
322  open_paren_map_.clear();
323  close_paren_map_.clear();
324  }
325 
326  // Adds an open parenthesis with destination state open_dest.
327  void OpenInsert(Label paren_id, StateId open_dest) {
328  const State key(paren_id, open_dest);
329  if (!open_paren_set_.count(key)) {
330  open_paren_set_.insert(key);
331  open_paren_map_.emplace(open_dest, paren_id);
332  }
333  }
334 
335  // Adds a matching closing parenthesis with source state close_source
336  // balancing an open_parenthesis with destination state open_dest if
337  // OpenInsert() previously called.
338  void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
339  const State key(paren_id, open_dest);
340  if (open_paren_set_.count(key)) {
341  close_paren_map_.emplace(key, close_source);
342  }
343  }
344 
345  // Finds close paren source states matching an open parenthesis. The following
346  // methods are then used to iterate through those matching states. Should be
347  // called only after FinishInsert(open_dest).
349  const State key(paren_id, open_dest);
350  const auto it = close_source_map_.find(key);
351  if (it == close_source_map_.end()) {
352  return close_source_sets_.FindSet(-1);
353  } else {
354  return close_source_sets_.FindSet(it->second);
355  }
356  }
357 
358  // Called when all open and close parenthesis insertions (w.r.t. open
359  // parentheses entering state open_dest) are finished. Must be called before
360  // Find(open_dest).
361  void FinishInsert(StateId open_dest) {
362  std::vector<StateId> close_sources;
363  for (auto oit = open_paren_map_.find(open_dest);
364  oit != open_paren_map_.end() && oit->first == open_dest;) {
365  const auto paren_id = oit->second;
366  close_sources.clear();
367  const State key(paren_id, open_dest);
368  open_paren_set_.erase(open_paren_set_.find(key));
369  for (auto cit = close_paren_map_.find(key);
370  cit != close_paren_map_.end() && cit->first == key;) {
371  close_sources.push_back(cit->second);
372  close_paren_map_.erase(cit++);
373  }
374  std::sort(close_sources.begin(), close_sources.end());
375  auto unique_end = std::unique(close_sources.begin(), close_sources.end());
376  close_sources.resize(unique_end - close_sources.begin());
377  if (!close_sources.empty()) {
378  close_source_map_[key] = close_source_sets_.FindId(close_sources);
379  }
380  open_paren_map_.erase(oit++);
381  }
382  }
383 
384  // Returns a new balance data object representing the reversed balance
385  // information.
386  PdtBalanceData<Arc> *Reverse(StateId num_states, StateId num_split,
387  StateId state_id_shift) const;
388 
389  private:
390  // Open paren at destintation state?
391  OpenParenSet open_paren_set_;
392  // Open parens per state.
393  OpenParenMap open_paren_map_;
394  // Current open destination state.
395  State open_dest_;
396  // Current open paren/state.
397  typename OpenParenMap::const_iterator open_iter_;
398  // Close states to (open paren, state).
399  CloseParenMap close_paren_map_;
400  // (Paren, state) to set ID.
401  CloseSourceMap close_source_map_;
402  mutable Collection<ssize_t, StateId> close_source_sets_;
403 };
404 
405 // Return a new balance data object representing the reversed balance
406 // information.
407 template <class Arc>
409  StateId num_states, StateId num_split, StateId state_id_shift) const {
410  auto *bd = new PdtBalanceData<Arc>;
411  std::unordered_set<StateId> close_sources;
412  const auto split_size = num_states / num_split;
413  for (StateId i = 0; i < num_states; i += split_size) {
414  close_sources.clear();
415  for (auto it = close_source_map_.begin(); it != close_source_map_.end();
416  ++it) {
417  const auto &okey = it->first;
418  const auto open_dest = okey.state_id;
419  const auto paren_id = okey.paren_id;
420  for (auto set_iter = close_source_sets_.FindSet(it->second);
421  !set_iter.Done(); set_iter.Next()) {
422  const auto close_source = set_iter.Element();
423  if ((close_source < i) || (close_source >= i + split_size)) continue;
424  close_sources.insert(close_source + state_id_shift);
425  bd->OpenInsert(paren_id, close_source + state_id_shift);
426  bd->CloseInsert(paren_id, close_source + state_id_shift,
427  open_dest + state_id_shift);
428  }
429  }
430  for (auto it = close_sources.begin(); it != close_sources.end(); ++it) {
431  bd->FinishInsert(*it);
432  }
433  }
434  return bd;
435 }
436 
437 } // namespace internal
438 } // namespace fst
439 
440 #endif // FST_EXTENSIONS_PDT_PAREN_H_
size_t operator()(const ParenState< Arc > &pstate) const
Definition: paren.h:49
constexpr int kNoLabel
Definition: fst.h:179
std::unordered_map< State, Arc, StateHash > ParenArcMultimap
Definition: paren.h:99
std::unordered_multimap< State, StateId, StateHash > CloseParenMap
Definition: paren.h:312
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source)
Definition: paren.h:338
std::unordered_multimap< StateId, Label > ParenMultimap
Definition: paren.h:92
std::unordered_multimap< StateId, Label > OpenParenMap
Definition: paren.h:309
typename Map::const_iterator StlIterator
Definition: paren.h:60
constexpr int kNoStateId
Definition: fst.h:180
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:20
void Map(MutableFst< A > *fst, C *mapper)
Definition: map.h:17
#define FSTERROR()
Definition: util.h:35
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:317
typename Arc::Label Label
Definition: paren.h:299
typename Map::mapped_type ValueType
Definition: paren.h:61
std::unordered_set< State, StateHash > OpenParenSet
Definition: paren.h:306
uint8_t uint8
Definition: types.h:29
typename State::Hash StateHash
Definition: paren.h:89
typename Arc::StateId StateId
Definition: paren.h:86
virtual StateId Start() const =0
typename Arc::StateId StateId
Definition: paren.h:31
void OpenInsert(Label paren_id, StateId open_dest)
Definition: paren.h:327
PdtParenReachable(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, bool close)
Definition: paren.h:109
MapIterator(const Map &map, StlIterator it)
Definition: paren.h:63
bool operator!=(const ParenState< Arc > &other) const
Definition: paren.h:44
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:105
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const
Definition: paren.h:156
typename Arc::Label Label
Definition: paren.h:30
ParenState(Label paren_id=kNoLabel, StateId state_id=kNoStateId)
Definition: paren.h:36
ValueType Value() const
Definition: paren.h:68
PdtBalanceData< Arc > * Reverse(StateId num_states, StateId num_split, StateId state_id_shift) const
Definition: paren.h:408
bool operator==(const ParenState< Arc > &other) const
Definition: paren.h:39
std::unordered_map< State, ssize_t, StateHash > StateSetMap
Definition: paren.h:95
bool Done() const
Definition: paren.h:66
typename Arc::Label Label
Definition: paren.h:85
typename State::Hash StateHash
Definition: paren.h:303
void FinishInsert(StateId open_dest)
Definition: paren.h:361
typename Arc::StateId StateId
Definition: paren.h:300
SetIterator Find(Label paren_id, StateId open_dest)
Definition: paren.h:348
ParenIterator FindParens(StateId s) const
Definition: paren.h:136
SetIterator FindStates(Label paren_id, StateId s) const
Definition: paren.h:143
std::unordered_map< State, ssize_t, StateHash > CloseSourceMap
Definition: paren.h:315