20 #ifndef FST_DISAMBIGUATE_H_ 21 #define FST_DISAMBIGUATE_H_ 48 using Label =
typename Arc::Label;
53 Weight weight = Weight::Zero(),
63 template <
class Arc,
class Relation>
66 using Label =
typename Arc::Label;
74 using LabelMap = std::multimap<Label, DeterminizeArc<StateTuple>>;
87 r_(std::make_unique<Relation>()),
91 std::vector<StateId> *head =
nullptr)
92 : fst_(fst.Copy()), head_(head), r_(std::move(r)), s_(
kNoStateId) {}
95 template <
class Filter>
98 head_(filter->GetHeadStates()),
99 r_(std::move(*filter).GetRelation()),
105 : fst_(
fst ?
fst->Copy() : filter.fst_->Copy()),
107 r_(std::make_unique<Relation>(*filter.r_)),
117 is_final_ = fst_->Final(head) != Weight::Zero();
119 if (head_->size() <= s) head_->resize(s + 1,
kNoStateId);
127 bool FilterArc(
const Arc &arc,
const Element &src_element,
132 return is_final_ ? final_weight : Weight::Zero();
139 std::unique_ptr<Relation>
GetRelation() && {
return std::move(r_); }
145 void InitLabelMap(
LabelMap *label_map)
const;
147 std::unique_ptr<Fst<Arc>> fst_;
148 std::vector<StateId> *head_;
150 std::unique_ptr<Relation> r_;
156 template <
class Arc,
class Relation>
158 const Arc &arc,
const Element &src_element,
const Element &dest_element,
161 if (label_map->empty()) InitLabelMap(label_map);
163 for (
auto liter = label_map->lower_bound(arc.ilabel);
164 liter != label_map->end() && liter->first == arc.ilabel; ++liter) {
165 const auto &dest_tuple = liter->second.dest_tuple;
166 const auto dest_head = dest_tuple->filter_state.GetState();
167 if ((*r_)(dest_element.state_id, dest_head)) {
168 dest_tuple->subset.push_front(dest_element);
175 template <
class Arc,
class Relation>
178 const auto src_head = tuple_->filter_state.GetState();
183 const auto &arc = aiter.Value();
185 if (arc.ilabel == label && arc.nextstate == nextstate)
continue;
188 label_map->emplace(arc.ilabel, std::move(det_arc));
190 nextstate = arc.nextstate;
204 using ArcId = std::pair<StateId, ssize_t>;
214 PreDisambiguate(sfst, ofst, opts);
216 FindAmbiguities(*ofst);
219 RemoveAmbiguities(ofst);
228 bool operator()(
const Arc &arc1,
const Arc &arc2)
const {
229 return arc1.ilabel < arc2.ilabel ||
230 (arc1.ilabel == arc2.ilabel && arc1.nextstate < arc2.nextstate);
233 uint64_t Properties(uint64_t props)
const {
243 explicit ArcIdCompare(
const std::vector<StateId> &head) : head_(head) {}
245 bool operator()(
const ArcId &a1,
const ArcId &a2)
const {
247 const auto src1 = a1.first;
248 const auto src2 = a2.first;
249 const auto head1 = head_[src1];
250 const auto head2 = head_[src2];
251 if (head1 < head2)
return true;
252 if (head2 < head1)
return false;
254 if (src1 < src2)
return true;
255 if (src2 < src1)
return false;
257 return a1.second < a2.second;
261 const std::vector<StateId> &head_;
268 using StateTuple =
typename StateTable::StateTuple;
272 FSTERROR() <<
"Disambiguate::CommonFuture: FST not provided";
275 explicit CommonFuture(
const Fst<Arc> &ifst) {
284 std::vector<bool> coaccess;
288 for (
StateId s = 0; s < coaccess.size(); ++s) {
290 related_.insert(opts.
state_table->Tuple(s).StatePair());
293 if (trans)
delete fsa;
297 return related_.count(std::make_pair(s1, s2)) > 0;
304 std::set<std::pair<StateId, StateId>> related_;
307 using ArcIdMap = std::multimap<ArcId, ArcId, ArcIdCompare>;
312 candidates_->insert(head_[s1] > head_[s2] ? std::make_pair(a1, a2)
313 : std::make_pair(a2, a1));
318 if (aid.second == -1) {
322 aiter.
Seek(aid.second);
323 return aiter.
Value();
341 void MarkAmbiguities();
352 std::set<std::pair<StateId, StateId>> coreachable_;
356 std::list<std::pair<StateId, StateId>> queue_;
359 std::vector<StateId> head_;
364 std::unique_ptr<ArcIdMap> candidates_;
367 std::set<ArcId> ambiguous_;
370 std::unique_ptr<UnionFind<StateId>> merge_;
387 auto common_future = std::make_unique<CommonFuture>(ifst);
391 nopts.filter =
new Filter(ifst, std::move(common_future), &head_);
413 FSTERROR() <<
"Disambiguate: Weight must have path property to use " 414 <<
"pruning options: " << Weight::Type();
426 candidates_ = std::make_unique<ArcIdMap>(ArcIdCompare(head_));
427 const auto start_pr = std::make_pair(fst.
Start(), fst.
Start());
428 coreachable_.insert(start_pr);
429 queue_.push_back(start_pr);
430 while (!queue_.empty()) {
431 const auto &pr = queue_.front();
432 const auto s1 = pr.first;
433 const auto s2 = pr.second;
435 FindAmbiguousPairs(fst, s1, s2);
442 if (fst.
NumArcs(s2) > fst.
NumArcs(s1)) FindAmbiguousPairs(fst, s2, s1);
446 const auto &arc1 = aiter.Value();
447 const ArcId a1(s1, aiter.Position());
448 if (matcher.
Find(arc1.ilabel)) {
449 for (; !matcher.
Done(); matcher.
Next()) {
450 const auto &arc2 = matcher.
Value();
452 if (arc2.ilabel ==
kNoLabel)
continue;
455 if (s1 != s2 && arc1.nextstate == arc2.nextstate) {
456 InsertCandidate(s1, s2, a1, a2);
458 const auto spr = arc1.nextstate <= arc2.nextstate
459 ? std::make_pair(arc1.nextstate, arc2.nextstate)
460 : std::make_pair(arc2.nextstate, arc1.nextstate);
462 if (coreachable_.insert(spr).second) {
464 if (spr.first != spr.second &&
465 head_[spr.first] == head_[spr.second]) {
467 merge_ = std::make_unique<UnionFind<StateId>>(fst.
NumStates(),
471 merge_->Union(spr.first, spr.second);
473 queue_.push_back(spr);
480 if (s1 != s2 && fst.
Final(s1) != Weight::Zero() &&
481 fst.
Final(s2) != Weight::Zero()) {
482 const ArcId a1(s1, -1);
483 const ArcId a2(s2, -1);
484 InsertCandidate(s1, s2, a1, a2);
490 if (!candidates_)
return;
491 for (
auto it = candidates_->begin(); it != candidates_->end(); ++it) {
492 const auto a = it->first;
493 const auto b = it->second;
495 if (ambiguous_.count(b) == 0) ambiguous_.insert(a);
497 coreachable_.clear();
508 !aiter.Done(); aiter.Next()) {
509 auto arc = aiter.Value();
510 const auto nextstate = merge_->FindSet(arc.nextstate);
511 if (nextstate != arc.nextstate) {
512 arc.nextstate = nextstate;
518 coreachable_.clear();
521 FindAmbiguities(*ofst);
523 FSTERROR() <<
"Disambiguate: Unable to remove spurious ambiguities";
531 if (ambiguous_.empty())
return;
534 for (
auto it = ambiguous_.begin(); it != ambiguous_.end(); ++it) {
535 const auto pos = it->second;
539 auto arc = aiter.
Value();
540 arc.nextstate = dead;
543 ofst->
SetFinal(it->first, Weight::Zero());
582 #endif // FST_DISAMBIGUATE_H_ static uint64_t Properties(uint64_t props)
typename Arc::Label Label
void SetState(StateId s) final
FilterState Start() const
constexpr uint64_t kArcSortProperties
virtual uint64_t Properties(uint64_t mask, bool test) const =0
RelationDeterminizeFilter(const Fst< Arc > &fst)
typename Arc::StateId StateId
virtual size_t NumArcs(StateId) const =0
RelationDeterminizeFilter(const RelationDeterminizeFilter &filter, const Fst< Arc > *fst=nullptr)
const Arc & Value() const final
std::multimap< Label, DeterminizeArc< StateTuple >> LabelMap
typename Arc::Weight Weight
constexpr uint64_t kError
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
virtual Weight Final(StateId) const =0
typename Arc::Label Label
void Connect(MutableFst< Arc > *fst)
void SetValue(const Arc &arc)
void SetState(StateId s, const StateTuple &tuple)
RelationDeterminizeFilter(const Fst< Arc > &fst, std::unique_ptr< Relation > r, std::vector< StateId > *head=nullptr)
constexpr uint64_t kODeterministic
const Arc & Value() const
const Arc & Value() const
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
typename StateTuple::Subset Subset
RelationDeterminizeFilter(const Fst< Arc > &fst, std::unique_ptr< Filter > filter)
std::unique_ptr< Relation > GetRelation()&&
std::pair< StateId, ssize_t > ArcId
bool Find(Label match_label) final
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kOLabelSorted
constexpr uint64_t kNotAcceptor
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename Arc::Weight Weight
Label subsequential_label
std::forward_list< Element > Subset
virtual StateId Start() const =0
constexpr uint64_t kIDeterministic
typename Arc::Label Label
Weight FilterFinal(const Weight final_weight, const Element &element) const
typename StateTuple::Element Element
constexpr uint64_t kILabelSorted
typename Arc::StateId StateId
std::vector< StateId > * GetHeadStates()
typename Arc::Weight Weight
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
std::unique_ptr< StateTuple > dest_tuple
DisambiguateOptions(float delta=kDelta, Weight weight=Weight::Zero(), StateId n=kNoStateId, Label label=0)
virtual StateId NumStates() const =0
typename Arc::StateId StateId
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
constexpr uint64_t kAcceptor