20 #ifndef FST_EXTENSIONS_PDT_COMPOSE_H_ 21 #define FST_EXTENSIONS_PDT_COMPOSE_H_ 23 #include <sys/types.h> 59 using Label =
typename Arc::Label;
65 uint32_t flags = (kParenLoop | kParenList))
66 : matcher_(fst, match_type), match_type_(match_type), flags_(flags) {
74 loop_.weight = Weight::One();
80 uint32_t flags = (kParenLoop | kParenList))
81 : matcher_(fst, match_type), match_type_(match_type), flags_(flags) {
89 loop_.weight = Weight::One();
95 : matcher_(matcher.matcher_, safe),
96 match_type_(matcher.match_type_),
97 flags_(matcher.flags_),
98 open_parens_(matcher.open_parens_),
99 close_parens_(matcher.close_parens_),
100 loop_(matcher.loop_) {
117 bool Done()
const {
return done_; }
137 FSTERROR() <<
"ParenMatcher: Bad open paren label: 0";
139 open_parens_.
Insert(label);
145 FSTERROR() <<
"ParenMatcher: Bad close paren label: 0";
147 close_parens_.
Insert(label);
153 FSTERROR() <<
"ParenMatcher: Bad open paren label: 0";
155 open_parens_.
Erase(label);
161 FSTERROR() <<
"ParenMatcher: Bad close paren label: 0";
163 close_parens_.
Erase(label);
177 bool NextOpenParen();
180 bool NextCloseParen();
189 bool open_paren_list_;
190 bool close_paren_list_;
200 open_paren_list_ =
false;
201 close_paren_list_ =
false;
205 if (match_label ==
kNoLabel && (flags_ & kParenList)) {
208 open_paren_list_ = NextOpenParen();
209 if (open_paren_list_)
return true;
213 close_paren_list_ = NextCloseParen();
214 if (close_paren_list_)
return true;
218 if (match_label > 0 && (flags_ & kParenLoop) &&
224 if (matcher_.
Find(match_label))
return true;
234 }
else if (open_paren_list_) {
236 open_paren_list_ = NextOpenParen();
237 if (open_paren_list_)
return;
240 close_paren_list_ = NextCloseParen();
241 if (close_paren_list_)
return;
244 }
else if (close_paren_list_) {
246 close_paren_list_ = NextCloseParen();
247 if (close_paren_list_)
return;
251 done_ = matcher_.
Done();
258 for (; !matcher_.
Done(); matcher_.
Next()) {
260 : matcher_.
Value().olabel;
261 if (label > open_parens_.
UpperBound())
return false;
270 for (; !matcher_.
Done(); matcher_.
Next()) {
272 : matcher_.
Value().olabel;
273 if (label > close_parens_.
UpperBound())
return false;
279 template <
class Filter>
282 using FST1 =
typename Filter::FST1;
283 using FST2 =
typename Filter::FST2;
284 using Arc =
typename Filter::Arc;
300 const std::vector<std::pair<Label, Label>> *parens =
nullptr,
301 bool expand =
false,
bool keep_parens =
true)
302 : filter_(fst1, fst2, matcher1, matcher2),
303 parens_(parens ? *parens : std::vector<std::pair<
Label,
Label>>()),
305 keep_parens_(keep_parens),
310 for (
const auto &pair : *parens) {
311 parens_.push_back(pair);
312 GetMatcher1()->AddOpenParen(pair.first);
313 GetMatcher2()->AddOpenParen(pair.first);
315 GetMatcher1()->AddCloseParen(pair.second);
316 GetMatcher2()->AddCloseParen(pair.second);
323 : filter_(filter.filter_, safe),
324 parens_(filter.parens_),
325 expand_(filter.expand_),
326 keep_parens_(filter.keep_parens_),
328 stack_(filter.parens_),
337 filter_.SetState(s1, s2, fs_.GetState1());
338 if (!expand_)
return;
339 ssize_t paren_id = stack_.Top(fs.
GetState2().GetState());
340 if (paren_id != paren_id_) {
341 if (paren_id_ != -1) {
342 GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
343 GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
345 paren_id_ = paren_id;
346 if (paren_id_ != -1) {
347 GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
348 GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
354 const auto fs1 = filter_.FilterArc(arc1, arc2);
356 if (fs1 == FilterState1::NoState())
return FilterState::NoState();
357 if (arc1->olabel ==
kNoLabel && arc2->ilabel) {
359 arc1->ilabel = arc2->ilabel;
360 }
else if (arc2->ilabel) {
361 arc2->olabel = arc1->ilabel;
363 return FilterParen(arc2->ilabel, fs1, fs2);
364 }
else if (arc2->ilabel ==
kNoLabel && arc1->olabel) {
366 arc2->olabel = arc1->olabel;
368 arc1->ilabel = arc2->olabel;
370 return FilterParen(arc1->olabel, fs1, fs2);
377 if (fs_.GetState2().GetState() != 0) *w1 = Weight::Zero();
378 filter_.FilterFinal(w1, w2);
396 const auto stack_id = stack_.Find(fs2.
GetState(), label);
398 return FilterState::NoState();
405 std::vector<std::pair<Label, Label>> parens_;
415 template <
class Arc,
bool left_pdt = true>
418 Arc, ParenMatcher<Fst<Arc>>,
419 ParenFilter<AltSequenceComposeFilter<ParenMatcher<Fst<Arc>>>>> {
430 const std::vector<std::pair<Label, Label>> &parens,
431 const Fst<Arc> &ifst2,
bool expand =
false,
432 bool keep_parens =
true) {
435 filter =
new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand,
445 Arc, ParenMatcher<Fst<Arc>>,
446 ParenFilter<SequenceComposeFilter<ParenMatcher<Fst<Arc>>>>> {
457 const std::vector<std::pair<Label, Label>> &parens,
458 bool expand =
false,
bool keep_parens =
true) {
461 filter =
new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand,
478 : connect(connect), filter_type(filter_type) {}
489 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
499 if (opts.connect)
Connect(ofst);
510 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
520 if (opts.connect)
Connect(ofst);
525 #endif // FST_EXTENSIONS_PDT_COMPOSE_H_ void SetState(StateId s1, StateId s2, const FilterState &fs)
void RemoveOpenParen(Label label)
constexpr uint32_t kParenLoop
void SetState(StateId s) final
FilterState Start() const
PdtComposeFstOptions(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, const std::vector< std::pair< Label, Label >> &parens, bool expand=false, bool keep_parens=true)
typename Arc::StateId StateId
bool IsOpenParen(Label label) const
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::Arc Arc
ParenMatcher(const FST &fst, MatchType match_type, uint32_t flags=(kParenLoop|kParenList))
void LowerBound(Label label)
void RemoveCloseParen(Label label)
ParenMatcher(const FST *fst, MatchType match_type, uint32_t flags=(kParenLoop|kParenList))
const Arc & Value() const final
const Arc & Value() const
PdtComposeOptions(bool connect=true, PdtComposeFilter filter_type=PdtComposeFilter::PAREN)
void Connect(MutableFst< Arc > *fst)
constexpr uint32_t kParenList
MatchType Type(bool test) const override
FilterState FilterArc(Arc *arc1, Arc *arc2) const
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::FST1 FST1
PdtComposeFstOptions(const Fst< Arc > &ifst1, const std::vector< std::pair< Label, Label >> &parens, const Fst< Arc > &ifst2, bool expand=false, bool keep_parens=true)
typename Arc::StateId StateId
virtual uint32_t Flags() const
bool Find(Label match_label) final
MatchType Type(bool test) const
typename Arc::Label Label
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
uint64_t Properties(uint64_t props) const
Weight Final(StateId s) const final
ParenMatcher< FST > * Copy(bool safe=false) const
bool IsCloseParen(Label label) const
ParenFilter(const ParenFilter &filter, bool safe=false)
uint64_t Properties(uint64_t iprops) const
void FilterFinal(Weight *w1, Weight *w2) const
void AddCloseParen(Label label)
const FST & GetFst() const
constexpr uint64_t kOLabelInvariantProperties
PdtComposeFilter filter_type
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::FST2 FST2
ssize_t Priority(StateId s)
constexpr uint64_t kILabelInvariantProperties
ssize_t Priority(StateId s) final
typename Arc::Weight Weight
bool Member(Key key) const
typename Arc::Weight Weight
ParenFilter(const FST1 &fst1, const FST2 &fst2, Matcher1 *matcher1=nullptr, Matcher2 *matcher2=nullptr, const std::vector< std::pair< Label, Label >> *parens=nullptr, bool expand=false, bool keep_parens=true)
const FST & GetFst() const override
const FilterState2 & GetState2() const
bool Find(Label match_label)
ParenMatcher(const ParenMatcher< FST > &matcher, bool safe=false)
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::FilterState FilterState1
typename Arc::Label Label
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::Matcher1 Matcher1
typename Arc::Label Label
typename Arc::Label Label
uint64_t Properties(uint64_t inprops) const override
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::Matcher2 Matcher2
void AddOpenParen(Label label)