FST  openfst-1.8.3
OpenFst Library
expand.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Expands an MPDT to an FST.
19 
20 #ifndef FST_EXTENSIONS_MPDT_EXPAND_H_
21 #define FST_EXTENSIONS_MPDT_EXPAND_H_
22 
23 #include <cstddef>
24 #include <cstdint>
25 #include <memory>
26 #include <utility>
27 #include <vector>
28 
31 #include <fst/extensions/pdt/pdt.h>
32 #include <fst/cache.h>
33 #include <fst/connect.h>
34 #include <fst/fst.h>
35 #include <fst/impl-to-fst.h>
36 #include <fst/mutable-fst.h>
37 #include <fst/properties.h>
38 #include <fst/queue.h>
39 #include <fst/state-table.h>
40 
41 namespace fst {
42 
43 template <class Arc>
48 
50  const CacheOptions &opts = CacheOptions(), bool kp = false,
52  nullptr,
54  : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
55 };
56 
57 // Properties for an expanded PDT.
58 inline uint64_t MPdtExpandProperties(uint64_t inprops) {
59  return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
60 }
61 
62 namespace internal {
63 
64 // Implementation class for ExpandFst
65 template <class Arc>
66 class MPdtExpandFstImpl : public CacheImpl<Arc> {
67  public:
68  using Label = typename Arc::Label;
69  using StateId = typename Arc::StateId;
70  using Weight = typename Arc::Weight;
71 
72  using StackId = StateId;
75 
81 
82  using CacheBaseImpl<CacheState<Arc>>::PushArc;
83  using CacheBaseImpl<CacheState<Arc>>::HasArcs;
84  using CacheBaseImpl<CacheState<Arc>>::HasFinal;
85  using CacheBaseImpl<CacheState<Arc>>::HasStart;
86  using CacheBaseImpl<CacheState<Arc>>::SetArcs;
87  using CacheBaseImpl<CacheState<Arc>>::SetFinal;
88  using CacheBaseImpl<CacheState<Arc>>::SetStart;
89 
91  const std::vector<std::pair<Label, Label>> &parens,
92  const std::vector<Label> &assignments,
93  const MPdtExpandFstOptions<Arc> &opts)
94  : CacheImpl<Arc>(opts),
95  fst_(fst.Copy()),
96  stack_(opts.stack ? opts.stack : new ParenStack(parens, assignments)),
97  state_table_(opts.state_table ? opts.state_table
98  : new PdtStateTable<StateId, StackId>()),
99  own_stack_(!opts.stack),
100  own_state_table_(!opts.state_table),
101  keep_parentheses_(opts.keep_parentheses) {
102  SetType("expand");
103  const auto props = fst.Properties(kFstProperties, false);
104  SetProperties(MPdtExpandProperties(props), kCopyProperties);
105  SetInputSymbols(fst.InputSymbols());
106  SetOutputSymbols(fst.OutputSymbols());
107  }
108 
110  : CacheImpl<Arc>(impl),
111  fst_(impl.fst_->Copy(true)),
112  stack_(new ParenStack(*impl.stack_)),
113  state_table_(new PdtStateTable<StateId, StackId>()),
114  own_stack_(true),
115  own_state_table_(true),
116  keep_parentheses_(impl.keep_parentheses_) {
117  SetType("expand");
118  SetProperties(impl.Properties(), kCopyProperties);
119  SetInputSymbols(impl.InputSymbols());
120  SetOutputSymbols(impl.OutputSymbols());
121  }
122 
123  ~MPdtExpandFstImpl() override {
124  if (own_stack_) delete stack_;
125  if (own_state_table_) delete state_table_;
126  }
127 
129  if (!HasStart()) {
130  const auto s = fst_->Start();
131  if (s == kNoStateId) return kNoStateId;
132  const StateTuple tuple(s, 0);
133  const auto start = state_table_->FindState(tuple);
134  SetStart(start);
135  }
136  return CacheImpl<Arc>::Start();
137  }
138 
140  if (!HasFinal(s)) {
141  const auto &tuple = state_table_->Tuple(s);
142  const auto weight = fst_->Final(tuple.state_id);
143  SetFinal(s, (weight != Weight::Zero() && tuple.stack_id == 0)
144  ? weight
145  : Weight::Zero());
146  }
147  return CacheImpl<Arc>::Final(s);
148  }
149 
150  size_t NumArcs(StateId s) {
151  if (!HasArcs(s)) ExpandState(s);
152  return CacheImpl<Arc>::NumArcs(s);
153  }
154 
156  if (!HasArcs(s)) ExpandState(s);
158  }
159 
161  if (!HasArcs(s)) ExpandState(s);
163  }
164 
166  if (!HasArcs(s)) ExpandState(s);
168  }
169 
170  // Computes the outgoing transitions from a state, creating new destination
171  // states as needed.
173  const auto tuple = state_table_->Tuple(s);
174  for (ArcIterator<Fst<Arc>> aiter(*fst_, tuple.state_id); !aiter.Done();
175  aiter.Next()) {
176  auto arc = aiter.Value();
177  const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
178  if (stack_id == -1) {
179  continue; // Non-matching close parenthesis.
180  } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
181  arc.ilabel = arc.olabel = 0; // Stack push/pop.
182  }
183  const StateTuple ntuple(arc.nextstate, stack_id);
184  arc.nextstate = state_table_->FindState(ntuple);
185  PushArc(s, arc);
186  }
187  SetArcs(s);
188  }
189 
190  const ParenStack &GetStack() const { return *stack_; }
191 
193  return *state_table_;
194  }
195 
196  private:
197  std::unique_ptr<const Fst<Arc>> fst_;
198  ParenStack *stack_;
199  PdtStateTable<StateId, StackId> *state_table_;
200  const bool own_stack_;
201  const bool own_state_table_;
202  const bool keep_parentheses_;
203 
204  MPdtExpandFstImpl &operator=(const MPdtExpandFstImpl &) = delete;
205 };
206 
207 } // namespace internal
208 
209 // Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST.
210 // This version is a delayed FST. In the MPDT, some transitions are labeled with
211 // open or close parentheses. To be interpreted as an MPDT, the parens for each
212 // stack must balance on a path. The open-close parenthesis label
213 // pairs are passed using the parens argument, and the assignment of those pairs
214 // to stacks is passed using the assignments argument. Expansion enforces the
215 // parenthesis constraints. The MPDT must be
216 // expandable as an FST.
217 //
218 // This class attaches interface to implementation and handles
219 // reference counting, delegating most methods to ImplToFst.
220 template <class A>
221 class MPdtExpandFst : public ImplToFst<internal::MPdtExpandFstImpl<A>> {
222  public:
223  using Arc = A;
224  using Label = typename Arc::Label;
225  using StateId = typename Arc::StateId;
226  using Weight = typename Arc::Weight;
227 
228  using StackId = StateId;
231  using State = typename Store::State;
233 
234  friend class ArcIterator<MPdtExpandFst<Arc>>;
236 
238  const std::vector<std::pair<Label, Label>> &parens,
239  const std::vector<Label> &assignments)
240  : ImplToFst<Impl>(std::make_shared<Impl>(fst, parens, assignments,
241  MPdtExpandFstOptions<Arc>())) {}
242 
244  const std::vector<std::pair<Label, Label>> &parens,
245  const std::vector<Label> &assignments,
246  const MPdtExpandFstOptions<Arc> &opts)
247  : ImplToFst<Impl>(
248  std::make_shared<Impl>(fst, parens, assignments, opts)) {}
249 
250  // See Fst<>::Copy() for doc.
251  MPdtExpandFst(const MPdtExpandFst<Arc> &fst, bool safe = false)
252  : ImplToFst<Impl>(fst, safe) {}
253 
254  // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
255  MPdtExpandFst<Arc> *Copy(bool safe = false) const override {
256  return new MPdtExpandFst<A>(*this, safe);
257  }
258 
259  inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
260 
261  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
262  GetMutableImpl()->InitArcIterator(s, data);
263  }
264 
265  const ParenStack &GetStack() const { return GetImpl()->GetStack(); }
266 
268  return GetImpl()->GetStateTable();
269  }
270 
271  private:
274 
275  void operator=(const MPdtExpandFst &) = delete;
276 };
277 
278 // Specialization for MPdtExpandFst.
279 template <class Arc>
281  : public CacheStateIterator<MPdtExpandFst<Arc>> {
282  public:
284  : CacheStateIterator<MPdtExpandFst<Arc>>(fst, fst.GetMutableImpl()) {}
285 };
286 
287 // Specialization for MPdtExpandFst.
288 template <class Arc>
290  : public CacheArcIterator<MPdtExpandFst<Arc>> {
291  public:
292  using StateId = typename Arc::StateId;
293 
295  : CacheArcIterator<MPdtExpandFst<Arc>>(fst.GetMutableImpl(), s) {
296  if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s);
297  }
298 };
299 
300 template <class Arc>
302  StateIteratorData<Arc> *data) const {
303  data->base = std::make_unique<StateIterator<MPdtExpandFst<Arc>>>(*this);
304 }
305 
307  bool connect;
309 
310  explicit MPdtExpandOptions(bool connect = true, bool keep_parentheses = false)
311  : connect(connect), keep_parentheses(keep_parentheses) {}
312 };
313 
314 // Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST.
315 // This version writes the expanded PDT to a mutable FST. In the MPDT, some
316 // transitions are labeled with open or close parentheses. To be interpreted as
317 // an MPDT, the parens for each stack must balance on a path. The open-close
318 // parenthesis label pair sets are passed using the parens argument, and the
319 // assignment of those pairs to stacks is passed using the assignments argument.
320 // The expansion enforces the parenthesis constraints. The MPDT must be
321 // expandable as an FST.
322 template <class Arc>
323 void Expand(
324  const Fst<Arc> &ifst,
325  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
326  &parens,
327  const std::vector<typename Arc::Label> &assignments, MutableFst<Arc> *ofst,
328  const MPdtExpandOptions &opts) {
330  eopts.gc_limit = 0;
331  eopts.keep_parentheses = opts.keep_parentheses;
332  *ofst = MPdtExpandFst<Arc>(ifst, parens, assignments, eopts);
333  if (opts.connect) Connect(ofst);
334 }
335 
336 // Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST.
337 // This version writes the expanded PDT to a mutable FST. In the MPDT, some
338 // transitions are labeled with open or close parentheses. To be interpreted as
339 // an MPDT, the parens for each stack must balance on a path. The open-close
340 // parenthesis label pair sets are passed using the parens argument, and the
341 // assignment of those pairs to stacks is passed using the assignments argument.
342 // The expansion enforces the parenthesis constraints. The MPDT must be
343 // expandable as an FST.
344 template <class Arc>
345 void Expand(
346  const Fst<Arc> &ifst,
347  const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
348  &parens,
349  const std::vector<typename Arc::Label> &assignments, MutableFst<Arc> *ofst,
350  bool connect = true, bool keep_parentheses = false) {
351  const MPdtExpandOptions opts(connect, keep_parentheses);
352  Expand(ifst, parens, assignments, ofst, opts);
353 }
354 
355 } // namespace fst
356 
357 #endif // FST_EXTENSIONS_MPDT_EXPAND_H_
ssize_t NumOutputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:124
CacheOptions(bool gc=FST_FLAGS_fst_default_cache_gc, size_t gc_limit=FST_FLAGS_fst_default_cache_gc_limit)
Definition: cache.h:54
virtual uint64_t Properties(uint64_t mask, bool test) const =0
typename Store::State State
Definition: expand.h:231
void InitStateIterator(StateIteratorData< Arc > *data) const override
Definition: expand.h:301
const ParenStack & GetStack() const
Definition: expand.h:190
typename Arc::Weight Weight
Definition: expand.h:226
MPdtExpandFstImpl(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const std::vector< Label > &assignments, const MPdtExpandFstOptions< Arc > &opts)
Definition: expand.h:90
size_t NumInputEpsilons(StateId s)
Definition: expand.h:155
Arc::Weight Final(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:107
const SymbolTable * OutputSymbols() const
Definition: fst.h:761
constexpr uint64_t kInitialAcyclic
Definition: properties.h:116
SetType
Definition: set-weight.h:59
typename MPdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1156
void Connect(MutableFst< Arc > *fst)
Definition: connect.h:47
ssize_t NumArcs(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:113
constexpr int kNoStateId
Definition: fst.h:196
typename Arc::Label Label
Definition: expand.h:68
ssize_t NumInputEpsilons(const ExpandedFst< Arc > &fst, typename Arc::StateId s)
Definition: expanded-fst.h:118
StateIterator(const MPdtExpandFst< Arc > &fst)
Definition: expand.h:283
virtual uint64_t Properties() const
Definition: fst.h:701
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data)
Definition: expand.h:165
constexpr uint64_t kCopyProperties
Definition: properties.h:163
constexpr uint64_t kAcyclic
Definition: properties.h:111
uint64_t MPdtExpandProperties(uint64_t inprops)
Definition: expand.h:58
typename FirstCacheStore< VectorCacheStore< CacheState< Arc > > >::State State
Definition: cache.h:673
void InitArcIterator(StateId s, ArcIteratorData< Arc > *data) const override
Definition: expand.h:261
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:192
PdtStateTable< typename Arc::StateId, typename Arc::StateId > * state_table
Definition: expand.h:47
const ParenStack & GetStack() const
Definition: expand.h:265
std::unique_ptr< StateIteratorBase< Arc > > base
Definition: fst.h:382
typename Arc::StateId StateId
Definition: expand.h:225
constexpr uint64_t kFstProperties
Definition: properties.h:326
constexpr uint64_t kUnweighted
Definition: properties.h:106
void ExpandState(StateId s)
Definition: expand.h:172
MPdtExpandOptions(bool connect=true, bool keep_parentheses=false)
Definition: expand.h:310
size_t NumArcs(StateId s)
Definition: expand.h:150
internal::MPdtStack< typename Arc::StateId, typename Arc::Label > * stack
Definition: expand.h:46
typename MPdtExpandFst< Arc >::Arc Arc
Definition: cache.h:1202
virtual const SymbolTable * InputSymbols() const =0
const SymbolTable * InputSymbols() const
Definition: fst.h:759
MPdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const std::vector< Label > &assignments)
Definition: expand.h:237
MPdtExpandFst< Arc > * Copy(bool safe=false) const override
Definition: expand.h:255
ArcIterator(const MPdtExpandFst< Arc > &fst, StateId s)
Definition: expand.h:294
size_t NumOutputEpsilons(StateId s)
Definition: expand.h:160
typename Arc::Weight Weight
Definition: expand.h:70
typename Arc::StateId StateId
Definition: expand.h:292
MPdtExpandFstOptions(const CacheOptions &opts=CacheOptions(), bool kp=false, internal::MPdtStack< typename Arc::StateId, typename Arc::Label > *s=nullptr, PdtStateTable< typename Arc::StateId, typename Arc::StateId > *st=nullptr)
Definition: expand.h:49
void Expand(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, const std::vector< typename Arc::Label > &assignments, MutableFst< Arc > *ofst, const MPdtExpandOptions &opts)
Definition: expand.h:323
typename CacheState< Arc >::Arc Arc
Definition: cache.h:859
Impl * GetMutableImpl() const
Definition: impl-to-fst.h:125
Weight Final(StateId s)
Definition: expand.h:139
MPdtExpandFstImpl(const MPdtExpandFstImpl &impl)
Definition: expand.h:109
typename Arc::StateId StateId
Definition: expand.h:69
size_t gc_limit
Definition: cache.h:52
MPdtExpandFst(const Fst< Arc > &fst, const std::vector< std::pair< Label, Label >> &parens, const std::vector< Label > &assignments, const MPdtExpandFstOptions< Arc > &opts)
Definition: expand.h:243
typename Arc::Label Label
Definition: expand.h:224
const PdtStateTable< StateId, StackId > & GetStateTable() const
Definition: expand.h:267
MPdtExpandFst(const MPdtExpandFst< Arc > &fst, bool safe=false)
Definition: expand.h:251
const Impl * GetImpl() const
Definition: impl-to-fst.h:123
constexpr uint64_t kAcceptor
Definition: properties.h:64
virtual const SymbolTable * OutputSymbols() const =0