FST  openfst-1.8.3
OpenFst Library
paren.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Common classes for PDT parentheses.
19 
20 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
21 #define FST_EXTENSIONS_PDT_PAREN_H_
22 
23 #include <sys/types.h>
24 
25 #include <algorithm>
26 #include <cstddef>
27 #include <cstdint>
28 #include <set>
29 #include <utility>
30 #include <vector>
31 
32 #include <fst/log.h>
34 #include <fst/extensions/pdt/pdt.h>
35 #include <fst/dfs-visit.h>
36 #include <fst/fst.h>
37 #include <fst/util.h>
38 #include <unordered_map>
39 #include <unordered_set>
40 
41 namespace fst {
42 namespace internal {
43 
44 // ParenState: Pair of an open (close) parenthesis and its destination (source)
45 // state.
46 
47 template <class Arc>
48 struct ParenState {
49  using Label = typename Arc::Label;
50  using StateId = typename Arc::StateId;
51 
52  Label paren_id; // ID of open (close) paren.
53  StateId state_id; // Destination (source) state of open (close) paren.
54 
55  explicit ParenState(Label paren_id = kNoLabel, StateId state_id = kNoStateId)
56  : paren_id(paren_id), state_id(state_id) {}
57 
58  bool operator==(const ParenState<Arc> &other) const {
59  if (&other == this) return true;
60  return other.paren_id == paren_id && other.state_id == state_id;
61  }
62 
63  bool operator!=(const ParenState<Arc> &other) const {
64  return !(other == *this);
65  }
66 
67  struct Hash {
68  size_t operator()(const ParenState<Arc> &pstate) const {
69  static constexpr auto prime = 7853;
70  return pstate.paren_id + pstate.state_id * prime;
71  }
72  };
73 };
74 
75 // Creates an FST-style const iterator from range of contiguous values
76 // in memory.
77 template <class V>
78 class SpanIterator {
79  public:
80  using ValueType = const V;
81 
82  SpanIterator() = default;
83  explicit SpanIterator(ValueType *begin, ValueType *end)
84  : begin_(begin), end_(end), it_(begin) {}
85 
86  bool Done() const { return it_ == end_; }
87  ValueType Value() const { return *it_; }
88  void Next() { ++it_; }
89  void Reset() { it_ = begin_; }
90 
91  private:
92  ValueType *const begin_ = nullptr;
93  ValueType *const end_ = nullptr;
94  ValueType *it_ = nullptr;
95 };
96 
97 // PdtParenReachable: Provides various parenthesis reachability information.
98 
99 template <class Arc>
101  public:
102  using Label = typename Arc::Label;
103  using StateId = typename Arc::StateId;
104 
106  using StateHash = typename State::Hash;
107 
108  // Maps from state ID to reachable paren IDs from (to) that state.
109  using ParenMultimap = std::unordered_map<StateId, std::vector<Label>>;
110 
111  // Maps from paren ID and state ID to reachable state set ID.
112  using StateSetMap = std::unordered_map<State, ssize_t, StateHash>;
113 
114  // Maps from paren ID and state ID to arcs exiting that state with that
115  // Label.
116  using ParenArcMultimap =
117  std::unordered_map<State, std::vector<Arc>, StateHash>;
118 
119  using ParenIterator =
121 
122  using ParenArcIterator =
124 
126 
127  // Computes close (open) parenthesis reachability information for a PDT with
128  // bounded stack.
130  const std::vector<std::pair<Label, Label>> &parens,
131  bool close)
132  : fst_(fst), parens_(parens), close_(close), error_(false) {
133  paren_map_.reserve(2 * parens.size());
134  for (size_t i = 0; i < parens.size(); ++i) {
135  const auto &pair = parens[i];
136  paren_map_[pair.first] = i;
137  paren_map_[pair.second] = i;
138  }
139  if (close_) {
140  const auto start = fst.Start();
141  if (start == kNoStateId) return;
142  if (!DFSearch(start)) {
143  FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
144  error_ = true;
145  }
146  } else {
147  FSTERROR() << "PdtParenReachable: Open paren info not implemented";
148  error_ = true;
149  }
150  }
151 
152  bool Error() const { return error_; }
153 
154  // Given a state ID, returns an iterator over paren IDs for close (open)
155  // parens reachable from that state along balanced paths.
157  const auto parens = paren_multimap_.find(s);
158  if (parens != paren_multimap_.end()) {
159  // Cannot dereference iterators if the vector is empty, but that never
160  // happens. ComputeStateSet always adds something to the vector,
161  // and never leaves an empty vector.
162  DCHECK(!parens->second.empty());
163  return ParenIterator(&*parens->second.begin(), &*parens->second.end());
164  } else {
165  return ParenIterator();
166  }
167  }
168 
169  // Given a paren ID and a state ID s, returns an iterator over states that can
170  // be reached along balanced paths from (to) s that have have close (open)
171  // parentheses matching the paren ID exiting (entering) those states.
173  const State paren_state(paren_id, s);
174  const auto it = set_map_.find(paren_state);
175  if (it == set_map_.end()) {
176  return state_sets_.FindSet(-1);
177  } else {
178  return state_sets_.FindSet(it->second);
179  }
180  }
181 
182  // Given a paren ID and a state ID s, return an iterator over arcs that exit
183  // (enter) s and are labeled with a close (open) parenthesis matching the
184  // paren ID.
186  const State paren_state(paren_id, s);
187  const auto paren_arcs = paren_arc_multimap_.find(paren_state);
188  if (paren_arcs != paren_arc_multimap_.end()) {
189  // Cannot dereference iterators if the vector is empty, but that never
190  // happens. ComputeStateSet always adds something to the vector,
191  // and never leaves an empty vector.
192  DCHECK(!paren_arcs->second.empty());
193  return ParenArcIterator(&*paren_arcs->second.begin(),
194  &*paren_arcs->second.end());
195  } else {
196  return ParenArcIterator();
197  }
198  }
199 
200  private:
201  // Returns false when cycle detected during DFS gathering paren and state set
202  // information.
203  bool DFSearch(StateId s);
204 
205  // Unions state sets together gathered by the DFS.
206  void ComputeStateSet(StateId s);
207 
208  // Gathers state set(s) from state.
209  void UpdateStateSet(StateId nextstate, std::set<Label> *paren_set,
210  std::vector<std::set<StateId>> *state_sets) const;
211 
212  const Fst<Arc> &fst_;
213  // Paren IDs to labels.
214  const std::vector<std::pair<Label, Label>> &parens_;
215  // Close/open paren info?
216  const bool close_;
217  // Labels to paren IDs.
218  std::unordered_map<Label, Label> paren_map_;
219  // Paren reachability.
220  ParenMultimap paren_multimap_;
221  // Paren arcs.
222  ParenArcMultimap paren_arc_multimap_;
223  // DFS states.
224  std::vector<uint8_t> state_color_;
225  // Reachable states to IDs.
226  mutable Collection<ssize_t, StateId> state_sets_;
227  // IDs to reachable states.
228  StateSetMap set_map_;
229  bool error_;
230 
231  PdtParenReachable(const PdtParenReachable &) = delete;
232  PdtParenReachable &operator=(const PdtParenReachable &) = delete;
233 };
234 
235 // Gathers paren and state set information.
236 template <class Arc>
238  static constexpr uint8_t kWhiteState = 0x01; // Undiscovered.
239  static constexpr uint8_t kGreyState = 0x02; // Discovered & unfinished.
240  static constexpr uint8_t kBlackState = 0x04; // Finished.
241  if (s >= state_color_.size()) state_color_.resize(s + 1, kWhiteState);
242  if (state_color_[s] == kBlackState) return true;
243  if (state_color_[s] == kGreyState) return false;
244  state_color_[s] = kGreyState;
245  for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
246  const auto &arc = aiter.Value();
247  const auto it = paren_map_.find(arc.ilabel);
248  if (it != paren_map_.end()) { // Paren?
249  const auto paren_id = it->second;
250  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
251  if (!DFSearch(arc.nextstate)) return false;
252  for (auto set_iter = FindStates(paren_id, arc.nextstate);
253  !set_iter.Done(); set_iter.Next()) {
254  // Recursive DFSearch call may modify paren_arc_multimap_ via
255  // ComputeStateSet, so save the paren arcs to avoid issues
256  // with iterator invalidation.
257  std::vector<StateId> cp_nextstates;
258  for (auto paren_arc_iter =
259  FindParenArcs(paren_id, set_iter.Element());
260  !paren_arc_iter.Done(); paren_arc_iter.Next()) {
261  cp_nextstates.push_back(paren_arc_iter.Value().nextstate);
262  }
263  for (const StateId cp_nextstate : cp_nextstates) {
264  if (!DFSearch(cp_nextstate)) return false;
265  }
266  }
267  }
268  } else if (!DFSearch(arc.nextstate)) { // Non-paren.
269  return false;
270  }
271  }
272  ComputeStateSet(s);
273  state_color_[s] = kBlackState;
274  return true;
275 }
276 
277 // Unions state sets.
278 template <class Arc>
280  std::set<Label> paren_set;
281  std::vector<std::set<StateId>> state_sets(parens_.size());
282  for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
283  const auto &arc = aiter.Value();
284  const auto it = paren_map_.find(arc.ilabel);
285  if (it != paren_map_.end()) { // Paren?
286  const auto paren_id = it->second;
287  if (arc.ilabel == parens_[paren_id].first) { // Open paren?
288  for (auto set_iter = FindStates(paren_id, arc.nextstate);
289  !set_iter.Done(); set_iter.Next()) {
290  for (auto paren_arc_iter =
291  FindParenArcs(paren_id, set_iter.Element());
292  !paren_arc_iter.Done(); paren_arc_iter.Next()) {
293  const auto &cparc = paren_arc_iter.Value();
294  UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
295  }
296  }
297  } else { // Close paren.
298  paren_set.insert(paren_id);
299  state_sets[paren_id].insert(s);
300  const State paren_state(paren_id, s);
301  paren_arc_multimap_[paren_state].push_back(arc);
302  }
303  } else { // Non-paren.
304  UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
305  }
306  }
307  std::vector<StateId> state_vec;
308  for (const Label paren_id : paren_set) {
309  paren_multimap_[s].push_back(paren_id);
310 
311  const std::set<StateId> &state_set = state_sets[paren_id];
312  state_vec.assign(state_set.begin(), state_set.end());
313 
314  const State paren_state(paren_id, s);
315  set_map_[paren_state] = state_sets_.FindId(state_vec);
316  }
317 }
318 
319 // Gathers state sets.
320 template <class Arc>
322  StateId nextstate, std::set<Label> *paren_set,
323  std::vector<std::set<StateId>> *state_sets) const {
324  for (auto paren_iter = FindParens(nextstate); !paren_iter.Done();
325  paren_iter.Next()) {
326  const auto paren_id = paren_iter.Value();
327  paren_set->insert(paren_id);
328  for (auto set_iter = FindStates(paren_id, nextstate); !set_iter.Done();
329  set_iter.Next()) {
330  (*state_sets)[paren_id].insert(set_iter.Element());
331  }
332  }
333 }
334 
335 // Stores balancing parenthesis data for a PDT. Unlike PdtParenReachable above
336 // this allows on-the-fly construction (e.g., in PdtShortestPath).
337 template <class Arc>
339  public:
340  using Label = typename Arc::Label;
341  using StateId = typename Arc::StateId;
342 
344  using StateHash = typename State::Hash;
345 
346  // Set for open parens.
347  using OpenParenSet = std::unordered_set<State, StateHash>;
348 
349  // Maps from open paren destination state to parenthesis ID.
350  using OpenParenMap = std::unordered_map<StateId, std::vector<Label>>;
351 
352  // Maps from open paren state to source states of matching close parens
353  using CloseParenMap =
354  std::unordered_map<State, std::vector<StateId>, StateHash>;
355 
356  // Maps from open paren state to close source set ID.
357  using CloseSourceMap = std::unordered_map<State, ssize_t, StateHash>;
358 
360 
361  PdtBalanceData() = default;
362 
363  void Clear() {
364  open_paren_map_.clear();
365  close_paren_map_.clear();
366  }
367 
368  // Adds an open parenthesis with destination state open_dest.
369  void OpenInsert(Label paren_id, StateId open_dest) {
370  const State key(paren_id, open_dest);
371  if (open_paren_set_.insert(key).second) {
372  open_paren_map_[open_dest].push_back(paren_id);
373  }
374  }
375 
376  // Adds a matching closing parenthesis with source state close_source
377  // balancing an open_parenthesis with destination state open_dest if
378  // OpenInsert() previously called.
379  void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
380  const State key(paren_id, open_dest);
381  if (open_paren_set_.count(key)) {
382  close_paren_map_[key].push_back(close_source);
383  }
384  }
385 
386  // Finds close paren source states matching an open parenthesis. The following
387  // methods are then used to iterate through those matching states. Should be
388  // called only after FinishInsert(open_dest).
390  const State key(paren_id, open_dest);
391  const auto it = close_source_map_.find(key);
392  if (it == close_source_map_.end()) {
393  return close_source_sets_.FindSet(-1);
394  } else {
395  return close_source_sets_.FindSet(it->second);
396  }
397  }
398 
399  // Called when all open and close parenthesis insertions (w.r.t. open
400  // parentheses entering state open_dest) are finished. Must be called before
401  // Find(open_dest).
402  void FinishInsert(StateId open_dest) {
403  const auto open_parens = open_paren_map_.find(open_dest);
404  if (open_parens != open_paren_map_.end()) {
405  for (const Label paren_id : open_parens->second) {
406  const State key(paren_id, open_dest);
407  open_paren_set_.erase(key);
408  const auto close_paren_it = close_paren_map_.find(key);
409  CHECK(close_paren_it != close_paren_map_.end());
410  std::vector<StateId> &close_sources = close_paren_it->second;
411  std::sort(close_sources.begin(), close_sources.end());
412  auto unique_end =
413  std::unique(close_sources.begin(), close_sources.end());
414  close_sources.resize(unique_end - close_sources.begin());
415  if (!close_sources.empty()) {
416  close_source_map_[key] = close_source_sets_.FindId(close_sources);
417  }
418  close_paren_map_.erase(close_paren_it);
419  }
420  open_paren_map_.erase(open_parens);
421  }
422  }
423 
424  // Returns a new balance data object representing the reversed balance
425  // information.
426  PdtBalanceData<Arc> *Reverse(StateId num_states, StateId num_split,
427  StateId state_id_shift) const;
428 
429  private:
430  // Open paren at destintation state?
431  OpenParenSet open_paren_set_;
432  // Open parens per state.
433  OpenParenMap open_paren_map_;
434  // Current open destination state.
435  State open_dest_;
436  // Current open paren/state.
437  typename OpenParenMap::const_iterator open_iter_;
438  // Close states to (open paren, state).
439  CloseParenMap close_paren_map_;
440  // (Paren, state) to set ID.
441  CloseSourceMap close_source_map_;
442  mutable Collection<ssize_t, StateId> close_source_sets_;
443 };
444 
445 // Return a new balance data object representing the reversed balance
446 // information.
447 template <class Arc>
449  StateId num_states, StateId num_split, StateId state_id_shift) const {
450  auto bd = fst::make_unique_for_overwrite<PdtBalanceData<Arc>>();
451  std::unordered_set<StateId> close_sources;
452  const auto split_size = num_states / num_split;
453  for (StateId i = 0; i < num_states; i += split_size) {
454  close_sources.clear();
455  for (auto it = close_source_map_.begin(); it != close_source_map_.end();
456  ++it) {
457  const auto &okey = it->first;
458  const auto open_dest = okey.state_id;
459  const auto paren_id = okey.paren_id;
460  for (auto set_iter = close_source_sets_.FindSet(it->second);
461  !set_iter.Done(); set_iter.Next()) {
462  const auto close_source = set_iter.Element();
463  if ((close_source < i) || (close_source >= i + split_size)) continue;
464  close_sources.insert(close_source + state_id_shift);
465  bd->OpenInsert(paren_id, close_source + state_id_shift);
466  bd->CloseInsert(paren_id, close_source + state_id_shift,
467  open_dest + state_id_shift);
468  }
469  }
470  for (auto it = close_sources.begin(); it != close_sources.end(); ++it) {
471  bd->FinishInsert(*it);
472  }
473  }
474  return bd.release();
475 }
476 
477 } // namespace internal
478 } // namespace fst
479 
480 #endif // FST_EXTENSIONS_PDT_PAREN_H_
std::unordered_map< State, std::vector< StateId >, StateHash > CloseParenMap
Definition: paren.h:354
size_t operator()(const ParenState< Arc > &pstate) const
Definition: paren.h:68
constexpr int kNoLabel
Definition: fst.h:195
bool Done() const
Definition: paren.h:86
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source)
Definition: paren.h:379
SpanIterator(ValueType *begin, ValueType *end)
Definition: paren.h:83
constexpr int kNoStateId
Definition: fst.h:196
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
Definition: reverse.h:36
#define FSTERROR()
Definition: util.h:56
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:359
typename Arc::Label Label
Definition: paren.h:340
std::unordered_map< StateId, std::vector< Label >> ParenMultimap
Definition: paren.h:109
std::unordered_set< State, StateHash > OpenParenSet
Definition: paren.h:347
std::unordered_map< StateId, std::vector< Label >> OpenParenMap
Definition: paren.h:350
ValueType Value() const
Definition: paren.h:87
typename State::Hash StateHash
Definition: paren.h:106
typename Arc::StateId StateId
Definition: paren.h:103
virtual StateId Start() const =0
typename Arc::StateId StateId
Definition: paren.h:50
std::unordered_map< State, std::vector< Arc >, StateHash > ParenArcMultimap
Definition: paren.h:117
void OpenInsert(Label paren_id, StateId open_dest)
Definition: paren.h:369
PdtParenReachable(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, bool close)
Definition: paren.h:129
bool operator!=(const ParenState< Arc > &other) const
Definition: paren.h:63
typename Collection< ssize_t, StateId >::SetIterator SetIterator
Definition: paren.h:125
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const
Definition: paren.h:185
typename Arc::Label Label
Definition: paren.h:49
ParenState(Label paren_id=kNoLabel, StateId state_id=kNoStateId)
Definition: paren.h:55
PdtBalanceData< Arc > * Reverse(StateId num_states, StateId num_split, StateId state_id_shift) const
Definition: paren.h:448
bool operator==(const ParenState< Arc > &other) const
Definition: paren.h:58
std::unordered_map< State, ssize_t, StateHash > StateSetMap
Definition: paren.h:112
#define DCHECK(x)
Definition: log.h:74
#define CHECK(x)
Definition: log.h:65
typename Arc::Label Label
Definition: paren.h:102
typename State::Hash StateHash
Definition: paren.h:344
void FinishInsert(StateId open_dest)
Definition: paren.h:402
typename Arc::StateId StateId
Definition: paren.h:341
SetIterator Find(Label paren_id, StateId open_dest)
Definition: paren.h:389
ParenIterator FindParens(StateId s) const
Definition: paren.h:156
SetIterator FindStates(Label paren_id, StateId s) const
Definition: paren.h:172
std::unordered_map< State, ssize_t, StateHash > CloseSourceMap
Definition: paren.h:357