20 #ifndef FST_COMPOSE_H_ 21 #define FST_COMPOSE_H_ 50 template <
class Arc,
class M = Matcher<Fst<Arc>>,
51 class Filter = SequenceComposeFilter<M>,
53 GenericComposeStateTable<Arc,
typename Filter::FilterState>>
61 M *matcher1 =
nullptr, M *matcher2 =
nullptr,
62 Filter *filter =
nullptr,
63 StateTable *state_table =
nullptr)
68 state_table(state_table) {}
72 template <
class C,
class F,
class T>
86 template <
class M1,
class M2,
class Filter = SequenceComposeFilter<M1, M2>,
87 class StateTable = GenericComposeStateTable<
88 typename M1::Arc,
typename Filter::FilterState>,
89 class CacheStore = DefaultCacheStore<
typename M1::Arc>>
100 M1 *matcher1 =
nullptr, M2 *matcher2 =
nullptr,
101 Filter *filter =
nullptr,
102 StateTable *state_table =
nullptr)
107 state_table(state_table),
108 own_state_table(true),
109 allow_noncommute(false) {}
112 M1 *matcher1 =
nullptr, M2 *matcher2 =
nullptr,
113 Filter *filter =
nullptr,
114 StateTable *state_table =
nullptr)
119 state_table(state_table),
120 own_state_table(true),
121 allow_noncommute(false) {}
127 state_table(nullptr),
128 own_state_table(true),
129 allow_noncommute(false) {}
136 template <
class Arc,
class CacheStore = DefaultCacheStore<Arc>,
137 class F = ComposeFst<Arc, CacheStore>>
139 :
public CacheBaseImpl<typename CacheStore::State, CacheStore> {
146 using State =
typename CacheStore::State;
155 using CacheImpl::HasArcs;
156 using CacheImpl::HasFinal;
157 using CacheImpl::HasStart;
158 using CacheImpl::SetFinal;
159 using CacheImpl::SetStart;
179 const auto start = ComputeStart();
182 return CacheImpl::Start();
186 if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
193 if (!HasArcs(s))
Expand(s);
198 if (!HasArcs(s))
Expand(s);
203 if (!HasArcs(s))
Expand(s);
208 if (!HasArcs(s))
Expand(s);
209 CacheImpl::InitArcIterator(s, data);
219 virtual StateId ComputeStart() = 0;
226 template <
class CacheStore,
class Filter,
class StateTable>
236 using Arc =
typename CacheStore::Arc;
242 using State =
typename CacheStore::State;
255 template <
class M1,
class M2>
262 filter_(new Filter(*impl.filter_, true)),
263 matcher1_(filter_->GetMatcher1()),
264 matcher2_(filter_->GetMatcher2()),
265 fst1_(matcher1_->GetFst()),
266 fst2_(matcher2_->GetFst()),
267 state_table_(new StateTable(*impl.state_table_)),
268 own_state_table_(true),
269 match_type_(impl.match_type_) {}
272 if (own_state_table_)
delete state_table_;
282 (fst1_.Properties(kError,
false) || fst2_.Properties(kError,
false) ||
283 (matcher1_->Properties(0) &
kError) ||
284 (matcher2_->Properties(0) &
kError) |
285 (filter_->Properties(0) &
kError) ||
286 state_table_->Error())) {
287 SetProperties(kError, kError);
295 const auto &tuple = state_table_->Tuple(s);
296 const auto s1 = tuple.StateId1();
297 const auto s2 = tuple.StateId2();
298 filter_->SetState(s1, s2, tuple.GetFilterState());
299 if (MatchInput(s1, s2)) {
300 OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_,
true);
302 OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_,
false);
318 const Filter *
GetFilter()
const {
return filter_.get(); }
335 if ((matcher1_->Type(
false) == match_type) &&
336 (matcher2_->Type(
false) == match_type) &&
337 (filter_->Properties(test_props) == test_props)) {
350 template <
class FST,
class Matcher>
355 const Arc loop(match_input ? 0 :
kNoLabel, match_input ? kNoLabel : 0,
357 MatchArc(s, matchera, loop, match_input);
360 MatchArc(s, matchera, iterb.
Value(), match_input);
362 CacheImpl::SetArcs(s);
366 template <
class Matcher>
369 if (matchera->
Find(match_input ? arc.olabel : arc.ilabel)) {
370 for (; !matchera->
Done(); matchera->
Next()) {
371 auto arca = matchera->
Value();
374 const auto &fs = filter_->FilterArc(&arcb, &arca);
375 if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
377 const auto &fs = filter_->FilterArc(&arca, &arcb);
378 if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
387 const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
388 CacheImpl::EmplaceArc(s, arc1.ilabel, arc2.olabel,
389 Times(arc1.weight, arc2.weight),
390 state_table_->FindState(tuple));
393 StateId ComputeStart()
override {
394 const auto s1 = fst1_.Start();
396 const auto s2 = fst2_.Start();
398 const auto &fs = filter_->Start();
400 return state_table_->FindState(tuple);
404 const auto &tuple = state_table_->Tuple(s);
405 const auto s1 = tuple.StateId1();
406 auto final1 = matcher1_->Final(s1);
407 if (final1 == Weight::Zero())
return final1;
408 const auto s2 = tuple.StateId2();
409 auto final2 = matcher2_->Final(s2);
410 if (final2 == Weight::Zero())
return final2;
411 filter_->SetState(s1, s2, tuple.GetFilterState());
412 filter_->FilterFinal(&final1, &final2);
413 return Times(final1, final2);
418 switch (match_type_) {
424 const auto priority1 = matcher1_->Priority(s1);
425 const auto priority2 = matcher2_->Priority(s2);
427 FSTERROR() <<
"ComposeFst: Both sides can't require match";
435 return priority1 <= priority2;
443 std::unique_ptr<Filter> filter_;
448 StateTable *state_table_;
449 bool own_state_table_;
454 template <
class CacheStore,
class Filter,
class StateTable>
455 template <
class M1,
class M2>
462 : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
463 matcher1_(filter_->GetMatcher1()),
464 matcher2_(filter_->GetMatcher2()),
465 fst1_(matcher1_->GetFst()),
466 fst2_(matcher2_->GetFst()),
467 state_table_(opts.state_table ? opts.state_table
468 : new StateTable(fst1_, fst2_)),
469 own_state_table_(opts.state_table ? opts.own_state_table : true) {
471 if (!
CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
472 FSTERROR() <<
"ComposeFst: Output symbol table of 1st argument " 473 <<
"does not match input symbol table of 2nd argument";
479 VLOG(2) <<
"ComposeFstImpl: Match type: " << match_type_;
483 const auto mprops1 = matcher1_->Properties(fprops1);
484 const auto mprops2 = matcher2_->Properties(fprops2);
490 template <
class CacheStore,
class Filter,
class StateTable>
495 FSTERROR() <<
"ComposeFst: 1st argument cannot perform required matching " 502 FSTERROR() <<
"ComposeFst: 2nd argument cannot perform required matching " 508 const auto type1 = matcher1_->Type(
false);
509 const auto type2 = matcher2_->Type(
false);
521 FSTERROR() <<
"ComposeFst: 1st argument cannot match on output labels " 522 <<
"and 2nd argument cannot match on input labels (sort?).";
564 template <
class A,
class CacheStore >
566 :
public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
573 using State =
typename CacheStore::State;
579 template <
class,
class,
class>
590 template <
class Matcher,
class Filter,
class StateTuple>
600 template <
class Matcher1,
class Matcher2,
class Filter,
class StateTuple>
610 : fst.GetSharedImpl()) {}
620 GetMutableImpl()->InitArcIterator(s, data);
624 return GetImpl()->InitMatcher(*
this, match_type);
634 template <
class Matcher1,
class Matcher2,
class Filter,
class StateTuple>
639 auto impl = std::make_shared<
642 if (!(Weight::Properties() &
kCommutative) && !opts.allow_noncommute) {
644 const auto props2 = fst2.Properties(
kUnweighted,
true);
645 if (!(props1 &
kUnweighted) && !(props2 & kUnweighted)) {
646 FSTERROR() <<
"ComposeFst: Weights must be a commutative semiring: " 656 template <
class Matcher,
class Filter,
class StateTuple>
663 return CreateBase2(fst1, fst2, nopts);
674 return CreateBase1(fst1, fst2, nopts);
680 return CreateBase1(fst1, fst2, nopts);
686 return CreateBase1(fst1, fst2, nopts);
696 template <
class Arc,
class CacheStore>
702 fst.GetMutableImpl()) {}
706 template <
class Arc,
class CacheStore>
718 template <
class Arc,
class CacheStore>
722 std::make_unique<StateIterator<ComposeFst<Arc, CacheStore>>>(*this);
728 template <
class CacheStore,
class Filter,
class StateTable>
731 using Arc =
typename CacheStore::Arc;
747 : owned_fst_(fst.
Copy()),
751 match_type_(match_type),
752 matcher1_(impl_->matcher1_->
Copy()),
753 matcher2_(impl_->matcher2_->
Copy()),
754 current_loop_(false),
756 if (match_type_ ==
MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
766 match_type_(match_type),
767 matcher1_(impl_->matcher1_->
Copy()),
768 matcher2_(impl_->matcher2_->
Copy()),
769 current_loop_(false),
771 if (match_type_ ==
MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
778 : owned_fst_(matcher.fst_.
Copy(safe)),
782 match_type_(matcher.match_type_),
783 matcher1_(matcher.matcher1_->
Copy(safe)),
784 matcher2_(matcher.matcher2_->
Copy(safe)),
785 current_loop_(false),
787 if (match_type_ ==
MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
802 (matcher2_->Type(test) == match_type_)) ||
803 ((matcher1_->Type(test) == match_type_) &&
807 if ((matcher1_->Type(test) == match_type_) &&
808 (matcher2_->Type(test) == match_type_)) {
816 uint64_t
Properties(uint64_t inprops)
const override {
return inprops; }
821 const auto &tuple = impl_->state_table_->Tuple(s);
822 matcher1_->SetState(tuple.StateId1());
823 matcher2_->SetState(tuple.StateId2());
824 loop_.nextstate = s_;
829 current_loop_ =
false;
831 current_loop_ =
true;
835 found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
837 found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
843 return !current_loop_ && matcher1_->Done() && matcher2_->Done();
846 const Arc &
Value() const final {
return current_loop_ ? loop_ : arc_; }
850 current_loop_ =
false;
852 FindNext(matcher1_.get(), matcher2_.get());
854 FindNext(matcher2_.get(), matcher1_.get());
863 const auto &fs = impl_->filter_->FilterArc(arc1, arc2);
864 if (fs == FilterState::NoState())
return false;
865 const StateTuple tuple(arc1->nextstate, arc2->nextstate, fs);
866 arc_.ilabel = arc1->ilabel;
867 arc_.olabel = arc2->olabel;
868 arc_.weight =
Times(arc1->weight, arc2->weight);
869 arc_.nextstate = impl_->state_table_->FindState(tuple);
874 template <
class MatcherA,
class MatcherB>
875 bool FindLabel(
Label label, MatcherA *matchera, MatcherB *matcherb) {
876 if (matchera->Find(label)) {
877 matcherb->Find(match_type_ ==
MATCH_INPUT ? matchera->Value().olabel
878 : matchera->Value().ilabel);
879 return FindNext(matchera, matcherb);
886 template <
class MatcherA,
class MatcherB>
887 bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
891 while (!matchera->Done() || !matcherb->Done()) {
892 if (matcherb->Done()) {
897 while (!matchera->Done() &&
899 ? matchera->Value().olabel
900 : matchera->Value().ilabel)) {
904 while (!matcherb->Done()) {
910 auto arca = matchera->Value();
911 auto arcb = matcherb->Value();
918 return MatchArc(s_, &arca, &arcb);
920 return MatchArc(s_, &arcb, &arca);
928 std::unique_ptr<const ComposeFst<Arc, CacheStore>> owned_fst_;
933 std::unique_ptr<Matcher1> matcher1_;
934 std::unique_ptr<Matcher2> matcher2_;
959 : connect(connect), filter_type(filter_type) {}
1000 switch (opts.filter_type) {
1044 if (opts.connect)
Connect(ofst);
1049 #endif // FST_COMPOSE_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
uint64_t Properties() const override
typename Filter::Matcher1 Matcher1
ssize_t Priority(StateId s) final
void SetProperties(uint64_t props)
ComposeFst(const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2, const ComposeFstImplOptions< Matcher1, Matcher2, Filter, StateTuple, CacheStore > &opts)
typename Filter::Matcher2 Matcher2
const StateTable * GetStateTable() const
size_t NumArcs(StateId s)
constexpr ssize_t kRequirePriority
void Expand(StateId s) override
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
typename Filter::Matcher2 Matcher2
typename StateTable::StateTuple StateTuple
static std::shared_ptr< Impl > CreateBase1(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const ComposeFstOptions< Arc, Matcher, Filter, StateTuple > &opts)
typename StateTable::StateTuple StateTuple
typename Filter::Matcher1 Matcher1
typename Arc::Label Label
uint64_t ComposeProperties(uint64_t inprops1, uint64_t inprops2)
ComposeFstImplOptions(const CacheImplOptions< CacheStore > &opts, M1 *matcher1=nullptr, M2 *matcher2=nullptr, Filter *filter=nullptr, StateTable *state_table=nullptr)
size_t NumOutputEpsilons(StateId s)
~ComposeFstImplBase() override
ComposeFilter filter_type
MatchType LookAheadMatchType(const Matcher1 &m1, const Matcher2 &m2)
~ComposeFstImpl() override
ComposeOptions(bool connect=true, ComposeFilter filter_type=AUTO_FILTER)
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
const SymbolTable * OutputSymbols() const
ComposeFstMatcher(const ComposeFst< Arc, CacheStore > *fst, MatchType match_type)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
constexpr uint64_t kError
ComposeFst(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const CacheOptions &opts=CacheOptions())
static std::shared_ptr< Impl > CreateBase(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const CacheOptions &opts)
void SetOutputSymbols(const SymbolTable *osyms)
typename Matcher2::FST FST2
typename ComposeFst< Arc, CacheStore >::Arc Arc
void Connect(MutableFst< Arc > *fst)
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
typename CacheStore::State State
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
MatcherBase< Arc > * InitMatcher(const ComposeFst< Arc, CacheStore > &fst, MatchType match_type) const override
const Matcher1 * GetMatcher1() const
ComposeFstMatcher(const ComposeFst< Arc, CacheStore > &fst, MatchType match_type)
typename Arc::StateId StateId
typename Arc::Weight Weight
StateTable * GetStateTable()
typename Arc::Weight Weight
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
MatchType Type(bool test) const override
typename CacheStore::Arc Arc
virtual uint64_t Properties() const
constexpr uint64_t kCopyProperties
constexpr uint64_t kCommutative
void InitStateIterator(StateIteratorData< Arc > *data) const override
const Filter * GetFilter() const
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
ComposeFst * Copy(bool safe=false) const override
ComposeFstImplBase(const CacheOptions &opts)
typename Arc::StateId StateId
ComposeFst(std::shared_ptr< Impl > impl)
size_t NumInputEpsilons(StateId s)
std::unique_ptr< StateIteratorBase< Arc > > base
const FST1 & GetFst1() const
typename Arc::StateId StateId
ComposeFst(const ComposeFst &fst, bool safe=false)
ComposeFstMatcher(const ComposeFstMatcher< CacheStore, Filter, StateTable > &matcher, bool safe=false)
constexpr uint64_t kOLabelInvariantProperties
ComposeFstImpl * Copy() const override
const FST2 & GetFst2() const
ComposeFstImpl(const ComposeFstImpl &impl)
const Matcher2 * GetMatcher2() const
static std::shared_ptr< Impl > CreateBase2(const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2, const ComposeFstImplOptions< Matcher1, Matcher2, Filter, StateTuple, CacheStore > &opts)
void SetInputSymbols(const SymbolTable *isyms)
typename CacheStore::Arc Arc
const std::string & Type() const
virtual MatcherBase< Arc > * InitMatcher(const F &fst, MatchType match_type) const
constexpr uint64_t kILabelInvariantProperties
MatcherBase< Arc > * InitMatcher(MatchType match_type) const override
typename Filter::FilterState FilterState
ComposeFst(const Fst< Arc > &fst1, const Fst< Arc > &fst2, const ComposeFstOptions< Arc, Matcher, Filter, StateTuple > &opts)
void SetState(StateId s) final
typename Matcher1::FST FST1
uint64_t Properties(uint64_t inprops) const override
typename Arc::StateId StateId
constexpr uint64_t kFstProperties
ComposeFstOptions(const CacheOptions &opts=CacheOptions(), M *matcher1=nullptr, M *matcher2=nullptr, Filter *filter=nullptr, StateTable *state_table=nullptr)
const Arc & Value() const final
constexpr uint64_t kUnweighted
const Fst< Arc > & GetFst() const override
typename Arc::Weight Weight
typename ComposeFst< Arc, CacheStore >::Arc Arc
const SymbolTable * InputSymbols() const
ComposeFstImplBase(const CacheImplOptions< CacheStore > &opts)
void SetType(std::string_view type)
typename Filter::FilterState FilterState
uint64_t Properties(uint64_t mask) const override
ComposeFstMatcher * Copy(bool safe=false) const override
ArcIterator(const ComposeFst< Arc, CacheStore > &fst, StateId s)
ComposeFstImplBase(const ComposeFstImplBase &impl)
typename CacheStore::State State
typename CacheStore::Arc::Label Label
bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, bool warning=true)
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
const Arc & Value() const
Impl * GetMutableImpl() const
typename Arc::Label Label
StateIterator(const ComposeFst< Arc, CacheStore > &fst)
typename Arc::Weight Weight
typename Arc::StateId StateId
typename CacheStore::State State
bool Find(Label label) final
const Impl * GetImpl() const
ComposeFstImplOptions(const CacheOptions &opts, M1 *matcher1=nullptr, M2 *matcher2=nullptr, Filter *filter=nullptr, StateTable *state_table=nullptr)
constexpr uint32_t kRequireMatch