20 #ifndef FST_COMPOSE_H_ 21 #define FST_COMPOSE_H_ 23 #include <sys/types.h> 60 template <
class Arc,
class M = Matcher<Fst<Arc>>,
61 class Filter = SequenceComposeFilter<M>,
63 GenericComposeStateTable<Arc,
typename Filter::FilterState>>
71 M *matcher1 =
nullptr, M *matcher2 =
nullptr,
72 Filter *filter =
nullptr,
73 StateTable *state_table =
nullptr)
78 state_table(state_table) {}
82 template <
class C,
class F,
class T>
96 template <
class M1,
class M2,
class Filter = SequenceComposeFilter<M1, M2>,
97 class StateTable = GenericComposeStateTable<
98 typename M1::Arc,
typename Filter::FilterState>,
99 class CacheStore = DefaultCacheStore<
typename M1::Arc>>
110 M1 *matcher1 =
nullptr, M2 *matcher2 =
nullptr,
111 Filter *filter =
nullptr,
112 StateTable *state_table =
nullptr)
117 state_table(state_table),
118 own_state_table(true),
119 allow_noncommute(false) {}
122 M1 *matcher1 =
nullptr, M2 *matcher2 =
nullptr,
123 Filter *filter =
nullptr,
124 StateTable *state_table =
nullptr)
129 state_table(state_table),
130 own_state_table(true),
131 allow_noncommute(false) {}
137 state_table(nullptr),
138 own_state_table(true),
139 allow_noncommute(false) {}
146 template <
class Arc,
class CacheStore = DefaultCacheStore<Arc>,
147 class F = ComposeFst<Arc, CacheStore>>
149 :
public CacheBaseImpl<typename CacheStore::State, CacheStore> {
156 using State =
typename CacheStore::State;
165 using CacheImpl::HasArcs;
166 using CacheImpl::HasFinal;
167 using CacheImpl::HasStart;
168 using CacheImpl::SetFinal;
169 using CacheImpl::SetStart;
189 const auto start = ComputeStart();
192 return CacheImpl::Start();
196 if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
203 if (!HasArcs(s))
Expand(s);
208 if (!HasArcs(s))
Expand(s);
213 if (!HasArcs(s))
Expand(s);
218 if (!HasArcs(s))
Expand(s);
219 CacheImpl::InitArcIterator(s, data);
229 virtual StateId ComputeStart() = 0;
236 template <
class CacheStore,
class Filter,
class StateTable>
246 using Arc =
typename CacheStore::Arc;
252 using State =
typename CacheStore::State;
265 template <
class M1,
class M2>
272 filter_(new Filter(*impl.filter_, true)),
273 matcher1_(filter_->GetMatcher1()),
274 matcher2_(filter_->GetMatcher2()),
275 fst1_(matcher1_->GetFst()),
276 fst2_(matcher2_->GetFst()),
277 state_table_(new StateTable(*impl.state_table_)),
278 own_state_table_(true),
279 match_type_(impl.match_type_) {}
282 if (own_state_table_)
delete state_table_;
292 (fst1_.Properties(kError,
false) || fst2_.Properties(kError,
false) ||
293 (matcher1_->Properties(0) &
kError) ||
294 (matcher2_->Properties(0) &
kError) |
295 (filter_->Properties(0) &
kError) ||
296 state_table_->Error())) {
297 SetProperties(kError, kError);
305 const auto &tuple = state_table_->Tuple(s);
306 const auto s1 = tuple.StateId1();
307 const auto s2 = tuple.StateId2();
308 filter_->SetState(s1, s2, tuple.GetFilterState());
309 if (MatchInput(s1, s2)) {
310 OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_,
true);
312 OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_,
false);
328 const Filter *
GetFilter()
const {
return filter_.get(); }
345 if ((matcher1_->Type(
false) == match_type) &&
346 (matcher2_->Type(
false) == match_type) &&
347 (filter_->Properties(test_props) == test_props)) {
360 template <
class FST,
class Matcher>
365 const Arc loop(match_input ? 0 :
kNoLabel, match_input ? kNoLabel : 0,
367 MatchArc(s, matchera, loop, match_input);
370 MatchArc(s, matchera, iterb.
Value(), match_input);
372 CacheImpl::SetArcs(s);
376 template <
class Matcher>
379 if (matchera->
Find(match_input ? arc.olabel : arc.ilabel)) {
380 for (; !matchera->
Done(); matchera->
Next()) {
381 auto arca = matchera->
Value();
384 const auto &fs = filter_->FilterArc(&arcb, &arca);
385 if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
387 const auto &fs = filter_->FilterArc(&arca, &arcb);
388 if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
397 const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
398 CacheImpl::EmplaceArc(s, arc1.ilabel, arc2.olabel,
399 Times(arc1.weight, arc2.weight),
400 state_table_->FindState(tuple));
403 StateId ComputeStart()
override {
404 const auto s1 = fst1_.Start();
406 const auto s2 = fst2_.Start();
408 const auto &fs = filter_->Start();
410 return state_table_->FindState(tuple);
414 const auto &tuple = state_table_->Tuple(s);
415 const auto s1 = tuple.StateId1();
416 auto final1 = matcher1_->Final(s1);
417 if (final1 == Weight::Zero())
return final1;
418 const auto s2 = tuple.StateId2();
419 auto final2 = matcher2_->Final(s2);
420 if (final2 == Weight::Zero())
return final2;
421 filter_->SetState(s1, s2, tuple.GetFilterState());
422 filter_->FilterFinal(&final1, &final2);
423 return Times(final1, final2);
428 switch (match_type_) {
434 const auto priority1 = matcher1_->Priority(s1);
435 const auto priority2 = matcher2_->Priority(s2);
437 FSTERROR() <<
"ComposeFst: Both sides can't require match";
445 return priority1 <= priority2;
453 std::unique_ptr<Filter> filter_;
458 StateTable *state_table_;
459 bool own_state_table_;
464 template <
class CacheStore,
class Filter,
class StateTable>
465 template <
class M1,
class M2>
472 : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
473 matcher1_(filter_->GetMatcher1()),
474 matcher2_(filter_->GetMatcher2()),
475 fst1_(matcher1_->GetFst()),
476 fst2_(matcher2_->GetFst()),
477 state_table_(opts.state_table ? opts.state_table
478 : new StateTable(fst1_, fst2_)),
479 own_state_table_(opts.state_table ? opts.own_state_table : true) {
481 if (!
CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
482 FSTERROR() <<
"ComposeFst: Output symbol table of 1st argument " 483 <<
"does not match input symbol table of 2nd argument";
489 VLOG(2) <<
"ComposeFstImpl: Match type: " << match_type_;
493 const auto mprops1 = matcher1_->Properties(fprops1);
494 const auto mprops2 = matcher2_->Properties(fprops2);
500 template <
class CacheStore,
class Filter,
class StateTable>
505 FSTERROR() <<
"ComposeFst: 1st argument cannot perform required matching " 512 FSTERROR() <<
"ComposeFst: 2nd argument cannot perform required matching " 518 const auto type1 = matcher1_->Type(
false);
519 const auto type2 = matcher2_->Type(
false);
531 FSTERROR() <<
"ComposeFst: 1st argument cannot match on output labels " 532 <<
"and 2nd argument cannot match on input labels (sort?).";
574 template <
class A,
class CacheStore >
576 :
public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
583 using State =
typename CacheStore::State;
589 template <
class,
class,
class>
600 template <
class Matcher,
class Filter,
class StateTuple>
610 template <
class Matcher1,
class Matcher2,
class Filter,
class StateTuple>
620 : fst.GetSharedImpl()) {}
630 GetMutableImpl()->InitArcIterator(s, data);
634 return GetImpl()->InitMatcher(*
this, match_type);
644 template <
class Matcher1,
class Matcher2,
class Filter,
class StateTuple>
649 auto impl = std::make_shared<
652 if (!(Weight::Properties() &
kCommutative) && !opts.allow_noncommute) {
654 const auto props2 = fst2.Properties(
kUnweighted,
true);
655 if (!(props1 &
kUnweighted) && !(props2 & kUnweighted)) {
656 FSTERROR() <<
"ComposeFst: Weights must be a commutative semiring: " 666 template <
class Matcher,
class Filter,
class StateTuple>
673 return CreateBase2(fst1, fst2, nopts);
684 return CreateBase1(fst1, fst2, nopts);
690 return CreateBase1(fst1, fst2, nopts);
696 return CreateBase1(fst1, fst2, nopts);
706 template <
class Arc,
class CacheStore>
712 fst.GetMutableImpl()) {}
716 template <
class Arc,
class CacheStore>
728 template <
class Arc,
class CacheStore>
732 std::make_unique<StateIterator<ComposeFst<Arc, CacheStore>>>(*this);
738 template <
class CacheStore,
class Filter,
class StateTable>
741 using Arc =
typename CacheStore::Arc;
757 : owned_fst_(fst.
Copy()),
761 match_type_(match_type),
762 matcher1_(impl_->matcher1_->
Copy()),
763 matcher2_(impl_->matcher2_->
Copy()),
764 current_loop_(false),
766 if (match_type_ ==
MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
776 match_type_(match_type),
777 matcher1_(impl_->matcher1_->
Copy()),
778 matcher2_(impl_->matcher2_->
Copy()),
779 current_loop_(false),
781 if (match_type_ ==
MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
788 : owned_fst_(matcher.fst_.
Copy(safe)),
792 match_type_(matcher.match_type_),
793 matcher1_(matcher.matcher1_->
Copy(safe)),
794 matcher2_(matcher.matcher2_->
Copy(safe)),
795 current_loop_(false),
797 if (match_type_ ==
MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
812 (matcher2_->Type(test) == match_type_)) ||
813 ((matcher1_->Type(test) == match_type_) &&
817 if ((matcher1_->Type(test) == match_type_) &&
818 (matcher2_->Type(test) == match_type_)) {
826 uint64_t
Properties(uint64_t inprops)
const override {
return inprops; }
831 const auto &tuple = impl_->state_table_->Tuple(s);
832 matcher1_->SetState(tuple.StateId1());
833 matcher2_->SetState(tuple.StateId2());
834 loop_.nextstate = s_;
839 current_loop_ =
false;
841 current_loop_ =
true;
845 found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
847 found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
853 return !current_loop_ && matcher1_->Done() && matcher2_->Done();
856 const Arc &
Value() const final {
return current_loop_ ? loop_ : arc_; }
860 current_loop_ =
false;
862 FindNext(matcher1_.get(), matcher2_.get());
864 FindNext(matcher2_.get(), matcher1_.get());
873 const auto &fs = impl_->filter_->FilterArc(arc1, arc2);
874 if (fs == FilterState::NoState())
return false;
875 const StateTuple tuple(arc1->nextstate, arc2->nextstate, fs);
876 arc_.ilabel = arc1->ilabel;
877 arc_.olabel = arc2->olabel;
878 arc_.weight =
Times(arc1->weight, arc2->weight);
879 arc_.nextstate = impl_->state_table_->FindState(tuple);
884 template <
class MatcherA,
class MatcherB>
885 bool FindLabel(
Label label, MatcherA *matchera, MatcherB *matcherb) {
886 if (matchera->Find(label)) {
887 matcherb->Find(match_type_ ==
MATCH_INPUT ? matchera->Value().olabel
888 : matchera->Value().ilabel);
889 return FindNext(matchera, matcherb);
896 template <
class MatcherA,
class MatcherB>
897 bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
901 while (!matchera->Done() || !matcherb->Done()) {
902 if (matcherb->Done()) {
907 while (!matchera->Done() &&
909 ? matchera->Value().olabel
910 : matchera->Value().ilabel)) {
914 while (!matcherb->Done()) {
920 auto arca = matchera->Value();
921 auto arcb = matcherb->Value();
928 if (MatchArc(s_, &arca, &arcb))
return true;
930 if (MatchArc(s_, &arcb, &arca))
return true;
938 std::unique_ptr<const ComposeFst<Arc, CacheStore>> owned_fst_;
943 std::unique_ptr<Matcher1> matcher1_;
944 std::unique_ptr<Matcher2> matcher2_;
969 : connect(connect), filter_type(filter_type) {}
1004 template <
class Arc>
1010 switch (opts.filter_type) {
1054 if (opts.connect)
Connect(ofst);
1059 #endif // FST_COMPOSE_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
uint64_t Properties() const override
typename Filter::Matcher1 Matcher1
ssize_t Priority(StateId s) final
void SetProperties(uint64_t props)
ComposeFst(const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2, const ComposeFstImplOptions< Matcher1, Matcher2, Filter, StateTuple, CacheStore > &opts)
typename Filter::Matcher2 Matcher2
const StateTable * GetStateTable() const
size_t NumArcs(StateId s)
constexpr ssize_t kRequirePriority
void Expand(StateId s) override
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
typename Filter::Matcher2 Matcher2
typename StateTable::StateTuple StateTuple
static std::shared_ptr< Impl > CreateBase1(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const ComposeFstOptions< Arc, Matcher, Filter, StateTuple > &opts)
typename StateTable::StateTuple StateTuple
typename Filter::Matcher1 Matcher1
typename Arc::Label Label
uint64_t ComposeProperties(uint64_t inprops1, uint64_t inprops2)
ComposeFstImplOptions(const CacheImplOptions< CacheStore > &opts, M1 *matcher1=nullptr, M2 *matcher2=nullptr, Filter *filter=nullptr, StateTable *state_table=nullptr)
size_t NumOutputEpsilons(StateId s)
ComposeFilter filter_type
MatchType LookAheadMatchType(const Matcher1 &m1, const Matcher2 &m2)
~ComposeFstImpl() override
ComposeOptions(bool connect=true, ComposeFilter filter_type=AUTO_FILTER)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
ComposeFstMatcher(const ComposeFst< Arc, CacheStore > *fst, MatchType match_type)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
ComposeFst(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const CacheOptions &opts=CacheOptions())
static std::shared_ptr< Impl > CreateBase(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const CacheOptions &opts)
void SetOutputSymbols(const SymbolTable *osyms)
typename Matcher2::FST FST2
typename ComposeFst< Arc, CacheStore >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename CacheStore::State State
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
MatcherBase< Arc > * InitMatcher(const ComposeFst< Arc, CacheStore > &fst, MatchType match_type) const override
const Matcher1 * GetMatcher1() const
ComposeFstMatcher(const ComposeFst< Arc, CacheStore > &fst, MatchType match_type)
typename Arc::StateId StateId
typename Arc::Weight Weight
StateTable * GetStateTable()
typename Arc::Weight Weight
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
MatchType Type(bool test) const override
typename CacheStore::Arc Arc
virtual uint64_t Properties() const
constexpr uint64_t kCopyProperties
constexpr uint64_t kCommutative
void InitStateIterator(StateIteratorData< Arc > *data) const override
const Filter * GetFilter() const
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
ComposeFst * Copy(bool safe=false) const override
ComposeFstImplBase(const CacheOptions &opts)
typename Arc::StateId StateId
ComposeFst(std::shared_ptr< Impl > impl)
size_t NumInputEpsilons(StateId s)
std::unique_ptr< StateIteratorBase< Arc > > base
const FST1 & GetFst1() const
typename Arc::StateId StateId
ComposeFst(const ComposeFst &fst, bool safe=false)
ComposeFstMatcher(const ComposeFstMatcher< CacheStore, Filter, StateTable > &matcher, bool safe=false)
constexpr uint64_t kOLabelInvariantProperties
ComposeFstImpl * Copy() const override
const FST2 & GetFst2() const
ComposeFstImpl(const ComposeFstImpl &impl)
const Matcher2 * GetMatcher2() const
static std::shared_ptr< Impl > CreateBase2(const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2, const ComposeFstImplOptions< Matcher1, Matcher2, Filter, StateTuple, CacheStore > &opts)
void SetInputSymbols(const SymbolTable *isyms)
typename CacheStore::Arc Arc
const std::string & Type() const
virtual MatcherBase< Arc > * InitMatcher(const F &fst, MatchType match_type) const
constexpr uint64_t kILabelInvariantProperties
MatcherBase< Arc > * InitMatcher(MatchType match_type) const override
typename Filter::FilterState FilterState
ComposeFst(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const ComposeFstOptions< Arc, Matcher, Filter, StateTuple > &opts)
void SetState(StateId s) final
typename Matcher1::FST FST1
uint64_t Properties(uint64_t inprops) const override
typename Arc::StateId StateId
constexpr uint64_t kFstProperties
ComposeFstOptions(const CacheOptions &opts=CacheOptions(), M *matcher1=nullptr, M *matcher2=nullptr, Filter *filter=nullptr, StateTable *state_table=nullptr)
const Arc & Value() const final
constexpr uint64_t kUnweighted
const Fst< Arc > & GetFst() const override
typename Arc::Weight Weight
typename ComposeFst< Arc, CacheStore >::Arc Arc
const SymbolTable * InputSymbols() const
ComposeFstImplBase(const CacheImplOptions< CacheStore > &opts)
void SetType(std::string_view type)
typename Filter::FilterState FilterState
uint64_t Properties(uint64_t mask) const override
ComposeFstMatcher * Copy(bool safe=false) const override
ArcIterator(const ComposeFst< Arc, CacheStore > &fst, StateId s)
ComposeFstImplBase(const ComposeFstImplBase &impl)
typename CacheStore::State State
typename CacheStore::Arc::Label Label
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
const Arc & Value() const
Impl * GetMutableImpl() const
typename Arc::Label Label
StateIterator(const ComposeFst< Arc, CacheStore > &fst)
typename Arc::Weight Weight
typename Arc::StateId StateId
typename CacheStore::State State
bool Find(Label label) final
const Impl * GetImpl() const
ComposeFstImplOptions(const CacheOptions &opts, M1 *matcher1=nullptr, M2 *matcher2=nullptr, Filter *filter=nullptr, StateTable *state_table=nullptr)
constexpr uint32_t kRequireMatch