20 #ifndef FST_DISAMBIGUATE_H_ 21 #define FST_DISAMBIGUATE_H_ 23 #include <sys/types.h> 61 using Label =
typename Arc::Label;
66 Weight weight = Weight::Zero(),
76 template <
class Arc,
class Relation>
77 class RelationDeterminizeFilter {
79 using Label =
typename Arc::Label;
87 using LabelMap = std::multimap<Label, DeterminizeArc<StateTuple>>;
100 r_(std::make_unique<Relation>()),
104 std::vector<StateId> *head =
nullptr)
105 : fst_(fst.Copy()), head_(head), r_(std::move(r)), s_(
kNoStateId) {}
108 template <
class Filter>
111 head_(filter->GetHeadStates()),
112 r_(std::move(*filter).GetRelation()),
118 : fst_(
fst ?
fst->Copy() : filter.fst_->Copy()),
120 r_(std::make_unique<Relation>(*filter.r_)),
130 is_final_ = fst_->Final(head) != Weight::Zero();
132 if (head_->size() <= s) head_->resize(s + 1,
kNoStateId);
140 bool FilterArc(
const Arc &arc,
const Element &src_element,
145 return is_final_ ? final_weight : Weight::Zero();
152 std::unique_ptr<Relation>
GetRelation() && {
return std::move(r_); }
158 void InitLabelMap(
LabelMap *label_map)
const;
160 std::unique_ptr<Fst<Arc>> fst_;
161 std::vector<StateId> *head_;
163 std::unique_ptr<Relation> r_;
169 template <
class Arc,
class Relation>
171 const Arc &arc,
const Element &src_element,
const Element &dest_element,
174 if (label_map->empty()) InitLabelMap(label_map);
176 for (
auto liter = label_map->lower_bound(arc.ilabel);
177 liter != label_map->end() && liter->first == arc.ilabel; ++liter) {
178 const auto &dest_tuple = liter->second.dest_tuple;
179 const auto dest_head = dest_tuple->filter_state.GetState();
180 if ((*r_)(dest_element.state_id, dest_head)) {
181 dest_tuple->subset.push_front(dest_element);
188 template <
class Arc,
class Relation>
191 const auto src_head = tuple_->filter_state.GetState();
196 const auto &arc = aiter.Value();
198 if (arc.ilabel == label && arc.nextstate == nextstate)
continue;
201 label_map->emplace(arc.ilabel, std::move(det_arc));
203 nextstate = arc.nextstate;
217 using ArcId = std::pair<StateId, ssize_t>;
227 PreDisambiguate(sfst, ofst, opts);
229 FindAmbiguities(*ofst);
232 RemoveAmbiguities(ofst);
241 bool operator()(
const Arc &arc1,
const Arc &arc2)
const {
242 return arc1.ilabel < arc2.ilabel ||
243 (arc1.ilabel == arc2.ilabel && arc1.nextstate < arc2.nextstate);
246 uint64_t Properties(uint64_t props)
const {
256 explicit ArcIdCompare(
const std::vector<StateId> &head) : head_(head) {}
258 bool operator()(
const ArcId &a1,
const ArcId &a2)
const {
260 const auto src1 = a1.first;
261 const auto src2 = a2.first;
262 const auto head1 = head_[src1];
263 const auto head2 = head_[src2];
264 if (head1 < head2)
return true;
265 if (head2 < head1)
return false;
267 if (src1 < src2)
return true;
268 if (src2 < src1)
return false;
270 return a1.second < a2.second;
274 const std::vector<StateId> &head_;
281 using StateTuple =
typename StateTable::StateTuple;
285 FSTERROR() <<
"Disambiguate::CommonFuture: FST not provided";
288 explicit CommonFuture(
const Fst<Arc> &ifst) {
297 std::vector<bool> coaccess;
301 for (
StateId s = 0; s < coaccess.size(); ++s) {
303 related_.insert(opts.
state_table->Tuple(s).StatePair());
306 if (trans)
delete fsa;
310 return related_.count(std::make_pair(s1, s2)) > 0;
317 std::set<std::pair<StateId, StateId>> related_;
320 using ArcIdMap = std::multimap<ArcId, ArcId, ArcIdCompare>;
325 candidates_->insert(head_[s1] > head_[s2] ? std::make_pair(a1, a2)
326 : std::make_pair(a2, a1));
331 if (aid.second == -1) {
335 aiter.
Seek(aid.second);
336 return aiter.
Value();
354 void MarkAmbiguities();
365 std::set<std::pair<StateId, StateId>> coreachable_;
369 std::list<std::pair<StateId, StateId>> queue_;
372 std::vector<StateId> head_;
377 std::unique_ptr<ArcIdMap> candidates_;
380 std::set<ArcId> ambiguous_;
383 std::unique_ptr<UnionFind<StateId>> merge_;
400 auto common_future = std::make_unique<CommonFuture>(ifst);
404 nopts.filter =
new Filter(ifst, std::move(common_future), &head_);
426 FSTERROR() <<
"Disambiguate: Weight must have path property to use " 427 <<
"pruning options: " << Weight::Type();
439 candidates_ = std::make_unique<ArcIdMap>(ArcIdCompare(head_));
440 const auto start_pr = std::make_pair(fst.
Start(), fst.
Start());
441 coreachable_.insert(start_pr);
442 queue_.push_back(start_pr);
443 while (!queue_.empty()) {
444 const auto &pr = queue_.front();
445 const auto s1 = pr.first;
446 const auto s2 = pr.second;
448 FindAmbiguousPairs(fst, s1, s2);
455 if (fst.
NumArcs(s2) > fst.
NumArcs(s1)) FindAmbiguousPairs(fst, s2, s1);
459 const auto &arc1 = aiter.Value();
460 const ArcId a1(s1, aiter.Position());
461 if (matcher.
Find(arc1.ilabel)) {
462 for (; !matcher.
Done(); matcher.
Next()) {
463 const auto &arc2 = matcher.
Value();
465 if (arc2.ilabel ==
kNoLabel)
continue;
468 if (s1 != s2 && arc1.nextstate == arc2.nextstate) {
469 InsertCandidate(s1, s2, a1, a2);
471 const auto spr = arc1.nextstate <= arc2.nextstate
472 ? std::make_pair(arc1.nextstate, arc2.nextstate)
473 : std::make_pair(arc2.nextstate, arc1.nextstate);
475 if (coreachable_.insert(spr).second) {
477 if (spr.first != spr.second &&
478 head_[spr.first] == head_[spr.second]) {
480 merge_ = std::make_unique<UnionFind<StateId>>(fst.
NumStates(),
484 merge_->Union(spr.first, spr.second);
486 queue_.push_back(spr);
493 if (s1 != s2 && fst.
Final(s1) != Weight::Zero() &&
494 fst.
Final(s2) != Weight::Zero()) {
495 const ArcId a1(s1, -1);
496 const ArcId a2(s2, -1);
497 InsertCandidate(s1, s2, a1, a2);
503 if (!candidates_)
return;
504 for (
auto it = candidates_->begin(); it != candidates_->end(); ++it) {
505 const auto a = it->first;
506 const auto b = it->second;
508 if (ambiguous_.count(b) == 0) ambiguous_.insert(a);
510 coreachable_.clear();
521 !aiter.Done(); aiter.Next()) {
522 auto arc = aiter.Value();
523 const auto nextstate = merge_->FindSet(arc.nextstate);
524 if (nextstate != arc.nextstate) {
525 arc.nextstate = nextstate;
531 coreachable_.clear();
534 FindAmbiguities(*ofst);
536 FSTERROR() <<
"Disambiguate: Unable to remove spurious ambiguities";
544 if (ambiguous_.empty())
return;
547 for (
auto it = ambiguous_.begin(); it != ambiguous_.end(); ++it) {
548 const auto pos = it->second;
552 auto arc = aiter.
Value();
553 arc.nextstate = dead;
556 ofst->
SetFinal(it->first, Weight::Zero());
595 #endif // FST_DISAMBIGUATE_H_
static uint64_t Properties(uint64_t props)
typename Arc::Label Label
void SetState(StateId s) final
FilterState Start() const
constexpr uint64_t kArcSortProperties
virtual uint64_t Properties(uint64_t mask, bool test) const =0
RelationDeterminizeFilter(const Fst< Arc > &fst)
typename Arc::StateId StateId
virtual size_t NumArcs(StateId) const =0
RelationDeterminizeFilter(const RelationDeterminizeFilter &filter, const Fst< Arc > *fst=nullptr)
const Arc & Value() const final
std::multimap< Label, DeterminizeArc< StateTuple >> LabelMap
typename Arc::Weight Weight
constexpr uint64_t kError
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
virtual Weight Final(StateId) const =0
typename Arc::Label Label
void Connect(MutableFst< Arc > *fst)
void SetValue(const Arc &arc)
void SetState(StateId s, const StateTuple &tuple)
RelationDeterminizeFilter(const Fst< Arc > &fst, std::unique_ptr< Relation > r, std::vector< StateId > *head=nullptr)
constexpr uint64_t kODeterministic
const Arc & Value() const
const Arc & Value() const
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
typename StateTuple::Subset Subset
RelationDeterminizeFilter(const Fst< Arc > &fst, std::unique_ptr< Filter > filter)
std::unique_ptr< Relation > GetRelation()&&
std::pair< StateId, ssize_t > ArcId
bool Find(Label match_label) final
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kOLabelSorted
constexpr uint64_t kNotAcceptor
virtual void SetProperties(uint64_t props, uint64_t mask)=0
typename Arc::Weight Weight
Label subsequential_label
std::forward_list< Element > Subset
virtual StateId Start() const =0
constexpr uint64_t kIDeterministic
typename Arc::Label Label
Weight FilterFinal(const Weight final_weight, const Element &element) const
typename StateTuple::Element Element
constexpr uint64_t kILabelSorted
typename Arc::StateId StateId
std::vector< StateId > * GetHeadStates()
typename Arc::Weight Weight
virtual StateId AddState()=0
virtual void SetFinal(StateId s, Weight weight=Weight::One())=0
std::unique_ptr< StateTuple > dest_tuple
DisambiguateOptions(float delta=kDelta, Weight weight=Weight::Zero(), StateId n=kNoStateId, Label label=0)
virtual StateId NumStates() const =0
typename Arc::StateId StateId
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
constexpr uint64_t kAcceptor