20 #ifndef FST_MINIMIZE_H_ 21 #define FST_MINIMIZE_H_ 56 #include <unordered_map> 69 : fst_(fst), partition_(partition) {}
74 const auto xfinal = fst_.
Final(x).Hash();
75 const auto yfinal = fst_.
Final(y).Hash();
76 if (xfinal < yfinal) {
78 }
else if (xfinal > yfinal) {
86 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
87 const auto &arc1 = aiter1.Value();
88 const auto &arc2 = aiter2.Value();
89 if (arc1.ilabel < arc2.ilabel)
return true;
90 if (arc1.ilabel > arc2.ilabel)
return false;
91 if (partition_.
ClassId(arc1.nextstate) <
92 partition_.
ClassId(arc2.nextstate))
94 if (partition_.
ClassId(arc1.nextstate) >
95 partition_.
ClassId(arc2.nextstate))
122 template <
class Arc,
class Queue>
148 class StateILabelHasher {
150 explicit StateILabelHasher(
const Fst<Arc> &
fst) : fst_(fst) {}
152 using Label =
typename Arc::Label;
153 using StateId =
typename Arc::StateId;
156 const size_t p1 = 7603;
157 const size_t p2 = 433024223;
161 const Label this_ilabel = aiter.Value().ilabel;
162 if (this_ilabel != current_ilabel) {
163 result = p1 * result + this_ilabel;
164 current_ilabel = this_ilabel;
174 class ArcIterCompare {
178 const auto &xarc = x->Value();
179 const auto &yarc = y->Value();
180 return xarc.ilabel > yarc.ilabel;
185 std::priority_queue<RevArcIterPtr, std::vector<RevArcIterPtr>,
196 VLOG(5) <<
"PrePartition";
201 std::vector<StateId> state_to_initial_class(num_states);
207 using HashToClassMap = std::unordered_map<size_t, StateId>;
208 HashToClassMap hash_to_class_nonfinal;
209 HashToClassMap hash_to_class_final;
210 StateILabelHasher hasher(fst);
211 for (
StateId s = 0; s < num_states; ++s) {
212 size_t hash = hasher(s);
213 HashToClassMap &this_map =
214 (fst.
Final(s) != Weight::Zero() ? hash_to_class_final
215 : hash_to_class_nonfinal);
217 auto p = this_map.emplace(hash, next_class);
218 state_to_initial_class[s] = p.second ? next_class++ : p.first->second;
223 P_.AllocateClasses(next_class);
224 for (
StateId s = 0; s < num_states; ++s) {
225 P_.Add(s, state_to_initial_class[s]);
227 for (
StateId c = 0; c < next_class; ++c) L_.Enqueue(c);
228 VLOG(5) <<
"Initial Partition: " << P_.NumClasses();
241 P_.Initialize(Tr_.NumStates() - 1);
245 aiter_queue_ = std::make_unique<ArcIterQueue>();
252 const auto s = siter.Value();
253 if (Tr_.NumArcs(s + 1)) {
254 aiter_queue_->push(std::make_unique<RevArcIter>(Tr_, s + 1));
259 Label prev_label = -1;
260 while (!aiter_queue_->empty()) {
266 std::move(const_cast<RevArcIterPtr &>(aiter_queue_->top()));
268 if (aiter->Done())
continue;
269 const auto &arc = aiter->Value();
270 auto from_state = aiter->Value().nextstate - 1;
271 auto from_label = arc.ilabel;
272 if (prev_label != from_label) P_.FinalizeSplit(&L_);
273 auto from_class = P_.ClassId(from_state);
274 if (P_.ClassSize(from_class) > 1) P_.SplitOn(from_state);
275 prev_label = from_label;
277 if (!aiter->Done()) aiter_queue_->push(std::move(aiter));
279 P_.FinalizeSplit(&L_);
285 while (!L_.Empty()) {
286 const auto C = L_.Head();
301 std::unique_ptr<ArcIterQueue> aiter_queue_;
333 class HeightVisitor {
335 HeightVisitor() : max_height_(0), num_states_(0) {}
343 for (
StateId i = height_.size(); i <= s; ++i) height_.push_back(-1);
344 if (s >= num_states_) num_states_ = s + 1;
349 bool TreeArc(
StateId s,
const Arc &arc) {
return true; }
352 bool BackArc(
StateId s,
const Arc &arc) {
return true; }
355 bool ForwardOrCrossArc(
StateId s,
const Arc &arc) {
356 if (height_[arc.nextstate] + 1 > height_[s]) {
357 height_[s] = height_[arc.nextstate] + 1;
363 void FinishState(
StateId s,
StateId parent,
const Arc *parent_arc) {
364 if (height_[s] == -1) height_[s] = 0;
365 const auto h = height_[s] + 1;
367 if (h > height_[parent]) height_[parent] = h;
368 if (h > max_height_) max_height_ = h;
373 void FinishVisit() {}
375 size_t max_height()
const {
return max_height_; }
377 const std::vector<StateId> &height()
const {
return height_; }
379 size_t num_states()
const {
return num_states_; }
382 std::vector<StateId> height_;
391 HeightVisitor hvisitor;
394 partition_.Initialize(hvisitor.num_states());
395 partition_.AllocateClasses(hvisitor.max_height() + 1);
396 const auto &hstates = hvisitor.height();
397 for (
StateId s = 0; s < hstates.size(); ++s) partition_.Add(s, hstates[s]);
402 using EquivalenceMap = std::map<StateId, StateId, StateComparator<Arc>>;
405 auto height = partition_.NumClasses();
406 for (
StateId h = 0; h < height; ++h) {
407 EquivalenceMap equiv_classes(comp);
410 equiv_classes[siter.
Value()] = h;
413 if (insert_result.second) {
414 insert_result.first->second = partition_.AddClass();
419 const auto s = siter.
Value();
420 const auto old_class = partition_.ClassId(s);
421 const auto new_class = equiv_classes[s];
426 if (old_class != new_class) partition_.Move(s, new_class);
443 using StateId =
typename Arc::StateId;
444 std::vector<StateId> state_map(partition.
NumClasses());
447 state_map[i] = siter.
Value();
453 const auto s = siter.Value();
456 auto arc = aiter.Value();
457 arc.nextstate = state_map[partition.
ClassId(arc.nextstate)];
458 if (s == state_map[c]) {
461 fst->
AddArc(state_map[c], std::move(arc));
479 if (fst->
Properties(revuz_props,
true) == revuz_props) {
481 VLOG(2) <<
"Acyclic minimization";
490 VLOG(2) <<
"Cyclic minimization";
514 using Weight =
typename Arc::Weight;
515 static constexpr
auto minimize_props =
517 const auto props = fst->
Properties(minimize_props,
true);
525 FSTERROR() <<
"Cannot minimize a non-deterministic FST over a " 526 "non-idempotent semiring";
528 }
else if (!allow_nondet) {
530 FSTERROR() <<
"Refusing to minimize a non-deterministic FST with " 531 <<
"allow_nondet = false";
551 std::unique_ptr<SymbolTable> osyms(
558 ArcMap(gfst, fst, &mapper);
577 #endif // FST_MINIMIZE_H_
typename Arc::StateId ClassId
void ArcMap(MutableFst< A > *fst, C *mapper)
typename Arc::Weight Weight
void AcceptorMinimize(MutableFst< Arc > *fst)
typename Arc::StateId StateId
typename Arc::Label Label
typename Arc::StateId StateId
virtual uint64_t Properties(uint64_t mask, bool test) const =0
AcyclicMinimizer(const ExpandedFst< Arc > &fst)
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
bool operator()(const StateId x, const StateId y) const
virtual size_t NumArcs(StateId) const =0
typename Arc::Weight Weight
virtual SymbolTable * Copy() const
constexpr uint64_t kError
const Partition< StateId > & GetPartition() const
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
void Connect(MutableFst< Arc > *fst)
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
void MergeStates(const Partition< typename Arc::StateId > &partition, MutableFst< Arc > *fst)
const T NumClasses() const
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
const SymbolTable * OutputSymbols() const override=0
constexpr uint8_t kEncodeLabels
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kAcyclic
virtual void SetProperties(uint64_t props, uint64_t mask)=0
const Partition< StateId > & GetPartition()
virtual StateId Start() const =0
constexpr float kShortestDelta
constexpr uint64_t kIDeterministic
void StateMap(MutableFst< A > *fst, C *mapper)
void SetProperties(uint64_t props, uint64_t mask) override
const T ClassId(T element_id) const
StateComparator(const Fst< Arc > &fst, const Partition< StateId > &partition)
void Push(MutableFst< Arc > *fst, ReweightType type=REWEIGHT_TO_INITIAL, float delta=kShortestDelta, bool remove_total_weight=false)
constexpr uint8_t kEncodeWeights
virtual void AddArc(StateId, const Arc &)=0
constexpr uint64_t kUnweighted
void Minimize(MutableFst< Arc > *fst, MutableFst< Arc > *sfst=nullptr, float delta=kShortestDelta, bool allow_nondet=false)
CyclicMinimizer(const ExpandedFst< Arc > &fst)
typename Arc::Label Label
virtual void DeleteStates(const std::vector< StateId > &)=0
typename Arc::Weight Weight
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
constexpr uint64_t kWeighted
virtual StateId NumStates() const =0
typename Arc::StateId ClassId
typename Arc::StateId StateId
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
constexpr uint64_t kAcceptor
std::unique_ptr< RevArcIter > RevArcIterPtr