20 #ifndef FST_MINIMIZE_H_ 21 #define FST_MINIMIZE_H_ 48 #include <unordered_map> 61 : fst_(fst), partition_(partition) {}
66 const auto xfinal = fst_.
Final(x).Hash();
67 const auto yfinal = fst_.
Final(y).Hash();
68 if (xfinal < yfinal) {
70 }
else if (xfinal > yfinal) {
78 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
79 const auto &arc1 = aiter1.Value();
80 const auto &arc2 = aiter2.Value();
81 if (arc1.ilabel < arc2.ilabel)
return true;
82 if (arc1.ilabel > arc2.ilabel)
return false;
83 if (partition_.
ClassId(arc1.nextstate) <
84 partition_.
ClassId(arc2.nextstate))
86 if (partition_.
ClassId(arc1.nextstate) >
87 partition_.
ClassId(arc2.nextstate))
114 template <
class Arc,
class Queue>
140 class StateILabelHasher {
142 explicit StateILabelHasher(
const Fst<Arc> &
fst) : fst_(fst) {}
144 using Label =
typename Arc::Label;
145 using StateId =
typename Arc::StateId;
148 const size_t p1 = 7603;
149 const size_t p2 = 433024223;
153 const Label this_ilabel = aiter.Value().ilabel;
154 if (this_ilabel != current_ilabel) {
155 result = p1 * result + this_ilabel;
156 current_ilabel = this_ilabel;
166 class ArcIterCompare {
170 const auto &xarc = x->Value();
171 const auto &yarc = y->Value();
172 return xarc.ilabel > yarc.ilabel;
177 std::priority_queue<RevArcIterPtr, std::vector<RevArcIterPtr>,
188 VLOG(5) <<
"PrePartition";
193 std::vector<StateId> state_to_initial_class(num_states);
199 using HashToClassMap = std::unordered_map<size_t, StateId>;
200 HashToClassMap hash_to_class_nonfinal;
201 HashToClassMap hash_to_class_final;
202 StateILabelHasher hasher(fst);
203 for (
StateId s = 0; s < num_states; ++s) {
204 size_t hash = hasher(s);
205 HashToClassMap &this_map =
206 (fst.
Final(s) != Weight::Zero() ? hash_to_class_final
207 : hash_to_class_nonfinal);
209 auto p = this_map.emplace(hash, next_class);
210 state_to_initial_class[s] = p.second ? next_class++ : p.first->second;
215 P_.AllocateClasses(next_class);
216 for (
StateId s = 0; s < num_states; ++s) {
217 P_.Add(s, state_to_initial_class[s]);
219 for (
StateId c = 0; c < next_class; ++c) L_.Enqueue(c);
220 VLOG(5) <<
"Initial Partition: " << P_.NumClasses();
233 P_.Initialize(Tr_.NumStates() - 1);
237 aiter_queue_ = std::make_unique<ArcIterQueue>();
244 const auto s = siter.Value();
245 if (Tr_.NumArcs(s + 1)) {
246 aiter_queue_->push(std::make_unique<RevArcIter>(Tr_, s + 1));
251 Label prev_label = -1;
252 while (!aiter_queue_->empty()) {
258 std::move(const_cast<RevArcIterPtr &>(aiter_queue_->top()));
260 if (aiter->Done())
continue;
261 const auto &arc = aiter->Value();
262 auto from_state = aiter->Value().nextstate - 1;
263 auto from_label = arc.ilabel;
264 if (prev_label != from_label) P_.FinalizeSplit(&L_);
265 auto from_class = P_.ClassId(from_state);
266 if (P_.ClassSize(from_class) > 1) P_.SplitOn(from_state);
267 prev_label = from_label;
269 if (!aiter->Done()) aiter_queue_->push(std::move(aiter));
271 P_.FinalizeSplit(&L_);
277 while (!L_.Empty()) {
278 const auto C = L_.Head();
293 std::unique_ptr<ArcIterQueue> aiter_queue_;
325 class HeightVisitor {
327 HeightVisitor() : max_height_(0), num_states_(0) {}
335 for (
StateId i = height_.size(); i <= s; ++i) height_.push_back(-1);
336 if (s >= num_states_) num_states_ = s + 1;
341 bool TreeArc(
StateId s,
const Arc &arc) {
return true; }
344 bool BackArc(
StateId s,
const Arc &arc) {
return true; }
347 bool ForwardOrCrossArc(
StateId s,
const Arc &arc) {
348 if (height_[arc.nextstate] + 1 > height_[s]) {
349 height_[s] = height_[arc.nextstate] + 1;
355 void FinishState(
StateId s,
StateId parent,
const Arc *parent_arc) {
356 if (height_[s] == -1) height_[s] = 0;
357 const auto h = height_[s] + 1;
359 if (h > height_[parent]) height_[parent] = h;
360 if (h > max_height_) max_height_ = h;
365 void FinishVisit() {}
367 size_t max_height()
const {
return max_height_; }
369 const std::vector<StateId> &height()
const {
return height_; }
371 size_t num_states()
const {
return num_states_; }
374 std::vector<StateId> height_;
383 HeightVisitor hvisitor;
386 partition_.Initialize(hvisitor.num_states());
387 partition_.AllocateClasses(hvisitor.max_height() + 1);
388 const auto &hstates = hvisitor.height();
389 for (
StateId s = 0; s < hstates.size(); ++s) partition_.Add(s, hstates[s]);
394 using EquivalenceMap = std::map<StateId, StateId, StateComparator<Arc>>;
397 auto height = partition_.NumClasses();
398 for (
StateId h = 0; h < height; ++h) {
399 EquivalenceMap equiv_classes(comp);
402 equiv_classes[siter.
Value()] = h;
405 if (insert_result.second) {
406 insert_result.first->second = partition_.AddClass();
411 const auto s = siter.
Value();
412 const auto old_class = partition_.ClassId(s);
413 const auto new_class = equiv_classes[s];
418 if (old_class != new_class) partition_.Move(s, new_class);
435 using StateId =
typename Arc::StateId;
436 std::vector<StateId> state_map(partition.
NumClasses());
439 state_map[i] = siter.
Value();
445 const auto s = siter.Value();
448 auto arc = aiter.Value();
449 arc.nextstate = state_map[partition.
ClassId(arc.nextstate)];
450 if (s == state_map[c]) {
453 fst->
AddArc(state_map[c], std::move(arc));
471 if (fst->
Properties(revuz_props,
true) == revuz_props) {
473 VLOG(2) <<
"Acyclic minimization";
482 VLOG(2) <<
"Cyclic minimization";
506 using Weight =
typename Arc::Weight;
507 static constexpr
auto minimize_props =
509 const auto props = fst->
Properties(minimize_props,
true);
517 FSTERROR() <<
"Cannot minimize a non-deterministic FST over a " 518 "non-idempotent semiring";
520 }
else if (!allow_nondet) {
522 FSTERROR() <<
"Refusing to minimize a non-deterministic FST with " 523 <<
"allow_nondet = false";
543 std::unique_ptr<SymbolTable> osyms(
550 ArcMap(gfst, fst, &mapper);
569 #endif // FST_MINIMIZE_H_ typename Arc::StateId ClassId
void ArcMap(MutableFst< A > *fst, C *mapper)
typename Arc::Weight Weight
void AcceptorMinimize(MutableFst< Arc > *fst)
typename Arc::StateId StateId
typename Arc::Label Label
typename Arc::StateId StateId
virtual uint64_t Properties(uint64_t mask, bool test) const =0
AcyclicMinimizer(const ExpandedFst< Arc > &fst)
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
bool operator()(const StateId x, const StateId y) const
virtual size_t NumArcs(StateId) const =0
typename Arc::Weight Weight
virtual SymbolTable * Copy() const
constexpr uint64_t kError
const Partition< StateId > & GetPartition() const
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
void SetProperties(uint64_t props, uint64_t mask) override
virtual Weight Final(StateId) const =0
virtual void SetStart(StateId)=0
void Connect(MutableFst< Arc > *fst)
std::bool_constant<(W::Properties()&kIdempotent)!=0 > IsIdempotent
void MergeStates(const Partition< typename Arc::StateId > &partition, MutableFst< Arc > *fst)
const T NumClasses() const
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
const SymbolTable * OutputSymbols() const override=0
constexpr uint8_t kEncodeLabels
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kAcyclic
virtual void SetProperties(uint64_t props, uint64_t mask)=0
const Partition< StateId > & GetPartition()
virtual StateId Start() const =0
constexpr float kShortestDelta
constexpr uint64_t kIDeterministic
void StateMap(MutableFst< A > *fst, C *mapper)
const T ClassId(T element_id) const
StateComparator(const Fst< Arc > &fst, const Partition< StateId > &partition)
void Push(MutableFst< Arc > *fst, ReweightType type=REWEIGHT_TO_INITIAL, float delta=kShortestDelta, bool remove_total_weight=false)
constexpr uint8_t kEncodeWeights
virtual void AddArc(StateId, const Arc &)=0
constexpr uint64_t kUnweighted
void Minimize(MutableFst< Arc > *fst, MutableFst< Arc > *sfst=nullptr, float delta=kShortestDelta, bool allow_nondet=false)
CyclicMinimizer(const ExpandedFst< Arc > &fst)
typename Arc::Label Label
virtual void DeleteStates(const std::vector< StateId > &)=0
typename Arc::Weight Weight
virtual void SetOutputSymbols(const SymbolTable *osyms)=0
constexpr uint64_t kWeighted
virtual StateId NumStates() const =0
typename Arc::StateId ClassId
typename Arc::StateId StateId
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
constexpr uint64_t kAcceptor
std::unique_ptr< RevArcIter > RevArcIterPtr