6 #ifndef FST_EXTENSIONS_PDT_PAREN_H_ 7 #define FST_EXTENSIONS_PDT_PAREN_H_ 11 #include <unordered_map> 12 #include <unordered_set> 30 using Label =
typename Arc::Label;
37 : paren_id(paren_id), state_id(state_id) {}
40 if (&other ==
this)
return true;
45 return !(other == *
this);
50 static constexpr
auto prime = 7853;
64 : begin_(it), end_(map.end()), it_(it) {}
66 bool Done()
const {
return it_ == end_ || it_->first != begin_->first; }
85 using Label =
typename Arc::Label;
95 using StateSetMap = std::unordered_map<State, ssize_t, StateHash>;
110 const std::vector<std::pair<Label, Label>> &parens,
112 : fst_(fst), parens_(parens), close_(close), error_(false) {
113 paren_map_.reserve(2 * parens.size());
114 for (
size_t i = 0; i < parens.size(); ++i) {
115 const auto &pair = parens[i];
116 paren_map_[pair.first] = i;
117 paren_map_[pair.second] = i;
120 const auto start = fst.
Start();
122 if (!DFSearch(start)) {
123 FSTERROR() <<
"PdtReachable: Underlying cyclicity not supported";
127 FSTERROR() <<
"PdtParenReachable: Open paren info not implemented";
132 bool Error()
const {
return error_; }
137 return ParenIterator(paren_multimap_, paren_multimap_.find(s));
144 const State paren_state(paren_id, s);
145 const auto it = set_map_.find(paren_state);
146 if (it == set_map_.end()) {
147 return state_sets_.FindSet(-1);
149 return state_sets_.FindSet(it->second);
157 const State paren_state(paren_id, s);
159 paren_arc_multimap_.find(paren_state));
168 void ComputeStateSet(
StateId s);
171 void UpdateStateSet(
StateId nextstate, std::set<Label> *paren_set,
172 std::vector<std::set<StateId>> *state_sets)
const;
176 const std::vector<std::pair<Label, Label>> &parens_;
180 std::unordered_map<Label, Label> paren_map_;
186 std::vector<uint8> state_color_;
200 static constexpr
uint8 kWhiteState = 0x01;
201 static constexpr
uint8 kGreyState = 0x02;
202 static constexpr
uint8 kBlackState = 0x04;
203 if (s >= state_color_.size()) state_color_.resize(s + 1, kWhiteState);
204 if (state_color_[s] == kBlackState)
return true;
205 if (state_color_[s] == kGreyState)
return false;
206 state_color_[s] = kGreyState;
208 const auto &arc = aiter.Value();
209 const auto it = paren_map_.find(arc.ilabel);
210 if (it != paren_map_.end()) {
212 if (arc.ilabel == parens_[
paren_id].first) {
213 if (!DFSearch(arc.nextstate))
return false;
214 for (
auto set_iter = FindStates(paren_id, arc.nextstate);
215 !set_iter.Done(); set_iter.Next()) {
216 for (
auto paren_arc_iter =
217 FindParenArcs(paren_id, set_iter.Element());
218 !paren_arc_iter.Done(); paren_arc_iter.Next()) {
219 const auto &cparc = paren_arc_iter.Value();
220 if (!DFSearch(cparc.nextstate))
return false;
224 }
else if (!DFSearch(arc.nextstate)) {
229 state_color_[s] = kBlackState;
236 std::set<Label> paren_set;
237 std::vector<std::set<StateId>> state_sets(parens_.size());
239 const auto &arc = aiter.Value();
240 const auto it = paren_map_.find(arc.ilabel);
241 if (it != paren_map_.end()) {
243 if (arc.ilabel == parens_[
paren_id].first) {
244 for (
auto set_iter = FindStates(paren_id, arc.nextstate);
245 !set_iter.Done(); set_iter.Next()) {
246 for (
auto paren_arc_iter =
247 FindParenArcs(paren_id, set_iter.Element());
248 !paren_arc_iter.Done(); paren_arc_iter.Next()) {
249 const auto &cparc = paren_arc_iter.Value();
250 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
254 paren_set.insert(paren_id);
256 const State paren_state(paren_id, s);
257 paren_arc_multimap_.insert(std::make_pair(paren_state, arc));
260 UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
263 std::vector<StateId> state_set;
264 for (
auto paren_iter = paren_set.begin(); paren_iter != paren_set.end();
268 paren_multimap_.insert(std::make_pair(s,
paren_id));
269 for (
auto state_iter = state_sets[
paren_id].begin();
270 state_iter != state_sets[
paren_id].end(); ++state_iter) {
271 state_set.push_back(*state_iter);
274 set_map_[paren_state] = state_sets_.FindId(state_set);
281 StateId nextstate, std::set<Label> *paren_set,
282 std::vector<std::set<StateId>> *state_sets)
const {
283 for (
auto paren_iter = FindParens(nextstate); !paren_iter.Done();
285 const auto paren_id = paren_iter.Value();
287 for (
auto set_iter = FindStates(
paren_id, nextstate); !set_iter.Done();
289 (*state_sets)[
paren_id].insert(set_iter.Element());
322 open_paren_map_.clear();
323 close_paren_map_.clear();
328 const State key(paren_id, open_dest);
329 if (!open_paren_set_.count(key)) {
330 open_paren_set_.insert(key);
331 open_paren_map_.emplace(open_dest, paren_id);
339 const State key(paren_id, open_dest);
340 if (open_paren_set_.count(key)) {
341 close_paren_map_.emplace(key, close_source);
349 const State key(paren_id, open_dest);
350 const auto it = close_source_map_.find(key);
351 if (it == close_source_map_.end()) {
352 return close_source_sets_.FindSet(-1);
354 return close_source_sets_.FindSet(it->second);
362 std::vector<StateId> close_sources;
363 for (
auto oit = open_paren_map_.find(open_dest);
364 oit != open_paren_map_.end() && oit->first == open_dest;) {
366 close_sources.clear();
368 open_paren_set_.erase(open_paren_set_.find(key));
369 for (
auto cit = close_paren_map_.find(key);
370 cit != close_paren_map_.end() && cit->first == key;) {
371 close_sources.push_back(cit->second);
372 close_paren_map_.erase(cit++);
374 std::sort(close_sources.begin(), close_sources.end());
375 auto unique_end = std::unique(close_sources.begin(), close_sources.end());
376 close_sources.resize(unique_end - close_sources.begin());
377 if (!close_sources.empty()) {
378 close_source_map_[key] = close_source_sets_.FindId(close_sources);
380 open_paren_map_.erase(oit++);
397 typename OpenParenMap::const_iterator open_iter_;
411 std::unordered_set<StateId> close_sources;
412 const auto split_size = num_states / num_split;
413 for (
StateId i = 0; i < num_states; i += split_size) {
414 close_sources.clear();
415 for (
auto it = close_source_map_.begin(); it != close_source_map_.end();
417 const auto &okey = it->first;
418 const auto open_dest = okey.state_id;
419 const auto paren_id = okey.paren_id;
420 for (
auto set_iter = close_source_sets_.FindSet(it->second);
421 !set_iter.Done(); set_iter.Next()) {
422 const auto close_source = set_iter.Element();
423 if ((close_source < i) || (close_source >= i + split_size))
continue;
424 close_sources.insert(close_source + state_id_shift);
425 bd->OpenInsert(
paren_id, close_source + state_id_shift);
426 bd->CloseInsert(
paren_id, close_source + state_id_shift,
427 open_dest + state_id_shift);
430 for (
auto it = close_sources.begin(); it != close_sources.end(); ++it) {
431 bd->FinishInsert(*it);
440 #endif // FST_EXTENSIONS_PDT_PAREN_H_ size_t operator()(const ParenState< Arc > &pstate) const
std::unordered_map< State, Arc, StateHash > ParenArcMultimap
std::unordered_multimap< State, StateId, StateHash > CloseParenMap
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source)
std::unordered_multimap< StateId, Label > ParenMultimap
std::unordered_multimap< StateId, Label > OpenParenMap
typename Map::const_iterator StlIterator
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
void Map(MutableFst< A > *fst, C *mapper)
typename Collection< ssize_t, StateId >::SetIterator SetIterator
typename Arc::Label Label
typename Map::mapped_type ValueType
std::unordered_set< State, StateHash > OpenParenSet
typename State::Hash StateHash
typename Arc::StateId StateId
virtual StateId Start() const =0
typename Arc::StateId StateId
void OpenInsert(Label paren_id, StateId open_dest)
PdtParenReachable(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, bool close)
MapIterator(const Map &map, StlIterator it)
bool operator!=(const ParenState< Arc > &other) const
typename Collection< ssize_t, StateId >::SetIterator SetIterator
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const
typename Arc::Label Label
ParenState(Label paren_id=kNoLabel, StateId state_id=kNoStateId)
PdtBalanceData< Arc > * Reverse(StateId num_states, StateId num_split, StateId state_id_shift) const
bool operator==(const ParenState< Arc > &other) const
std::unordered_map< State, ssize_t, StateHash > StateSetMap
typename Arc::Label Label
typename State::Hash StateHash
void FinishInsert(StateId open_dest)
typename Arc::StateId StateId
SetIterator Find(Label paren_id, StateId open_dest)
ParenIterator FindParens(StateId s) const
SetIterator FindStates(Label paren_id, StateId s) const
std::unordered_map< State, ssize_t, StateHash > CloseSourceMap