FST  openfst-1.8.2.post1
OpenFst Library
paren.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Common classes for PDT parentheses.
19 
20 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
21 #define FST_EXTENSIONS_PDT_PAREN_H_
22 
23 #include <algorithm>
24 #include <cstdint>
25 #include <set>
26 #include <vector>
27 
28 #include <fst/log.h>
30 #include <fst/extensions/pdt/pdt.h>
31 #include <fst/dfs-visit.h>
32 #include <fst/fst.h>
33 #include <unordered_map>
34 #include <unordered_set>
35 
36 namespace fst {
37 namespace internal {
38 
39 // ParenState: Pair of an open (close) parenthesis and its destination (source)
40 // state.
41 
42 template <class Arc>
43 struct ParenState {
44  using Label = typename Arc::Label;
45  using StateId = typename Arc::StateId;
46 
47  Label paren_id; // ID of open (close) paren.
48  StateId state_id; // Destination (source) state of open (close) paren.
49 
50  explicit ParenState(Label paren_id = kNoLabel, StateId state_id = kNoStateId)
51  : paren_id(paren_id), state_id(state_id) {}
52 
53  bool operator==(const ParenState<Arc> &other) const {
54  if (&other == this) return true;
55  return other.paren_id == paren_id && other.state_id == state_id;
56  }
57 
58  bool operator!=(const ParenState<Arc> &other) const {
59  return !(other == *this);
60  }
61 
62  struct Hash {
63  size_t operator()(const ParenState<Arc> &pstate) const {
64  static constexpr auto prime = 7853;
65  return pstate.paren_id + pstate.state_id * prime;
66  }
67  };
68 };
69 
70 // Creates an FST-style const iterator from range of contiguous values
71 // in memory.
72 template <class V>
73 class SpanIterator {
74  public:
75  using ValueType = const V;
76 
77  SpanIterator() = default;
78  explicit SpanIterator(ValueType *begin, ValueType *end)
79  : begin_(begin), end_(end), it_(begin) {}
80 
81  bool Done() const { return it_ == end_; }
82  ValueType Value() const { return *it_; }
83  void Next() { ++it_; }
84  void Reset() { it_ = begin_; }
85 
86  private:
87  ValueType *const begin_ = nullptr;
88  ValueType *const end_ = nullptr;
89  ValueType *it_ = nullptr;
90 };
91 
92 // PdtParenReachable: Provides various parenthesis reachability information.
93 
94 template <class Arc>
96  public:
97  using Label = typename Arc::Label;
98  using StateId = typename Arc::StateId;
99 
101  using StateHash = typename State::Hash;
102 
103  // Maps from state ID to reachable paren IDs from (to) that state.
104  using ParenMultimap = std::unordered_map<StateId, std::vector<Label>>;
105 
106  // Maps from paren ID and state ID to reachable state set ID.
107  using StateSetMap = std::unordered_map<State, ssize_t, StateHash>;
108 
109  // Maps from paren ID and state ID to arcs exiting that state with that
110  // Label.
111  using ParenArcMultimap =
112  std::unordered_map<State, std::vector<Arc>, StateHash>;
113 
114  using ParenIterator =
116 
117  using ParenArcIterator =
119 
121 
122  // Computes close (open) parenthesis reachability information for a PDT with
123  // bounded stack.
125  const std::vector<std::pair<Label, Label>> &parens,
126  bool close)
127  : fst_(fst), parens_(parens), close_(close), error_(false) {
128  paren_map_.reserve(2 * parens.size());
129  for (size_t i = 0; i < parens.size(); ++i) {
130  const auto &pair = parens[i];
131  paren_map_[pair.first] = i;
132  paren_map_[pair.second] = i;
133  }
134  if (close_) {
135  const auto start = fst.Start();
136  if (start == kNoStateId) return;
137  if (!DFSearch(start)) {
138  FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
139  error_ = true;
140  }
141  } else {
142  FSTERROR() << "PdtParenReachable: Open paren info not implemented";
143  error_ = true;
144  }
145  }
146 
147  bool Error() const { return error_; }
148 
149  // Given a state ID, returns an iterator over paren IDs for close (open)
150  // parens reachable from that state along balanced paths.
152  const auto parens = paren_multimap_.find(s);
153  if (parens != paren_multimap_.end()) {
154  // Cannot dereference iterators if the vector is empty, but that never
155  // happens. ComputeStateSet always adds something to the vector,
156  // and never leaves an empty vector.
157  DCHECK(!parens->second.empty());
158  return ParenIterator(&*parens->second.begin(), &*parens->second.end());
159  } else {
160  return ParenIterator();
161  }
162  }
163 
164  // Given a paren ID and a state ID s, returns an iterator over states that can
165  // be reached along balanced paths from (to) s that have have close (open)
166  // parentheses matching the paren ID exiting (entering) those states.
168  const State paren_state(paren_id, s);
169  const auto it = set_map_.find(paren_state);
170  if (it == set_map_.end()) {
171  return state_sets_.FindSet(-1);
172  } else {
173  return state_sets_.FindSet(it->second);
174  }
175  }
176 
177  // Given a paren ID and a state ID s, return an iterator over arcs that exit
178  // (enter) s and are labeled with a close (open) parenthesis matching the
179  // paren ID.
181  const State paren_state(paren_id, s);
182  const auto paren_arcs = paren_arc_multimap_.find(paren_state);
183  if (paren_arcs != paren_arc_multimap_.end()) {
184  // Cannot dereference iterators if the vector is empty, but that never
185  // happens. ComputeStateSet always adds something to the vector,
186  // and never leaves an empty vector.
187  DCHECK(!paren_arcs->second.empty());
188  return ParenArcIterator(&*paren_arcs->second.begin(),
189  &*paren_arcs->second.end());
190  } else {
191  return ParenArcIterator();
192  }
193  }
194 
195  private:
196  // Returns false when cycle detected during DFS gathering paren and state set
197  // information.
198  bool DFSearch(StateId s);
199 
200  // Unions state sets together gathered by the DFS.
201  void ComputeStateSet(StateId s);
202 
203  // Gathers state set(s) from state.
204  void UpdateStateSet(StateId nextstate, std::set<Label> *paren_set,
205  std::vector<std::set<StateId>> *state_sets) const;
206 
207  const Fst<Arc> &fst_;
208  // Paren IDs to labels.
209  const std::vector<std::pair<Label, Label>> &parens_;
210  // Close/open paren info?
211  const bool close_;
212  // Labels to paren IDs.
213  std::unordered_map<Label, Label> paren_map_;
214  // Paren reachability.
215  ParenMultimap paren_multimap_;
216  // Paren arcs.
217  ParenArcMultimap paren_arc_multimap_;
218  // DFS states.
219  std::vector<uint8_t> state_color_;
220  // Reachable states to IDs.
221  mutable Collection<ssize_t, StateId> state_sets_;
222  // IDs to reachable states.
223  StateSetMap set_map_;
224  bool error_;
225 
226  PdtParenReachable(const PdtParenReachable &) = delete;
227  PdtParenReachable &operator=(const PdtParenReachable &) = delete;
228 };
229 
230 // Gathers paren and state set information.
231 template <class Arc>
233  static constexpr uint8_t kWhiteState = 0x01; // Undiscovered.
234  static constexpr uint8_t kGreyState = 0x02; // Discovered & unfinished.
235  static constexpr uint8_t kBlackState = 0x04; // Finished.
236  if (s >= state_color_.size()) state_color_.resize(s + 1, kWhiteState);
237  if (state_color_[s] == kBlackState) return true;
238  if (state_color_[s] == kGreyState) return false;
239  state_color_[s] = kGreyState;
240  for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
241  const auto &arc = aiter.Value();
242  const auto it = paren_map_.find(arc.ilabel);
243  if (it != paren_map_.end()) { // Paren?
244  const auto paren_id = it->second;
245  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
246  if (!DFSearch(arc.nextstate)) return false;
247  for (auto set_iter = FindStates(paren_id, arc.nextstate);
248  !set_iter.Done(); set_iter.Next()) {
249  // Recursive DFSearch call may modify paren_arc_multimap_ via
250  // ComputeStateSet, so save the paren arcs to avoid issues
251  // with iterator invalidation.
252  std::vector<StateId> cp_nextstates;
253  for (auto paren_arc_iter =
254  FindParenArcs(paren_id, set_iter.Element());
255  !paren_arc_iter.Done(); paren_arc_iter.Next()) {
256  cp_nextstates.push_back(paren_arc_iter.Value().nextstate);
257  }
258  for (const StateId cp_nextstate : cp_nextstates) {
259  if (!DFSearch(cp_nextstate)) return false;
260  }
261  }
262  }
263  } else if (!DFSearch(arc.nextstate)) { // Non-paren.
264  return false;
265  }
266  }
267  ComputeStateSet(s);
268  state_color_[s] = kBlackState;
269  return true;
270 }
271 
272 // Unions state sets.
273 template <class Arc>
275  std::set<Label> paren_set;
276  std::vector<std::set<StateId>> state_sets(parens_.size());
277  for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
278  const auto &arc = aiter.Value();
279  const auto it = paren_map_.find(arc.ilabel);
280  if (it != paren_map_.end()) { // Paren?
281  const auto paren_id = it->second;
282  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
283  for (auto set_iter = FindStates(paren_id, arc.nextstate);
284  !set_iter.Done(); set_iter.Next()) {
285  for (auto paren_arc_iter =
286  FindParenArcs(paren_id, set_iter.Element());
287  !paren_arc_iter.Done(); paren_arc_iter.Next()) {
288  const auto &cparc = paren_arc_iter.Value();
289  UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
290  }
291  }
292  } else { // Close paren.
293  paren_set.insert(paren_id);
294  state_sets[paren_id].insert(s);
295  const State paren_state(paren_id, s);
296  paren_arc_multimap_[paren_state].push_back(arc);
297  }
298  } else { // Non-paren.
299  UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
300  }
301  }
302  std::vector<StateId> state_vec;
303  for (const Label paren_id : paren_set) {
304  paren_multimap_[s].push_back(paren_id);
305 
306  const std::set<StateId> &state_set = state_sets[paren_id];
307  state_vec.assign(state_set.begin(), state_set.end());
308 
309  const State paren_state(paren_id, s);
310  set_map_[paren_state] = state_sets_.FindId(state_vec);
311  }
312 }
313 
314 // Gathers state sets.
315 template <class Arc>
317  StateId nextstate, std::set<Label> *paren_set,
318  std::vector<std::set<StateId>> *state_sets) const {
319  for (auto paren_iter = FindParens(nextstate); !paren_iter.Done();
320  paren_iter.Next()) {
321  const auto paren_id = paren_iter.Value();
322  paren_set->insert(paren_id);
323  for (auto set_iter = FindStates(paren_id, nextstate); !set_iter.Done();
324  set_iter.Next()) {
325  (*state_sets)[paren_id].insert(set_iter.Element());
326  }
327  }
328 }
329 
330 // Stores balancing parenthesis data for a PDT. Unlike PdtParenReachable above
331 // this allows on-the-fly construction (e.g., in PdtShortestPath).
332 template <class Arc>
334  public:
335  using Label = typename Arc::Label;
336  using StateId = typename Arc::StateId;
337 
339  using StateHash = typename State::Hash;
340 
341  // Set for open parens.
342  using OpenParenSet = std::unordered_set<State, StateHash>;
343 
344  // Maps from open paren destination state to parenthesis ID.
345  using OpenParenMap = std::unordered_map<StateId, std::vector<Label>>;
346 
347  // Maps from open paren state to source states of matching close parens
348  using CloseParenMap =
349  std::unordered_map<State, std::vector<StateId>, StateHash>;
350 
351  // Maps from open paren state to close source set ID.
352  using CloseSourceMap = std::unordered_map<State, ssize_t, StateHash>;
353 
355 
357 
358  void Clear() {
359  open_paren_map_.clear();
360  close_paren_map_.clear();
361  }
362 
363  // Adds an open parenthesis with destination state open_dest.
364  void OpenInsert(Label paren_id, StateId open_dest) {
365  const State key(paren_id, open_dest);
366  if (open_paren_set_.insert(key).second) {
367  open_paren_map_[open_dest].push_back(paren_id);
368  }
369  }
370 
371  // Adds a matching closing parenthesis with source state close_source
372  // balancing an open_parenthesis with destination state open_dest if
373  // OpenInsert() previously called.
374  void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
375  const State key(paren_id, open_dest);
376  if (open_paren_set_.count(key)) {
377  close_paren_map_[key].push_back(close_source);
378  }
379  }
380 
381  // Finds close paren source states matching an open parenthesis. The following
382  // methods are then used to iterate through those matching states. Should be
383  // called only after FinishInsert(open_dest).
385  const State key(paren_id, open_dest);
386  const auto it = close_source_map_.find(key);
387  if (it == close_source_map_.end()) {
388  return close_source_sets_.FindSet(-1);
389  } else {
390  return close_source_sets_.FindSet(it->second);
391  }
392  }
393 
394  // Called when all open and close parenthesis insertions (w.r.t. open
395  // parentheses entering state open_dest) are finished. Must be called before
396  // Find(open_dest).
397  void FinishInsert(StateId open_dest) {
398  const auto open_parens = open_paren_map_.find(open_dest);
399  if (open_parens != open_paren_map_.end()) {
400  for (const Label paren_id : open_parens->second) {
401  const State key(paren_id, open_dest);
402  open_paren_set_.erase(key);
403  const auto close_paren_it = close_paren_map_.find(key);
404  CHECK(close_paren_it != close_paren_map_.end());
405  std::vector<StateId> &close_sources = close_paren_it->second;
406  std::sort(close_sources.begin(), close_sources.end());
407  auto unique_end =
408  std::unique(close_sources.begin(), close_sources.end());
409  close_sources.resize(unique_end - close_sources.begin());
410  if (!close_sources.empty()) {
411  close_source_map_[key] = close_source_sets_.FindId(close_sources);
412  }
413  close_paren_map_.erase(close_paren_it);
414  }
415  open_paren_map_.erase(open_parens);
416  }
417  }
418 
419  // Returns a new balance data object representing the reversed balance
420  // information.
421  PdtBalanceData<Arc> *Reverse(StateId num_states, StateId num_split,
422  StateId state_id_shift) const;
423 
424  private:
425  // Open paren at destintation state?
426  OpenParenSet open_paren_set_;
427  // Open parens per state.
428  OpenParenMap open_paren_map_;
429  // Current open destination state.
430  State open_dest_;
431  // Current open paren/state.
432  typename OpenParenMap::const_iterator open_iter_;
433  // Close states to (open paren, state).
434  CloseParenMap close_paren_map_;
435  // (Paren, state) to set ID.
436  CloseSourceMap close_source_map_;
437  mutable Collection<ssize_t, StateId> close_source_sets_;
438 };
439 
440 // Return a new balance data object representing the reversed balance
441 // information.
442 template <class Arc>
444  StateId num_states, StateId num_split, StateId state_id_shift) const {
445  auto *bd = new PdtBalanceData<Arc>;
446  std::unordered_set<StateId> close_sources;
447  const auto split_size = num_states / num_split;
448  for (StateId i = 0; i < num_states; i += split_size) {
449  close_sources.clear();
450  for (auto it = close_source_map_.begin(); it != close_source_map_.end();
451  ++it) {
452  const auto &okey = it->first;
453  const auto open_dest = okey.state_id;
454  const auto paren_id = okey.paren_id;
455  for (auto set_iter = close_source_sets_.FindSet(it->second);
456  !set_iter.Done(); set_iter.Next()) {
457  const auto close_source = set_iter.Element();
458  if ((close_source < i) || (close_source >= i + split_size)) continue;
459  close_sources.insert(close_source + state_id_shift);
460  bd->OpenInsert(paren_id, close_source + state_id_shift);
461  bd->CloseInsert(paren_id, close_source + state_id_shift,
462  open_dest + state_id_shift);
463  }
464  }
465  for (auto it = close_sources.begin(); it != close_sources.end(); ++it) {
466  bd->FinishInsert(*it);
467  }
468  }
469  return bd;
470 }
471 
472 } // namespace internal
473 } // namespace fst
474 
475 #endif // FST_EXTENSIONS_PDT_PAREN_H_
std::unordered_map< State, std::vector< StateId >, StateHash > CloseParenMap
Definition: paren.h:349
size_t operator()(const ParenState< Arc > &pstate) const
Definition: paren.h:63
constexpr int kNoLabel
Definition: fst.h:201
bool Done() const
Definition: paren.h:81
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source)
Definition: paren.h:374
SpanIterator(ValueType *begin, ValueType *end)
Definition: paren.h:78
constexpr int kNoStateId
Definition: fst.h:202
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:34
#define FSTERROR()
Definition: util.h:53
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:354
typename Arc::Label Label
Definition: paren.h:335
std::unordered_map< StateId, std::vector< Label >> ParenMultimap
Definition: paren.h:104
std::unordered_set< State, StateHash > OpenParenSet
Definition: paren.h:342
std::unordered_map< StateId, std::vector< Label >> OpenParenMap
Definition: paren.h:345
ValueType Value() const
Definition: paren.h:82
typename State::Hash StateHash
Definition: paren.h:101
typename Arc::StateId StateId
Definition: paren.h:98
virtual StateId Start() const =0
typename Arc::StateId StateId
Definition: paren.h:45
std::unordered_map< State, std::vector< Arc >, StateHash > ParenArcMultimap
Definition: paren.h:112
void OpenInsert(Label paren_id, StateId open_dest)
Definition: paren.h:364
PdtParenReachable(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, bool close)
Definition: paren.h:124
bool operator!=(const ParenState< Arc > &other) const
Definition: paren.h:58
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:120
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const
Definition: paren.h:180
typename Arc::Label Label
Definition: paren.h:44
ParenState(Label paren_id=kNoLabel, StateId state_id=kNoStateId)
Definition: paren.h:50
PdtBalanceData< Arc > * Reverse(StateId num_states, StateId num_split, StateId state_id_shift) const
Definition: paren.h:443
bool operator==(const ParenState< Arc > &other) const
Definition: paren.h:53
std::unordered_map< State, ssize_t, StateHash > StateSetMap
Definition: paren.h:107
#define DCHECK(x)
Definition: log.h:70
#define CHECK(x)
Definition: log.h:61
typename Arc::Label Label
Definition: paren.h:97
typename State::Hash StateHash
Definition: paren.h:339
void FinishInsert(StateId open_dest)
Definition: paren.h:397
typename Arc::StateId StateId
Definition: paren.h:336
SetIterator Find(Label paren_id, StateId open_dest)
Definition: paren.h:384
ParenIterator FindParens(StateId s) const
Definition: paren.h:151
SetIterator FindStates(Label paren_id, StateId s) const
Definition: paren.h:167
std::unordered_map< State, ssize_t, StateHash > CloseSourceMap
Definition: paren.h:352