20 #ifndef FST_EXTENSIONS_PDT_PAREN_H_ 21 #define FST_EXTENSIONS_PDT_PAREN_H_ 23 #include <sys/types.h> 38 #include <unordered_map> 39 #include <unordered_set> 49 using Label =
typename Arc::Label;
56 : paren_id(paren_id), state_id(state_id) {}
59 if (&other ==
this)
return true;
64 return !(other == *
this);
69 static constexpr
auto prime = 7853;
84 : begin_(begin), end_(end), it_(begin) {}
86 bool Done()
const {
return it_ == end_; }
112 using StateSetMap = std::unordered_map<State, ssize_t, StateHash>;
130 const std::vector<std::pair<Label, Label>> &parens,
132 : fst_(fst), parens_(parens), close_(close), error_(false) {
133 paren_map_.reserve(2 * parens.size());
134 for (
size_t i = 0; i < parens.size(); ++i) {
135 const auto &pair = parens[i];
136 paren_map_[pair.first] = i;
137 paren_map_[pair.second] = i;
140 const auto start = fst.
Start();
142 if (!DFSearch(start)) {
143 FSTERROR() <<
"PdtReachable: Underlying cyclicity not supported";
147 FSTERROR() <<
"PdtParenReachable: Open paren info not implemented";
152 bool Error()
const {
return error_; }
157 const auto parens = paren_multimap_.find(s);
158 if (parens != paren_multimap_.end()) {
162 DCHECK(!parens->second.empty());
163 return ParenIterator(&*parens->second.begin(), &*parens->second.end());
173 const State paren_state(paren_id, s);
174 const auto it = set_map_.find(paren_state);
175 if (it == set_map_.end()) {
176 return state_sets_.FindSet(-1);
178 return state_sets_.FindSet(it->second);
186 const State paren_state(paren_id, s);
187 const auto paren_arcs = paren_arc_multimap_.find(paren_state);
188 if (paren_arcs != paren_arc_multimap_.end()) {
192 DCHECK(!paren_arcs->second.empty());
194 &*paren_arcs->second.end());
206 void ComputeStateSet(
StateId s);
209 void UpdateStateSet(
StateId nextstate, std::set<Label> *paren_set,
210 std::vector<std::set<StateId>> *state_sets)
const;
214 const std::vector<std::pair<Label, Label>> &parens_;
218 std::unordered_map<Label, Label> paren_map_;
224 std::vector<uint8_t> state_color_;
238 static constexpr uint8_t kWhiteState = 0x01;
239 static constexpr uint8_t kGreyState = 0x02;
240 static constexpr uint8_t kBlackState = 0x04;
241 if (s >= state_color_.size()) state_color_.resize(s + 1, kWhiteState);
242 if (state_color_[s] == kBlackState)
return true;
243 if (state_color_[s] == kGreyState)
return false;
244 state_color_[s] = kGreyState;
246 const auto &arc = aiter.Value();
247 const auto it = paren_map_.find(arc.ilabel);
248 if (it != paren_map_.end()) {
250 if (arc.ilabel == parens_[
paren_id].first) {
251 if (!DFSearch(arc.nextstate))
return false;
252 for (
auto set_iter = FindStates(paren_id, arc.nextstate);
253 !set_iter.Done(); set_iter.Next()) {
257 std::vector<StateId> cp_nextstates;
258 for (
auto paren_arc_iter =
259 FindParenArcs(paren_id, set_iter.Element());
260 !paren_arc_iter.Done(); paren_arc_iter.Next()) {
261 cp_nextstates.push_back(paren_arc_iter.Value().nextstate);
263 for (
const StateId cp_nextstate : cp_nextstates) {
264 if (!DFSearch(cp_nextstate))
return false;
268 }
else if (!DFSearch(arc.nextstate)) {
273 state_color_[s] = kBlackState;
280 std::set<Label> paren_set;
281 std::vector<std::set<StateId>> state_sets(parens_.size());
283 const auto &arc = aiter.Value();
284 const auto it = paren_map_.find(arc.ilabel);
285 if (it != paren_map_.end()) {
287 if (arc.ilabel == parens_[
paren_id].first) {
288 for (
auto set_iter = FindStates(paren_id, arc.nextstate);
289 !set_iter.Done(); set_iter.Next()) {
290 for (
auto paren_arc_iter =
291 FindParenArcs(paren_id, set_iter.Element());
292 !paren_arc_iter.Done(); paren_arc_iter.Next()) {
293 const auto &cparc = paren_arc_iter.Value();
294 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
298 paren_set.insert(paren_id);
300 const State paren_state(paren_id, s);
301 paren_arc_multimap_[paren_state].push_back(arc);
304 UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
307 std::vector<StateId> state_vec;
309 paren_multimap_[s].push_back(
paren_id);
311 const std::set<StateId> &state_set = state_sets[
paren_id];
312 state_vec.assign(state_set.begin(), state_set.end());
315 set_map_[paren_state] = state_sets_.FindId(state_vec);
322 StateId nextstate, std::set<Label> *paren_set,
323 std::vector<std::set<StateId>> *state_sets)
const {
324 for (
auto paren_iter = FindParens(nextstate); !paren_iter.Done();
326 const auto paren_id = paren_iter.Value();
328 for (
auto set_iter = FindStates(
paren_id, nextstate); !set_iter.Done();
330 (*state_sets)[
paren_id].insert(set_iter.Element());
354 std::unordered_map<State, std::vector<StateId>,
StateHash>;
364 open_paren_map_.clear();
365 close_paren_map_.clear();
370 const State key(paren_id, open_dest);
371 if (open_paren_set_.insert(key).second) {
372 open_paren_map_[open_dest].push_back(paren_id);
380 const State key(paren_id, open_dest);
381 if (open_paren_set_.count(key)) {
382 close_paren_map_[key].push_back(close_source);
390 const State key(paren_id, open_dest);
391 const auto it = close_source_map_.find(key);
392 if (it == close_source_map_.end()) {
393 return close_source_sets_.FindSet(-1);
395 return close_source_sets_.FindSet(it->second);
403 const auto open_parens = open_paren_map_.find(open_dest);
404 if (open_parens != open_paren_map_.end()) {
406 const State key(paren_id, open_dest);
407 open_paren_set_.erase(key);
408 const auto close_paren_it = close_paren_map_.find(key);
409 CHECK(close_paren_it != close_paren_map_.end());
410 std::vector<StateId> &close_sources = close_paren_it->second;
411 std::sort(close_sources.begin(), close_sources.end());
413 std::unique(close_sources.begin(), close_sources.end());
414 close_sources.resize(unique_end - close_sources.begin());
415 if (!close_sources.empty()) {
416 close_source_map_[key] = close_source_sets_.FindId(close_sources);
418 close_paren_map_.erase(close_paren_it);
420 open_paren_map_.erase(open_parens);
437 typename OpenParenMap::const_iterator open_iter_;
450 auto bd = fst::make_unique_for_overwrite<PdtBalanceData<Arc>>();
451 std::unordered_set<StateId> close_sources;
452 const auto split_size = num_states / num_split;
453 for (
StateId i = 0; i < num_states; i += split_size) {
454 close_sources.clear();
455 for (
auto it = close_source_map_.begin(); it != close_source_map_.end();
457 const auto &okey = it->first;
458 const auto open_dest = okey.state_id;
459 const auto paren_id = okey.paren_id;
460 for (
auto set_iter = close_source_sets_.FindSet(it->second);
461 !set_iter.Done(); set_iter.Next()) {
462 const auto close_source = set_iter.Element();
463 if ((close_source < i) || (close_source >= i + split_size))
continue;
464 close_sources.insert(close_source + state_id_shift);
465 bd->OpenInsert(
paren_id, close_source + state_id_shift);
466 bd->CloseInsert(
paren_id, close_source + state_id_shift,
467 open_dest + state_id_shift);
470 for (
auto it = close_sources.begin(); it != close_sources.end(); ++it) {
471 bd->FinishInsert(*it);
480 #endif // FST_EXTENSIONS_PDT_PAREN_H_
std::unordered_map< State, std::vector< StateId >, StateHash > CloseParenMap
size_t operator()(const ParenState< Arc > &pstate) const
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source)
SpanIterator(ValueType *begin, ValueType *end)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
typename Collection< ssize_t, StateId >::SetIterator SetIterator
typename Arc::Label Label
std::unordered_map< StateId, std::vector< Label >> ParenMultimap
std::unordered_set< State, StateHash > OpenParenSet
std::unordered_map< StateId, std::vector< Label >> OpenParenMap
typename State::Hash StateHash
typename Arc::StateId StateId
virtual StateId Start() const =0
typename Arc::StateId StateId
std::unordered_map< State, std::vector< Arc >, StateHash > ParenArcMultimap
void OpenInsert(Label paren_id, StateId open_dest)
PdtParenReachable(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, bool close)
bool operator!=(const ParenState< Arc > &other) const
typename Collection< ssize_t, StateId >::SetIterator SetIterator
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const
typename Arc::Label Label
ParenState(Label paren_id=kNoLabel, StateId state_id=kNoStateId)
PdtBalanceData< Arc > * Reverse(StateId num_states, StateId num_split, StateId state_id_shift) const
bool operator==(const ParenState< Arc > &other) const
std::unordered_map< State, ssize_t, StateHash > StateSetMap
typename Arc::Label Label
typename State::Hash StateHash
void FinishInsert(StateId open_dest)
typename Arc::StateId StateId
SetIterator Find(Label paren_id, StateId open_dest)
ParenIterator FindParens(StateId s) const
SetIterator FindStates(Label paren_id, StateId s) const
std::unordered_map< State, ssize_t, StateHash > CloseSourceMap