20 #ifndef FST_EXTENSIONS_MPDT_EXPAND_H_ 21 #define FST_EXTENSIONS_MPDT_EXPAND_H_ 47 :
CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
61 using Label =
typename Arc::Label;
84 const std::vector<std::pair<Label, Label>> &parens,
85 const std::vector<Label> &assignments,
89 stack_(opts.stack ? opts.stack : new
ParenStack(parens, assignments)),
90 state_table_(opts.state_table ? opts.state_table
92 own_stack_(!opts.stack),
93 own_state_table_(!opts.state_table),
94 keep_parentheses_(opts.keep_parentheses) {
104 fst_(impl.fst_->Copy(true)),
108 own_state_table_(true),
109 keep_parentheses_(impl.keep_parentheses_) {
117 if (own_stack_)
delete stack_;
118 if (own_state_table_)
delete state_table_;
123 const auto s = fst_->Start();
126 const auto start = state_table_->FindState(tuple);
134 const auto &tuple = state_table_->Tuple(s);
135 const auto weight = fst_->Final(tuple.state_id);
136 SetFinal(s, (weight != Weight::Zero() && tuple.stack_id == 0)
144 if (!HasArcs(s)) ExpandState(s);
149 if (!HasArcs(s)) ExpandState(s);
154 if (!HasArcs(s)) ExpandState(s);
159 if (!HasArcs(s)) ExpandState(s);
166 const auto tuple = state_table_->Tuple(s);
169 auto arc = aiter.Value();
170 const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
171 if (stack_id == -1) {
173 }
else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
174 arc.ilabel = arc.olabel = 0;
176 const StateTuple ntuple(arc.nextstate, stack_id);
177 arc.nextstate = state_table_->FindState(ntuple);
186 return *state_table_;
190 std::unique_ptr<const Fst<Arc>> fst_;
193 const bool own_stack_;
194 const bool own_state_table_;
195 const bool keep_parentheses_;
231 const std::vector<std::pair<Label, Label>> &parens,
232 const std::vector<Label> &assignments)
237 const std::vector<std::pair<Label, Label>> &parens,
238 const std::vector<Label> &assignments,
241 std::make_shared<
Impl>(fst, parens, assignments, opts)) {}
255 GetMutableImpl()->InitArcIterator(s, data);
261 return GetImpl()->GetStateTable();
296 data->
base = std::make_unique<StateIterator<MPdtExpandFst<Arc>>>(*this);
304 : connect(connect), keep_parentheses(keep_parentheses) {}
318 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
320 const std::vector<typename Arc::Label> &assignments,
MutableFst<Arc> *ofst,
340 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
342 const std::vector<typename Arc::Label> &assignments,
MutableFst<Arc> *ofst,
343 bool connect =
true,
bool keep_parentheses =
false) {
345 Expand(ifst, parens, assignments, ofst, opts);
350 #endif // FST_EXTENSIONS_MPDT_EXPAND_H_ ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Store::State State
void InitStateIterator(StateIteratorData< Arc > *data) const override
const ParenStack & GetStack() const
typename Arc::Weight Weight
MPdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const std::vector< Label > &assignments, const MPdtExpandFstOptions< Arc > &opts)
size_t NumInputEpsilons(StateId s)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
constexpr uint64_t kInitialAcyclic
typename MPdtExpandFst< Arc >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename Arc::Label Label
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
StateIterator(const MPdtExpandFst< Arc > &fst)
virtual uint64_t Properties() const
~MPdtExpandFstImpl() override
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
constexpr uint64_t kCopyProperties
constexpr uint64_t kAcyclic
uint64_t MPdtExpandProperties(uint64_t inprops)
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
const PdtStateTable< StateId, StackId > & GetStateTable() const
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
const ParenStack & GetStack() const
std::unique_ptr< StateIteratorBase< Arc > > base
typename Arc::StateId StateId
constexpr uint64_t kFstProperties
constexpr uint64_t kUnweighted
void ExpandState(StateId s)
MPdtExpandOptions(bool connect=true, bool keep_parentheses=false)
size_t NumArcs(StateId s)
internal::MPdtStack< typename Arc::StateId, typename Arc::Label > * stack
typename MPdtExpandFst< Arc >::Arc Arc
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
MPdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const std::vector< Label > &assignments)
MPdtExpandFst< Arc > * Copy(bool safe=false) const override
ArcIterator(const MPdtExpandFst< Arc > &fst, StateId s)
size_t NumOutputEpsilons(StateId s)
typename Arc::Weight Weight
typename Arc::StateId StateId
MPdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool kp=false, internal::MPdtStack< typename Arc::StateId, typename Arc::Label > *s=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *st=nullptr)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
typename CacheState< Arc >::Arc Arc
Impl * GetMutableImpl() const
MPdtExpandFstImpl(const MPdtExpandFstImpl &impl)
typename Arc::StateId StateId
MPdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const std::vector< Label > &assignments, const MPdtExpandFstOptions< Arc > &opts)
typename Arc::Label Label
const PdtStateTable< StateId, StackId > & GetStateTable() const
MPdtExpandFst(const MPdtExpandFst< Arc > &fst, bool safe=false)
const Impl * GetImpl() const
constexpr uint64_t kAcceptor
virtual const SymbolTable * OutputSymbols() const =0