FST  openfst-1.8.3
OpenFst Library
encodemapper-class.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
18 #ifndef FST_SCRIPT_ENCODEMAPPER_CLASS_H_
19 #define FST_SCRIPT_ENCODEMAPPER_CLASS_H_
20 
21 #include <cstdint>
22 #include <iostream>
23 #include <istream>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <utility>
28 
29 #include <fst/encode.h>
30 #include <fst/generic-register.h>
31 #include <fst/symbol-table.h>
32 #include <fst/util.h>
33 #include <fst/script/arc-class.h>
34 #include <fst/script/fst-class.h>
35 #include <string_view>
36 
37 // Scripting API support for EncodeMapper.
38 
39 namespace fst {
40 namespace script {
41 
42 // Virtual interface implemented by each concrete EncodeMapperClassImpl<Arc>.
44  public:
45  // Returns an encoded ArcClass.
46  virtual ArcClass operator()(const ArcClass &) = 0;
47  virtual const std::string &ArcType() const = 0;
48  virtual const std::string &WeightType() const = 0;
49  virtual EncodeMapperImplBase *Copy() const = 0;
50  virtual uint8_t Flags() const = 0;
51  virtual uint64_t Properties(uint64_t) = 0;
52  virtual EncodeType Type() const = 0;
53  virtual bool Write(const std::string &) const = 0;
54  virtual bool Write(std::ostream &, const std::string &) const = 0;
55  virtual const SymbolTable *InputSymbols() const = 0;
56  virtual const SymbolTable *OutputSymbols() const = 0;
57  virtual void SetInputSymbols(const SymbolTable *) = 0;
58  virtual void SetOutputSymbols(const SymbolTable *) = 0;
59  virtual ~EncodeMapperImplBase() = default;
60 };
61 
62 // Templated implementation.
63 template <class Arc>
65  public:
66  explicit EncodeMapperClassImpl(const EncodeMapper<Arc> &mapper)
67  : mapper_(mapper) {}
68 
69  ArcClass operator()(const ArcClass &a) final;
70 
71  const std::string &ArcType() const final { return Arc::Type(); }
72 
73  const std::string &WeightType() const final { return Arc::Weight::Type(); }
74 
76  return new EncodeMapperClassImpl<Arc>(mapper_);
77  }
78 
79  uint8_t Flags() const final { return mapper_.Flags(); }
80 
81  uint64_t Properties(uint64_t inprops) final {
82  return mapper_.Properties(inprops);
83  }
84 
85  EncodeType Type() const final { return mapper_.Type(); }
86 
87  bool Write(const std::string &source) const final {
88  return mapper_.Write(source);
89  }
90 
91  bool Write(std::ostream &strm, const std::string &source) const final {
92  return mapper_.Write(strm, source);
93  }
94 
95  const SymbolTable *InputSymbols() const final {
96  return mapper_.InputSymbols();
97  }
98 
99  const SymbolTable *OutputSymbols() const final {
100  return mapper_.OutputSymbols();
101  }
102 
103  void SetInputSymbols(const SymbolTable *syms) final {
104  mapper_.SetInputSymbols(syms);
105  }
106 
107  void SetOutputSymbols(const SymbolTable *syms) final {
108  mapper_.SetOutputSymbols(syms);
109  }
110 
111  ~EncodeMapperClassImpl() override = default;
112 
113  const EncodeMapper<Arc> *GetImpl() const { return &mapper_; }
114 
115  EncodeMapper<Arc> *GetImpl() { return &mapper_; }
116 
117  private:
118  EncodeMapper<Arc> mapper_;
119 };
120 
121 template <class Arc>
123  const Arc arc(a.ilabel, a.olabel,
124  *(a.weight.GetWeight<typename Arc::Weight>()), a.nextstate);
125  return ArcClass(mapper_(arc));
126 }
127 
129  public:
130  EncodeMapperClass() : impl_(nullptr) {}
131 
132  EncodeMapperClass(std::string_view arc_type, uint8_t flags,
133  EncodeType type = ENCODE);
134 
135  template <class Arc>
136  explicit EncodeMapperClass(const EncodeMapper<Arc> &mapper)
137  : impl_(std::make_unique<EncodeMapperClassImpl<Arc>>(mapper)) {}
138 
140  : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
141 
143  impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy());
144  return *this;
145  }
146 
147  ArcClass operator()(const ArcClass &arc) { return (*impl_)(arc); }
148 
149  const std::string &ArcType() const { return impl_->ArcType(); }
150 
151  const std::string &WeightType() const { return impl_->WeightType(); }
152 
153  uint8_t Flags() const { return impl_->Flags(); }
154 
155  uint64_t Properties(uint64_t inprops) { return impl_->Properties(inprops); }
156 
157  EncodeType Type() const { return impl_->Type(); }
158 
159  static std::unique_ptr<EncodeMapperClass> Read(
160  const std::string &source);
161 
162  static std::unique_ptr<EncodeMapperClass> Read(
163  std::istream &strm, const std::string &source);
164 
165  bool Write(const std::string &source) const { return impl_->Write(source); }
166 
167  bool Write(std::ostream &strm, const std::string &source) const {
168  return impl_->Write(strm, source);
169  }
170 
171  const SymbolTable *InputSymbols() const { return impl_->InputSymbols(); }
172 
173  const SymbolTable *OutputSymbols() const { return impl_->OutputSymbols(); }
174 
175  void SetInputSymbols(const SymbolTable *syms) {
176  impl_->SetInputSymbols(syms);
177  }
178 
179  void SetOutputSymbols(const SymbolTable *syms) {
180  impl_->SetOutputSymbols(syms);
181  }
182 
183  // Implementation stuff.
184 
185  template <class Arc>
187  if (Arc::Type() != ArcType()) {
188  return nullptr;
189  } else {
190  auto *typed_impl = down_cast<EncodeMapperClassImpl<Arc> *>(impl_.get());
191  return typed_impl->GetImpl();
192  }
193  }
194 
195  template <class Arc>
197  if (Arc::Type() != ArcType()) {
198  return nullptr;
199  } else {
200  auto *typed_impl = down_cast<EncodeMapperClassImpl<Arc> *>(impl_.get());
201  return typed_impl->GetImpl();
202  }
203  }
204 
205  // Required for registration.
206 
207  template <class Arc>
208  static std::unique_ptr<EncodeMapperClass> Read(std::istream &strm,
209  std::string_view source) {
210  std::unique_ptr<EncodeMapper<Arc>> mapper(
211  EncodeMapper<Arc>::Read(strm, source));
212  return mapper ? std::make_unique<EncodeMapperClass>(*mapper) : nullptr;
213  }
214 
215  template <class Arc>
216  static std::unique_ptr<EncodeMapperImplBase> Create(
217  uint8_t flags, EncodeType type = ENCODE) {
218  return std::make_unique<EncodeMapperClassImpl<Arc>>(
219  EncodeMapper<Arc>(flags, type));
220  }
221 
222  private:
223  explicit EncodeMapperClass(std::unique_ptr<EncodeMapperImplBase> impl)
224  : impl_(std::move(impl)) {}
225 
226  const EncodeMapperImplBase *GetImpl() const { return impl_.get(); }
227 
228  EncodeMapperImplBase *GetImpl() { return impl_.get(); }
229 
230  std::unique_ptr<EncodeMapperImplBase> impl_;
231 };
232 
233 // Registration for EncodeMapper types.
234 
235 // This class definition is to avoid a nested class definition inside the
236 // EncodeMapperIORegistration struct.
237 
238 template <class Reader, class Creator>
240  Reader reader;
241  Creator creator;
242 
243  EncodeMapperClassRegEntry(Reader reader, Creator creator)
244  : reader(reader), creator(creator) {}
245 
246  EncodeMapperClassRegEntry() : reader(nullptr), creator(nullptr) {}
247 };
248 
249 template <class Reader, class Creator>
251  : public GenericRegister<std::string,
252  EncodeMapperClassRegEntry<Reader, Creator>,
253  EncodeMapperClassIORegister<Reader, Creator>> {
254  public:
255  Reader GetReader(std::string_view arc_type) const {
256  return this->GetEntry(arc_type).reader;
257  }
258 
259  Creator GetCreator(std::string_view arc_type) const {
260  return this->GetEntry(arc_type).creator;
261  }
262 
263  protected:
264  std::string ConvertKeyToSoFilename(std::string_view key) const final {
265  std::string legal_type(key);
266  ConvertToLegalCSymbol(&legal_type);
267  legal_type.append("-arc.so");
268  return legal_type;
269  }
270 };
271 
272 // Struct containing everything needed to register a particular type
274  using Reader = std::unique_ptr<EncodeMapperClass> (*)(
275  std::istream &stream, std::string_view source);
276 
277  using Creator = std::unique_ptr<EncodeMapperImplBase> (*)(uint8_t flags,
278  EncodeType type);
279 
281 
282  // EncodeMapper register.
284 
285  // EncodeMapper register-er.
286  using Registerer =
288 };
289 
290 #define REGISTER_ENCODEMAPPER_CLASS(Arc) \
291  static EncodeMapperClassIORegistration::Registerer \
292  EncodeMapperClass_##Arc##_registerer( \
293  Arc::Type(), \
294  EncodeMapperClassIORegistration::Entry( \
295  EncodeMapperClass::Read<Arc>, EncodeMapperClass::Create<Arc>));
296 
297 } // namespace script
298 } // namespace fst
299 
300 #endif // FST_SCRIPT_ENCODEMAPPER_CLASS_H_
ArcClass operator()(const ArcClass &a) final
std::unique_ptr< EncodeMapperImplBase >(*)(uint8_t flags, EncodeType type) Creator
void ConvertToLegalCSymbol(std::string *s)
Definition: util.cc:75
bool Write(const std::string &source) const final
void SetOutputSymbols(const SymbolTable *syms)
virtual const SymbolTable * OutputSymbols() const =0
bool Write(std::ostream &strm, const std::string &source) const
virtual const std::string & ArcType() const =0
EncodeMapperClass(const EncodeMapperClass &other)
const std::string & WeightType() const final
uint64_t Properties(uint64_t inprops) final
virtual const SymbolTable * InputSymbols() const =0
bool Write(std::ostream &strm, const std::string &source) const final
ArcClass operator()(const ArcClass &arc)
const std::string & ArcType() const
void SetOutputSymbols(const SymbolTable *syms) final
const SymbolTable * OutputSymbols() const final
To down_cast(From *f)
Definition: compat.h:50
Reader GetReader(std::string_view arc_type) const
Creator GetCreator(std::string_view arc_type) const
const SymbolTable * InputSymbols() const
const EncodeMapper< Arc > * GetEncodeMapper() const
virtual const std::string & WeightType() const =0
EncodeType
Definition: encode.h:53
EncodeMapperClass(const EncodeMapper< Arc > &mapper)
std::string ConvertKeyToSoFilename(std::string_view key) const final
virtual void SetInputSymbols(const SymbolTable *)=0
EncodeMapperClass & operator=(const EncodeMapperClass &other)
virtual uint8_t Flags() const =0
void SetInputSymbols(const SymbolTable *syms) final
virtual EncodeType Type() const =0
virtual EncodeMapperImplBase * Copy() const =0
const SymbolTable * InputSymbols() const final
virtual ArcClass operator()(const ArcClass &)=0
EncodeMapperClassRegEntry(Reader reader, Creator creator)
const W * GetWeight() const
Definition: weight-class.h:144
static std::unique_ptr< EncodeMapperImplBase > Create(uint8_t flags, EncodeType type=ENCODE)
uint64_t Properties(uint64_t inprops)
EncodeMapperClassImpl< Arc > * Copy() const final
std::unique_ptr< EncodeMapperClass >(*)(std::istream &stream, std::string_view source) Reader
virtual ~EncodeMapperImplBase()=default
virtual bool Write(const std::string &) const =0
static std::unique_ptr< EncodeMapperClass > Read(std::istream &strm, std::string_view source)
const SymbolTable * OutputSymbols() const
WeightClass weight
Definition: arc-class.h:51
virtual uint64_t Properties(uint64_t)=0
void SetInputSymbols(const SymbolTable *syms)
const std::string & ArcType() const final
bool Write(const std::string &source) const
EncodeMapper< Arc > * GetEncodeMapper()
EncodeMapperClassImpl(const EncodeMapper< Arc > &mapper)
virtual void SetOutputSymbols(const SymbolTable *)=0
const std::string & WeightType() const
const EncodeMapper< Arc > * GetImpl() const