FST  openfst-1.8.3
OpenFst Library
getters.cc
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #include <fst/script/getters.h>
19 
20 #include <string>
21 
22 #include <fst/compose.h>
23 #include <fst/determinize.h>
24 #include <fst/epsnormalize.h>
25 #include <fst/project.h>
26 #include <fst/queue.h>
27 #include <fst/rational.h>
28 #include <fst/replace-util.h>
29 #include <fst/reweight.h>
30 #include <fst/string.h>
32 #include <fst/script/arcsort.h>
33 #include <fst/script/map.h>
34 #include <fst/script/script-impl.h>
35 #include <string_view>
36 
37 namespace fst {
38 namespace script {
39 
40 bool GetArcFilterType(std::string_view str, ArcFilterType *arc_filter_type) {
41  if (str == "any") {
42  *arc_filter_type = ArcFilterType::ANY;
43  } else if (str == "epsilon") {
44  *arc_filter_type = ArcFilterType::EPSILON;
45  } else if (str == "iepsilon") {
46  *arc_filter_type = ArcFilterType::INPUT_EPSILON;
47  } else if (str == "oepsilon") {
48  *arc_filter_type = ArcFilterType::OUTPUT_EPSILON;
49  } else {
50  return false;
51  }
52  return true;
53 }
54 
55 bool GetArcSortType(std::string_view str, ArcSortType *sort_type) {
56  if (str == "ilabel") {
57  *sort_type = ArcSortType::ILABEL;
58  } else if (str == "olabel") {
59  *sort_type = ArcSortType::OLABEL;
60  } else {
61  return false;
62  }
63  return true;
64 }
65 
66 bool GetClosureType(std::string_view str, ClosureType *closure_type) {
67  if (str == "star") {
68  *closure_type = CLOSURE_STAR;
69  } else if (str == "plus") {
70  *closure_type = CLOSURE_PLUS;
71  } else {
72  return false;
73  }
74  return true;
75 }
76 
77 bool GetComposeFilter(std::string_view str, ComposeFilter *compose_filter) {
78  if (str == "alt_sequence") {
79  *compose_filter = ALT_SEQUENCE_FILTER;
80  } else if (str == "auto") {
81  *compose_filter = AUTO_FILTER;
82  } else if (str == "match") {
83  *compose_filter = MATCH_FILTER;
84  } else if (str == "no_match") {
85  *compose_filter = NO_MATCH_FILTER;
86  } else if (str == "null") {
87  *compose_filter = NULL_FILTER;
88  } else if (str == "sequence") {
89  *compose_filter = SEQUENCE_FILTER;
90  } else if (str == "trivial") {
91  *compose_filter = TRIVIAL_FILTER;
92  } else {
93  return false;
94  }
95  return true;
96 }
97 
98 bool GetDeterminizeType(std::string_view str, DeterminizeType *det_type) {
99  if (str == "functional") {
100  *det_type = DETERMINIZE_FUNCTIONAL;
101  } else if (str == "nonfunctional") {
102  *det_type = DETERMINIZE_NONFUNCTIONAL;
103  } else if (str == "disambiguate") {
104  *det_type = DETERMINIZE_DISAMBIGUATE;
105  } else {
106  return false;
107  }
108  return true;
109 }
110 
111 bool GetEpsNormalizeType(std::string_view str,
112  EpsNormalizeType *eps_norm_type) {
113  if (str == "input") {
114  *eps_norm_type = EPS_NORM_INPUT;
115  } else if (str == "output") {
116  *eps_norm_type = EPS_NORM_OUTPUT;
117  } else {
118  return false;
119  }
120  return true;
121 }
122 
123 bool GetMapType(std::string_view str, MapType *map_type) {
124  if (str == "arc_sum") {
125  *map_type = MapType::ARC_SUM;
126  } else if (str == "arc_unique") {
127  *map_type = MapType::ARC_UNIQUE;
128  } else if (str == "identity") {
129  *map_type = MapType::IDENTITY;
130  } else if (str == "input_epsilon") {
131  *map_type = MapType::INPUT_EPSILON;
132  } else if (str == "invert") {
133  *map_type = MapType::INVERT;
134  } else if (str == "output_epsilon") {
135  *map_type = MapType::OUTPUT_EPSILON;
136  } else if (str == "plus") {
137  *map_type = MapType::PLUS;
138  } else if (str == "power") {
139  *map_type = MapType::POWER;
140  } else if (str == "quantize") {
141  *map_type = MapType::QUANTIZE;
142  } else if (str == "rmweight") {
143  *map_type = MapType::RMWEIGHT;
144  } else if (str == "superfinal") {
145  *map_type = MapType::SUPERFINAL;
146  } else if (str == "times") {
147  *map_type = MapType::TIMES;
148  } else if (str == "to_log") {
149  *map_type = MapType::TO_LOG;
150  } else if (str == "to_log64") {
151  *map_type = MapType::TO_LOG64;
152  } else if (str == "to_std" || str == "to_standard") {
153  *map_type = MapType::TO_STD;
154  } else {
155  return false;
156  }
157  return true;
158 }
159 
160 bool GetProjectType(std::string_view str, ProjectType *project_type) {
161  if (str == "input") {
162  *project_type = ProjectType::INPUT;
163  } else if (str == "output") {
164  *project_type = ProjectType::OUTPUT;
165  } else {
166  return false;
167  }
168  return true;
169 }
170 
171 bool GetRandArcSelection(std::string_view str, RandArcSelection *ras) {
172  if (str == "uniform") {
174  } else if (str == "log_prob") {
176  } else if (str == "fast_log_prob") {
178  } else {
179  return false;
180  }
181  return true;
182 }
183 
184 bool GetQueueType(std::string_view str, QueueType *queue_type) {
185  if (str == "auto") {
186  *queue_type = AUTO_QUEUE;
187  } else if (str == "fifo") {
188  *queue_type = FIFO_QUEUE;
189  } else if (str == "lifo") {
190  *queue_type = LIFO_QUEUE;
191  } else if (str == "shortest") {
192  *queue_type = SHORTEST_FIRST_QUEUE;
193  } else if (str == "state") {
194  *queue_type = STATE_ORDER_QUEUE;
195  } else if (str == "top") {
196  *queue_type = TOP_ORDER_QUEUE;
197  } else {
198  return false;
199  }
200  return true;
201 }
202 
203 bool GetReplaceLabelType(std::string_view str, bool epsilon_on_replace,
204  ReplaceLabelType *rlt) {
205  if (epsilon_on_replace || str == "neither") {
206  *rlt = REPLACE_LABEL_NEITHER;
207  } else if (str == "input") {
208  *rlt = REPLACE_LABEL_INPUT;
209  } else if (str == "output") {
210  *rlt = REPLACE_LABEL_OUTPUT;
211  } else if (str == "both") {
212  *rlt = REPLACE_LABEL_BOTH;
213  } else {
214  return false;
215  }
216  return true;
217 }
218 
219 bool GetReweightType(std::string_view str, ReweightType *reweight_type) {
220  if (str == "to_initial") {
221  *reweight_type = REWEIGHT_TO_INITIAL;
222  } else if (str == "to_final") {
223  *reweight_type = REWEIGHT_TO_FINAL;
224  } else {
225  return false;
226  }
227  return true;
228 }
229 
230 uint64_t GetSeed(uint64_t seed) {
231  return seed == kDefaultSeed ? time(nullptr) : seed;
232 }
233 
234 bool GetTokenType(std::string_view str, TokenType *token_type) {
235  if (str == "byte") {
236  *token_type = TokenType::BYTE;
237  } else if (str == "utf8") {
238  *token_type = TokenType::UTF8;
239  } else if (str == "symbol") {
240  *token_type = TokenType::SYMBOL;
241  } else {
242  return false;
243  }
244  return true;
245 }
246 
247 } // namespace script
248 } // namespace fst
bool GetTokenType(std::string_view str, TokenType *token_type)
Definition: getters.cc:234
bool GetMapType(std::string_view str, MapType *map_type)
Definition: getters.cc:123
constexpr uint64_t kDefaultSeed
Definition: getters.h:49
bool GetArcSortType(std::string_view str, ArcSortType *sort_type)
Definition: getters.cc:55
QueueType
Definition: queue.h:76
MapType
Definition: map.h:56
ReplaceLabelType
Definition: replace-util.h:48
ReweightType
Definition: reweight.h:35
bool GetRandArcSelection(std::string_view str, RandArcSelection *ras)
Definition: getters.cc:171
bool GetDeterminizeType(std::string_view str, DeterminizeType *det_type)
Definition: getters.cc:98
ProjectType
Definition: project.h:37
bool GetReplaceLabelType(std::string_view str, bool epsilon_on_replace, ReplaceLabelType *rlt)
Definition: getters.cc:203
bool GetComposeFilter(std::string_view str, ComposeFilter *compose_filter)
Definition: getters.cc:77
bool GetClosureType(std::string_view str, ClosureType *closure_type)
Definition: getters.cc:66
EpsNormalizeType
Definition: epsnormalize.h:38
uint64_t GetSeed(uint64_t seed)
Definition: getters.cc:230
bool GetQueueType(std::string_view str, QueueType *queue_type)
Definition: getters.cc:184
TokenType
Definition: string.h:49
ClosureType
Definition: rational.h:45
bool GetEpsNormalizeType(std::string_view str, EpsNormalizeType *eps_norm_type)
Definition: getters.cc:111
ComposeFilter
Definition: compose.h:953
bool GetReweightType(std::string_view str, ReweightType *reweight_type)
Definition: getters.cc:219
bool GetArcFilterType(std::string_view str, ArcFilterType *arc_filter_type)
Definition: getters.cc:40
DeterminizeType
Definition: determinize.h:394
bool GetProjectType(std::string_view str, ProjectType *project_type)
Definition: getters.cc:160