FST  openfst-1.8.3
OpenFst Library
pdtscript.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Convenience file for including all PDT operations at once, and/or
19 // registering them for new arc types.
20 
21 #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
22 #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
23 
24 #include <algorithm>
25 #include <cstddef>
26 #include <cstdint>
27 #include <string>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31 
32 #include <fst/log.h>
39 #include <fst/compose.h> // for ComposeOptions
40 #include <fst/fst.h>
41 #include <fst/mutable-fst.h>
42 #include <fst/queue.h>
43 #include <fst/util.h>
44 #include <fst/script/arg-packs.h>
45 #include <fst/script/fst-class.h>
46 #include <fst/script/fstscript.h>
47 #include <fst/script/script-impl.h>
50 
51 namespace fst {
52 namespace script {
53 
54 using PdtComposeArgs =
55  std::tuple<const FstClass &, const FstClass &,
56  const std::vector<std::pair<int64_t, int64_t>> &,
58 
59 template <class Arc>
60 void Compose(PdtComposeArgs *args) {
61  const Fst<Arc> &ifst1 = *(std::get<0>(*args).GetFst<Arc>());
62  const Fst<Arc> &ifst2 = *(std::get<1>(*args).GetFst<Arc>());
63  MutableFst<Arc> *ofst = std::get<3>(*args)->GetMutableFst<Arc>();
64  // In case Arc::Label is not the same as FstClass::Label, we make a
65  // copy. Truncation may occur if FstClass::Label has more precision than
66  // Arc::Label.
67  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> typed_parens(
68  std::get<2>(*args).size());
69  std::copy(std::get<2>(*args).begin(), std::get<2>(*args).end(),
70  typed_parens.begin());
71  if (std::get<5>(*args)) {
72  Compose(ifst1, typed_parens, ifst2, ofst, std::get<4>(*args));
73  } else {
74  Compose(ifst1, ifst2, typed_parens, ofst, std::get<4>(*args));
75  }
76 }
77 
78 void Compose(const FstClass &ifst1, const FstClass &ifst2,
79  const std::vector<std::pair<int64_t, int64_t>> &parens,
80  MutableFstClass *ofst, const PdtComposeOptions &opts,
81  bool left_pdt);
82 
84  bool connect;
87 
88  PdtExpandOptions(bool c, bool k, const WeightClass &w)
89  : connect(c), keep_parentheses(k), weight_threshold(w) {}
90 };
91 
92 using PdtExpandArgs =
93  std::tuple<const FstClass &,
94  const std::vector<std::pair<int64_t, int64_t>> &,
96 
97 template <class Arc>
98 void Expand(PdtExpandArgs *args) {
99  const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
100  MutableFst<Arc> *ofst = std::get<2>(*args)->GetMutableFst<Arc>();
101  // In case Arc::Label is not the same as FstClass::Label, we make a
102  // copy. Truncation may occur if FstClass::Label has more precision than
103  // Arc::Label.
104  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> typed_parens(
105  std::get<1>(*args).size());
106  std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(),
107  typed_parens.begin());
108  Expand(fst, typed_parens, ofst,
110  std::get<3>(*args).connect, std::get<3>(*args).keep_parentheses,
111  *(std::get<3>(*args)
112  .weight_threshold.GetWeight<typename Arc::Weight>())));
113 }
114 
115 void Expand(const FstClass &ifst,
116  const std::vector<std::pair<int64_t, int64_t>> &parens,
117  MutableFstClass *ofst, const PdtExpandOptions &opts);
118 
119 void Expand(const FstClass &ifst,
120  const std::vector<std::pair<int64_t, int64_t>> &parens,
121  MutableFstClass *ofst, bool connect, bool keep_parentheses,
123 
124 using PdtReplaceArgs =
125  std::tuple<const std::vector<std::pair<int64_t, const FstClass *>> &,
126  MutableFstClass *, std::vector<std::pair<int64_t, int64_t>> *,
127  int64_t, PdtParserType, int64_t, const std::string &,
128  const std::string &>;
129 
130 template <class Arc>
131 void Replace(PdtReplaceArgs *args) {
132  const auto &untyped_pairs = std::get<0>(*args);
133  auto size = untyped_pairs.size();
134  std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>> typed_pairs(
135  size);
136  for (size_t i = 0; i < size; ++i) {
137  typed_pairs[i].first = untyped_pairs[i].first;
138  typed_pairs[i].second = untyped_pairs[i].second->GetFst<Arc>();
139  }
140  MutableFst<Arc> *ofst = std::get<1>(*args)->GetMutableFst<Arc>();
141  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> typed_parens;
142  const PdtReplaceOptions<Arc> opts(std::get<3>(*args), std::get<4>(*args),
143  std::get<5>(*args), std::get<6>(*args),
144  std::get<7>(*args));
145  Replace(typed_pairs, ofst, &typed_parens, opts);
146  // Copies typed parens into arg3.
147  std::get<2>(*args)->resize(typed_parens.size());
148  std::copy(typed_parens.begin(), typed_parens.end(),
149  std::get<2>(*args)->begin());
150 }
151 
152 void Replace(const std::vector<std::pair<int64_t, const FstClass *>> &pairs,
153  MutableFstClass *ofst,
154  std::vector<std::pair<int64_t, int64_t>> *parens, int64_t root,
155  PdtParserType parser_type = PdtParserType::LEFT,
156  int64_t start_paren_labels = kNoLabel,
157  const std::string &left_paren_prefix = "(_",
158  const std::string &right_paren_prefix = "_)");
159 
160 using PdtReverseArgs =
161  std::tuple<const FstClass &,
162  const std::vector<std::pair<int64_t, int64_t>> &,
163  MutableFstClass *>;
164 
165 template <class Arc>
166 void Reverse(PdtReverseArgs *args) {
167  const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
168  MutableFst<Arc> *ofst = std::get<2>(*args)->GetMutableFst<Arc>();
169  // In case Arc::Label is not the same as FstClass::Label, we make a
170  // copy. Truncation may occur if FstClass::Label has more precision than
171  // Arc::Label.
172  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> typed_parens(
173  std::get<1>(*args).size());
174  std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(),
175  typed_parens.begin());
176  Reverse(fst, typed_parens, ofst);
177 }
178 
179 void Reverse(const FstClass &ifst,
180  const std::vector<std::pair<int64_t, int64_t>> &,
181  MutableFstClass *ofst);
182 
183 // PDT SHORTESTPATH
184 
188  bool path_gc;
189 
190  explicit PdtShortestPathOptions(QueueType qt = FIFO_QUEUE, bool kp = false,
191  bool gc = true)
192  : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
193 };
194 
195 using PdtShortestPathArgs =
196  std::tuple<const FstClass &,
197  const std::vector<std::pair<int64_t, int64_t>> &,
198  MutableFstClass *, const PdtShortestPathOptions &>;
199 
200 template <class Arc>
202  const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
203  MutableFst<Arc> *ofst = std::get<2>(*args)->GetMutableFst<Arc>();
204  const PdtShortestPathOptions &opts = std::get<3>(*args);
205  // In case Arc::Label is not the same as FstClass::Label, we make a
206  // copy. Truncation may occur if FstClass::Label has more precision than
207  // Arc::Label.
208  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> typed_parens(
209  std::get<1>(*args).size());
210  std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(),
211  typed_parens.begin());
212  switch (opts.queue_type) {
213  default:
214  FSTERROR() << "Unknown queue type: " << opts.queue_type;
215  [[fallthrough]];
216  case FIFO_QUEUE: {
217  using Queue = FifoQueue<typename Arc::StateId>;
218  fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
219  opts.path_gc);
220  ShortestPath(fst, typed_parens, ofst, spopts);
221  return;
222  }
223  case LIFO_QUEUE: {
224  using Queue = LifoQueue<typename Arc::StateId>;
225  fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
226  opts.path_gc);
227  ShortestPath(fst, typed_parens, ofst, spopts);
228  return;
229  }
230  case STATE_ORDER_QUEUE: {
232  fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
233  opts.path_gc);
234  ShortestPath(fst, typed_parens, ofst, spopts);
235  return;
236  }
237  }
238 }
239 
240 void ShortestPath(
241  const FstClass &ifst,
242  const std::vector<std::pair<int64_t, int64_t>> &parens,
243  MutableFstClass *ofst,
244  const PdtShortestPathOptions &opts = PdtShortestPathOptions());
245 
246 // PRINT INFO
247 
248 using PdtInfoArgs = std::pair<const FstClass &,
249  const std::vector<std::pair<int64_t, int64_t>> &>;
250 
251 template <class Arc>
252 void Info(PdtInfoArgs *args) {
253  const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
254  // In case Arc::Label is not the same as FstClass::Label, we make a
255  // copy. Truncation may occur if FstClass::Label has more precision than
256  // Arc::Label.
257  std::vector<std::pair<typename Arc::Label, typename Arc::Label>> typed_parens(
258  std::get<1>(*args).size());
259  std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(),
260  typed_parens.begin());
261  PdtInfo<Arc> pdtinfo(fst, typed_parens);
262  pdtinfo.Print();
263 }
264 
265 void Info(const FstClass &ifst,
266  const std::vector<std::pair<int64_t, int64_t>> &parens);
267 
268 } // namespace script
269 } // namespace fst
270 
271 #define REGISTER_FST_PDT_OPERATIONS(ArcType) \
272  REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \
273  REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \
274  REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \
275  REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \
276  REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \
277  REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
278 #endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_
PdtShortestPathOptions(QueueType qt=FIFO_QUEUE, bool kp=false, bool gc=true)
Definition: pdtscript.h:190
constexpr int kNoLabel
Definition: fst.h:195
MutableFst< Arc > * GetMutableFst()
Definition: fst-class.h:534
PdtExpandOptions(bool c, bool k, const WeightClass &w)
Definition: pdtscript.h:88
QueueType
Definition: queue.h:76
void Reverse(const FstClass &ifst, const std::vector< std::pair< int64_t, int64_t >> &parens, std::vector< int64_t > *assignments, MutableFstClass *ofst)
Definition: mpdtscript.cc:72
void Replace(const std::vector< std::pair< int64_t, const FstClass * >> &pairs, MutableFstClass *ofst, std::vector< std::pair< int64_t, int64_t >> *parens, int64_t root, PdtParserType parser_type, int64_t start_paren_labels, const std::string &left_paren_prefix, const std::string &right_paren_prefix)
Definition: pdtscript.cc:75
PdtParserType
Definition: replace.h:69
void ShortestPath(const FstClass &ifst, const std::vector< std::pair< int64_t, int64_t >> &parens, MutableFstClass *ofst, const PdtShortestPathOptions &opts)
Definition: pdtscript.cc:109
std::tuple< const std::vector< std::pair< int64_t, const FstClass * >> &, MutableFstClass *, std::vector< std::pair< int64_t, int64_t >> *, int64_t, PdtParserType, int64_t, const std::string &, const std::string & > PdtReplaceArgs
Definition: pdtscript.h:128
std::tuple< const FstClass &, const std::vector< std::pair< int64_t, int64_t >> &, MutableFstClass * > PdtReverseArgs
Definition: pdtscript.h:163
void Info(const std::vector< std::string > &sources, const std::string &arc_type, const std::string &begin_key, const std::string &end_key, bool list_fsts)
Definition: farscript.cc:145
#define FSTERROR()
Definition: util.h:56
void Expand(const FstClass &ifst, const std::vector< std::pair< int64_t, int64_t >> &parens, const std::vector< int64_t > &assignments, MutableFstClass *ofst, const MPdtExpandOptions &opts)
Definition: mpdtscript.cc:55
std::pair< const FstClass &, const std::vector< std::pair< int64_t, int64_t >> & > PdtInfoArgs
Definition: pdtscript.h:249
const WeightClass & weight_threshold
Definition: pdtscript.h:86
std::tuple< const FstClass &, const std::vector< std::pair< int64_t, int64_t >> &, MutableFstClass *, const PdtShortestPathOptions & > PdtShortestPathArgs
Definition: pdtscript.h:198
std::tuple< const FstClass &, const FstClass &, const std::vector< std::pair< int64_t, int64_t >> &, MutableFstClass *, const PdtComposeOptions &, bool > PdtComposeArgs
Definition: pdtscript.h:57
void Compose(const FstClass &ifst1, const FstClass &ifst2, const std::vector< std::pair< int64_t, int64_t >> &parens, const std::vector< int64_t > &assignments, MutableFstClass *ofst, const MPdtComposeOptions &copts, bool left_pdt)
Definition: mpdtscript.cc:41
std::tuple< const FstClass &, const std::vector< std::pair< int64_t, int64_t >> &, MutableFstClass *, const PdtExpandOptions & > PdtExpandArgs
Definition: pdtscript.h:95