20 #ifndef FST_EXTENSIONS_PDT_COMPOSE_H_ 21 #define FST_EXTENSIONS_PDT_COMPOSE_H_ 46 using Label =
typename Arc::Label;
52 uint32_t flags = (kParenLoop | kParenList))
53 : matcher_(fst, match_type), match_type_(match_type), flags_(flags) {
61 loop_.weight = Weight::One();
67 uint32_t flags = (kParenLoop | kParenList))
68 : matcher_(fst, match_type), match_type_(match_type), flags_(flags) {
76 loop_.weight = Weight::One();
82 : matcher_(matcher.matcher_, safe),
83 match_type_(matcher.match_type_),
84 flags_(matcher.flags_),
85 open_parens_(matcher.open_parens_),
86 close_parens_(matcher.close_parens_),
87 loop_(matcher.loop_) {
104 bool Done()
const {
return done_; }
124 FSTERROR() <<
"ParenMatcher: Bad open paren label: 0";
126 open_parens_.
Insert(label);
132 FSTERROR() <<
"ParenMatcher: Bad close paren label: 0";
134 close_parens_.
Insert(label);
140 FSTERROR() <<
"ParenMatcher: Bad open paren label: 0";
142 open_parens_.
Erase(label);
148 FSTERROR() <<
"ParenMatcher: Bad close paren label: 0";
150 close_parens_.
Erase(label);
164 bool NextOpenParen();
167 bool NextCloseParen();
176 bool open_paren_list_;
177 bool close_paren_list_;
187 open_paren_list_ =
false;
188 close_paren_list_ =
false;
192 if (match_label ==
kNoLabel && (flags_ & kParenList)) {
195 open_paren_list_ = NextOpenParen();
196 if (open_paren_list_)
return true;
200 close_paren_list_ = NextCloseParen();
201 if (close_paren_list_)
return true;
205 if (match_label > 0 && (flags_ & kParenLoop) &&
211 if (matcher_.
Find(match_label))
return true;
221 }
else if (open_paren_list_) {
223 open_paren_list_ = NextOpenParen();
224 if (open_paren_list_)
return;
227 close_paren_list_ = NextCloseParen();
228 if (close_paren_list_)
return;
231 }
else if (close_paren_list_) {
233 close_paren_list_ = NextCloseParen();
234 if (close_paren_list_)
return;
238 done_ = matcher_.
Done();
245 for (; !matcher_.
Done(); matcher_.
Next()) {
247 : matcher_.
Value().olabel;
248 if (label > open_parens_.
UpperBound())
return false;
257 for (; !matcher_.
Done(); matcher_.
Next()) {
259 : matcher_.
Value().olabel;
260 if (label > close_parens_.
UpperBound())
return false;
266 template <
class Filter>
269 using FST1 =
typename Filter::FST1;
270 using FST2 =
typename Filter::FST2;
271 using Arc =
typename Filter::Arc;
287 const std::vector<std::pair<Label, Label>> *parens =
nullptr,
288 bool expand =
false,
bool keep_parens =
true)
289 : filter_(fst1, fst2, matcher1, matcher2),
290 parens_(parens ? *parens : std::vector<std::pair<
Label,
Label>>()),
292 keep_parens_(keep_parens),
297 for (
const auto &pair : *parens) {
298 parens_.push_back(pair);
299 GetMatcher1()->AddOpenParen(pair.first);
300 GetMatcher2()->AddOpenParen(pair.first);
302 GetMatcher1()->AddCloseParen(pair.second);
303 GetMatcher2()->AddCloseParen(pair.second);
310 : filter_(filter.filter_, safe),
311 parens_(filter.parens_),
312 expand_(filter.expand_),
313 keep_parens_(filter.keep_parens_),
315 stack_(filter.parens_),
324 filter_.SetState(s1, s2, fs_.GetState1());
325 if (!expand_)
return;
326 ssize_t paren_id = stack_.Top(fs.
GetState2().GetState());
327 if (paren_id != paren_id_) {
328 if (paren_id_ != -1) {
329 GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
330 GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
332 paren_id_ = paren_id;
333 if (paren_id_ != -1) {
334 GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
335 GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
341 const auto fs1 = filter_.FilterArc(arc1, arc2);
343 if (fs1 == FilterState1::NoState())
return FilterState::NoState();
344 if (arc1->olabel ==
kNoLabel && arc2->ilabel) {
346 arc1->ilabel = arc2->ilabel;
347 }
else if (arc2->ilabel) {
348 arc2->olabel = arc1->ilabel;
350 return FilterParen(arc2->ilabel, fs1, fs2);
351 }
else if (arc2->ilabel ==
kNoLabel && arc1->olabel) {
353 arc2->olabel = arc1->olabel;
355 arc1->ilabel = arc2->olabel;
357 return FilterParen(arc1->olabel, fs1, fs2);
364 if (fs_.GetState2().GetState() != 0) *w1 = Weight::Zero();
365 filter_.FilterFinal(w1, w2);
383 const auto stack_id = stack_.Find(fs2.
GetState(), label);
385 return FilterState::NoState();
392 std::vector<std::pair<Label, Label>> parens_;
402 template <
class Arc,
bool left_pdt = true>
405 Arc, ParenMatcher<Fst<Arc>>,
406 ParenFilter<AltSequenceComposeFilter<ParenMatcher<Fst<Arc>>>>> {
417 const std::vector<std::pair<Label, Label>> &parens,
418 const Fst<Arc> &ifst2,
bool expand =
false,
419 bool keep_parens =
true) {
422 filter =
new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand,
432 Arc, ParenMatcher<Fst<Arc>>,
433 ParenFilter<SequenceComposeFilter<ParenMatcher<Fst<Arc>>>>> {
444 const std::vector<std::pair<Label, Label>> &parens,
445 bool expand =
false,
bool keep_parens =
true) {
448 filter =
new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand,
465 : connect(connect), filter_type(filter_type) {}
476 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
486 if (opts.connect)
Connect(ofst);
497 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
507 if (opts.connect)
Connect(ofst);
512 #endif // FST_EXTENSIONS_PDT_COMPOSE_H_ void SetState(StateId s1, StateId s2, const FilterState &fs)
void RemoveOpenParen(Label label)
constexpr uint32_t kParenLoop
void SetState(StateId s) final
FilterState Start() const
PdtComposeFstOptions(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, const std::vector< std::pair< Label, Label >> &parens, bool expand=false, bool keep_parens=true)
typename Arc::StateId StateId
bool IsOpenParen(Label label) const
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::Arc Arc
ParenMatcher(const FST &fst, MatchType match_type, uint32_t flags=(kParenLoop|kParenList))
void LowerBound(Label label)
void RemoveCloseParen(Label label)
ParenMatcher(const FST *fst, MatchType match_type, uint32_t flags=(kParenLoop|kParenList))
const Arc & Value() const final
const Arc & Value() const
PdtComposeOptions(bool connect=true, PdtComposeFilter filter_type=PdtComposeFilter::PAREN)
void Connect(MutableFst< Arc > *fst)
constexpr uint32_t kParenList
MatchType Type(bool test) const override
FilterState FilterArc(Arc *arc1, Arc *arc2) const
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::FST1 FST1
PdtComposeFstOptions(const Fst< Arc > &ifst1, const std::vector< std::pair< Label, Label >> &parens, const Fst< Arc > &ifst2, bool expand=false, bool keep_parens=true)
typename Arc::StateId StateId
virtual uint32_t Flags() const
bool Find(Label match_label) final
MatchType Type(bool test) const
typename Arc::Label Label
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
uint64_t Properties(uint64_t props) const
Weight Final(StateId s) const final
ParenMatcher< FST > * Copy(bool safe=false) const
bool IsCloseParen(Label label) const
ParenFilter(const ParenFilter &filter, bool safe=false)
uint64_t Properties(uint64_t iprops) const
void FilterFinal(Weight *w1, Weight *w2) const
void AddCloseParen(Label label)
const FST & GetFst() const
constexpr uint64_t kOLabelInvariantProperties
PdtComposeFilter filter_type
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::FST2 FST2
ssize_t Priority(StateId s)
constexpr uint64_t kILabelInvariantProperties
ssize_t Priority(StateId s) final
typename Arc::Weight Weight
bool Member(Key key) const
typename Arc::Weight Weight
ParenFilter(const FST1 &fst1, const FST2 &fst2, Matcher1 *matcher1=nullptr, Matcher2 *matcher2=nullptr, const std::vector< std::pair< Label, Label >> *parens=nullptr, bool expand=false, bool keep_parens=true)
const FST & GetFst() const override
const FilterState2 & GetState2() const
bool Find(Label match_label)
ParenMatcher(const ParenMatcher< FST > &matcher, bool safe=false)
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::FilterState FilterState1
typename Arc::Label Label
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::Matcher1 Matcher1
typename Arc::Label Label
typename Arc::Label Label
uint64_t Properties(uint64_t inprops) const override
typename AltSequenceComposeFilter< ParenMatcher< Fst< Arc > > >::Matcher2 Matcher2
void AddOpenParen(Label label)