20 #ifndef FST_EXTENSIONS_PDT_PAREN_H_ 21 #define FST_EXTENSIONS_PDT_PAREN_H_ 33 #include <unordered_map> 34 #include <unordered_set> 44 using Label =
typename Arc::Label;
51 : paren_id(paren_id), state_id(state_id) {}
54 if (&other ==
this)
return true;
59 return !(other == *
this);
64 static constexpr
auto prime = 7853;
79 : begin_(begin), end_(end), it_(begin) {}
81 bool Done()
const {
return it_ == end_; }
97 using Label =
typename Arc::Label;
107 using StateSetMap = std::unordered_map<State, ssize_t, StateHash>;
125 const std::vector<std::pair<Label, Label>> &parens,
127 : fst_(fst), parens_(parens), close_(close), error_(false) {
128 paren_map_.reserve(2 * parens.size());
129 for (
size_t i = 0; i < parens.size(); ++i) {
130 const auto &pair = parens[i];
131 paren_map_[pair.first] = i;
132 paren_map_[pair.second] = i;
135 const auto start = fst.
Start();
137 if (!DFSearch(start)) {
138 FSTERROR() <<
"PdtReachable: Underlying cyclicity not supported";
142 FSTERROR() <<
"PdtParenReachable: Open paren info not implemented";
147 bool Error()
const {
return error_; }
152 const auto parens = paren_multimap_.find(s);
153 if (parens != paren_multimap_.end()) {
157 DCHECK(!parens->second.empty());
158 return ParenIterator(&*parens->second.begin(), &*parens->second.end());
168 const State paren_state(paren_id, s);
169 const auto it = set_map_.find(paren_state);
170 if (it == set_map_.end()) {
171 return state_sets_.FindSet(-1);
173 return state_sets_.FindSet(it->second);
181 const State paren_state(paren_id, s);
182 const auto paren_arcs = paren_arc_multimap_.find(paren_state);
183 if (paren_arcs != paren_arc_multimap_.end()) {
187 DCHECK(!paren_arcs->second.empty());
189 &*paren_arcs->second.end());
201 void ComputeStateSet(
StateId s);
204 void UpdateStateSet(
StateId nextstate, std::set<Label> *paren_set,
205 std::vector<std::set<StateId>> *state_sets)
const;
209 const std::vector<std::pair<Label, Label>> &parens_;
213 std::unordered_map<Label, Label> paren_map_;
219 std::vector<uint8_t> state_color_;
233 static constexpr uint8_t kWhiteState = 0x01;
234 static constexpr uint8_t kGreyState = 0x02;
235 static constexpr uint8_t kBlackState = 0x04;
236 if (s >= state_color_.size()) state_color_.resize(s + 1, kWhiteState);
237 if (state_color_[s] == kBlackState)
return true;
238 if (state_color_[s] == kGreyState)
return false;
239 state_color_[s] = kGreyState;
241 const auto &arc = aiter.Value();
242 const auto it = paren_map_.find(arc.ilabel);
243 if (it != paren_map_.end()) {
245 if (arc.ilabel == parens_[
paren_id].first) {
246 if (!DFSearch(arc.nextstate))
return false;
247 for (
auto set_iter = FindStates(paren_id, arc.nextstate);
248 !set_iter.Done(); set_iter.Next()) {
252 std::vector<StateId> cp_nextstates;
253 for (
auto paren_arc_iter =
254 FindParenArcs(paren_id, set_iter.Element());
255 !paren_arc_iter.Done(); paren_arc_iter.Next()) {
256 cp_nextstates.push_back(paren_arc_iter.Value().nextstate);
258 for (
const StateId cp_nextstate : cp_nextstates) {
259 if (!DFSearch(cp_nextstate))
return false;
263 }
else if (!DFSearch(arc.nextstate)) {
268 state_color_[s] = kBlackState;
275 std::set<Label> paren_set;
276 std::vector<std::set<StateId>> state_sets(parens_.size());
278 const auto &arc = aiter.Value();
279 const auto it = paren_map_.find(arc.ilabel);
280 if (it != paren_map_.end()) {
282 if (arc.ilabel == parens_[
paren_id].first) {
283 for (
auto set_iter = FindStates(paren_id, arc.nextstate);
284 !set_iter.Done(); set_iter.Next()) {
285 for (
auto paren_arc_iter =
286 FindParenArcs(paren_id, set_iter.Element());
287 !paren_arc_iter.Done(); paren_arc_iter.Next()) {
288 const auto &cparc = paren_arc_iter.Value();
289 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
293 paren_set.insert(paren_id);
295 const State paren_state(paren_id, s);
296 paren_arc_multimap_[paren_state].push_back(arc);
299 UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
302 std::vector<StateId> state_vec;
304 paren_multimap_[s].push_back(
paren_id);
306 const std::set<StateId> &state_set = state_sets[
paren_id];
307 state_vec.assign(state_set.begin(), state_set.end());
310 set_map_[paren_state] = state_sets_.FindId(state_vec);
317 StateId nextstate, std::set<Label> *paren_set,
318 std::vector<std::set<StateId>> *state_sets)
const {
319 for (
auto paren_iter = FindParens(nextstate); !paren_iter.Done();
321 const auto paren_id = paren_iter.Value();
323 for (
auto set_iter = FindStates(
paren_id, nextstate); !set_iter.Done();
325 (*state_sets)[
paren_id].insert(set_iter.Element());
349 std::unordered_map<State, std::vector<StateId>,
StateHash>;
359 open_paren_map_.clear();
360 close_paren_map_.clear();
365 const State key(paren_id, open_dest);
366 if (open_paren_set_.insert(key).second) {
367 open_paren_map_[open_dest].push_back(paren_id);
375 const State key(paren_id, open_dest);
376 if (open_paren_set_.count(key)) {
377 close_paren_map_[key].push_back(close_source);
385 const State key(paren_id, open_dest);
386 const auto it = close_source_map_.find(key);
387 if (it == close_source_map_.end()) {
388 return close_source_sets_.FindSet(-1);
390 return close_source_sets_.FindSet(it->second);
398 const auto open_parens = open_paren_map_.find(open_dest);
399 if (open_parens != open_paren_map_.end()) {
401 const State key(paren_id, open_dest);
402 open_paren_set_.erase(key);
403 const auto close_paren_it = close_paren_map_.find(key);
404 CHECK(close_paren_it != close_paren_map_.end());
405 std::vector<StateId> &close_sources = close_paren_it->second;
406 std::sort(close_sources.begin(), close_sources.end());
408 std::unique(close_sources.begin(), close_sources.end());
409 close_sources.resize(unique_end - close_sources.begin());
410 if (!close_sources.empty()) {
411 close_source_map_[key] = close_source_sets_.FindId(close_sources);
413 close_paren_map_.erase(close_paren_it);
415 open_paren_map_.erase(open_parens);
432 typename OpenParenMap::const_iterator open_iter_;
446 std::unordered_set<StateId> close_sources;
447 const auto split_size = num_states / num_split;
448 for (
StateId i = 0; i < num_states; i += split_size) {
449 close_sources.clear();
450 for (
auto it = close_source_map_.begin(); it != close_source_map_.end();
452 const auto &okey = it->first;
453 const auto open_dest = okey.state_id;
454 const auto paren_id = okey.paren_id;
455 for (
auto set_iter = close_source_sets_.FindSet(it->second);
456 !set_iter.Done(); set_iter.Next()) {
457 const auto close_source = set_iter.Element();
458 if ((close_source < i) || (close_source >= i + split_size))
continue;
459 close_sources.insert(close_source + state_id_shift);
460 bd->OpenInsert(
paren_id, close_source + state_id_shift);
461 bd->CloseInsert(
paren_id, close_source + state_id_shift,
462 open_dest + state_id_shift);
465 for (
auto it = close_sources.begin(); it != close_sources.end(); ++it) {
466 bd->FinishInsert(*it);
475 #endif // FST_EXTENSIONS_PDT_PAREN_H_
std::unordered_map< State, std::vector< StateId >, StateHash > CloseParenMap
size_t operator()(const ParenState< Arc > &pstate) const
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source)
SpanIterator(ValueType *begin, ValueType *end)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
typename Collection< ssize_t, StateId >::SetIterator SetIterator
typename Arc::Label Label
std::unordered_map< StateId, std::vector< Label >> ParenMultimap
std::unordered_set< State, StateHash > OpenParenSet
std::unordered_map< StateId, std::vector< Label >> OpenParenMap
typename State::Hash StateHash
typename Arc::StateId StateId
virtual StateId Start() const =0
typename Arc::StateId StateId
std::unordered_map< State, std::vector< Arc >, StateHash > ParenArcMultimap
void OpenInsert(Label paren_id, StateId open_dest)
PdtParenReachable(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, bool close)
bool operator!=(const ParenState< Arc > &other) const
typename Collection< ssize_t, StateId >::SetIterator SetIterator
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const
typename Arc::Label Label
ParenState(Label paren_id=kNoLabel, StateId state_id=kNoStateId)
PdtBalanceData< Arc > * Reverse(StateId num_states, StateId num_split, StateId state_id_shift) const
bool operator==(const ParenState< Arc > &other) const
std::unordered_map< State, ssize_t, StateHash > StateSetMap
typename Arc::Label Label
typename State::Hash StateHash
void FinishInsert(StateId open_dest)
typename Arc::StateId StateId
SetIterator Find(Label paren_id, StateId open_dest)
ParenIterator FindParens(StateId s) const
SetIterator FindStates(Label paren_id, StateId s) const
std::unordered_map< State, ssize_t, StateHash > CloseSourceMap