FST  openfst-1.8.3
OpenFst Library
weight-class.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Represents a generic weight in an FST; that is, represents a specific type
19 // of weight underneath while hiding that type from a client.
20 
21 #ifndef FST_SCRIPT_WEIGHT_CLASS_H_
22 #define FST_SCRIPT_WEIGHT_CLASS_H_
23 
24 #include <cstddef>
25 #include <memory>
26 #include <ostream>
27 #include <string>
28 
29 #include <fst/arc.h>
30 #include <fst/generic-register.h>
31 #include <fst/util.h>
32 #include <fst/weight.h>
33 #include <string_view>
34 
35 namespace fst {
36 namespace script {
37 
39  public:
40  virtual WeightImplBase *Copy() const = 0;
41  virtual void Print(std::ostream *o) const = 0;
42  virtual const std::string &Type() const = 0;
43  virtual std::string ToString() const = 0;
44  virtual bool Member() const = 0;
45  virtual bool operator==(const WeightImplBase &other) const = 0;
46  virtual bool operator!=(const WeightImplBase &other) const = 0;
47  virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0;
48  virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0;
49  virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0;
50  virtual WeightImplBase &PowerEq(size_t n) = 0;
51  virtual ~WeightImplBase() = default;
52 };
53 
54 template <class W>
56  public:
57  explicit WeightClassImpl(const W &weight) : weight_(weight) {}
58 
59  WeightClassImpl<W> *Copy() const final {
60  return new WeightClassImpl<W>(weight_);
61  }
62 
63  const std::string &Type() const final { return W::Type(); }
64 
65  void Print(std::ostream *ostrm) const final { *ostrm << weight_; }
66 
67  std::string ToString() const final {
68  return WeightToStr(weight_);
69  }
70 
71  bool Member() const final { return weight_.Member(); }
72 
73  bool operator==(const WeightImplBase &other) const final {
74  const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
75  return weight_ == typed_other->weight_;
76  }
77 
78  bool operator!=(const WeightImplBase &other) const final {
79  return !(*this == other);
80  }
81 
82  WeightClassImpl<W> &PlusEq(const WeightImplBase &other) final {
83  const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
84  weight_ = Plus(weight_, typed_other->weight_);
85  return *this;
86  }
87 
88  WeightClassImpl<W> &TimesEq(const WeightImplBase &other) final {
89  const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
90  weight_ = Times(weight_, typed_other->weight_);
91  return *this;
92  }
93 
95  const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
96  weight_ = Divide(weight_, typed_other->weight_);
97  return *this;
98  }
99 
100  WeightClassImpl<W> &PowerEq(size_t n) final {
101  weight_ = Power(weight_, n);
102  return *this;
103  }
104 
105  W *GetImpl() { return &weight_; }
106 
107  private:
108  W weight_;
109 };
110 
111 class WeightClass {
112  public:
113  WeightClass() = default;
114 
115  template <class W>
116  explicit WeightClass(const W &weight)
117  : impl_(std::make_unique<WeightClassImpl<W>>(weight)) {}
118 
119  template <class W>
120  explicit WeightClass(const WeightClassImpl<W> &impl)
121  : impl_(std::make_unique<WeightClassImpl<W>>(impl)) {}
122 
123  WeightClass(std::string_view weight_type, std::string_view weight_str);
124 
125  WeightClass(const WeightClass &other)
126  : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {}
127 
129  impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr);
130  return *this;
131  }
132 
133  static constexpr std::string_view __ZERO__ = "__ZERO__"; // NOLINT
134  static constexpr std::string_view __ONE__ = "__ONE__"; // NOLINT
135  static constexpr std::string_view __NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT
136 
137  static WeightClass Zero(std::string_view weight_type);
138 
139  static WeightClass One(std::string_view weight_type);
140 
141  static WeightClass NoWeight(std::string_view weight_type);
142 
143  template <class W>
144  const W *GetWeight() const {
145  if (W::Type() != impl_->Type()) {
146  return nullptr;
147  } else {
148  auto *typed_impl = static_cast<WeightClassImpl<W> *>(impl_.get());
149  return typed_impl->GetImpl();
150  }
151  }
152 
153  std::string ToString() const { return (impl_) ? impl_->ToString() : "none"; }
154 
155  const std::string &Type() const {
156  if (impl_) return impl_->Type();
157  static const std::string *const no_type = new std::string("none");
158  return *no_type;
159  }
160 
161  bool Member() const { return impl_ && impl_->Member(); }
162 
163  static bool WeightTypesMatch(const WeightClass &lhs, const WeightClass &rhs,
164  std::string_view op_name);
165 
166  friend bool operator==(const WeightClass &lhs, const WeightClass &rhs);
167 
168  friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
169 
170  friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
171 
172  friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
173 
174  friend WeightClass Power(const WeightClass &w, size_t n);
175 
176  private:
177  const WeightImplBase *GetImpl() const { return impl_.get(); }
178 
179  WeightImplBase *GetImpl() { return impl_.get(); }
180 
181  std::unique_ptr<WeightImplBase> impl_;
182 
183  friend std::ostream &operator<<(std::ostream &o, const WeightClass &c);
184 };
185 
186 bool operator==(const WeightClass &lhs, const WeightClass &rhs);
187 
188 bool operator!=(const WeightClass &lhs, const WeightClass &rhs);
189 
190 WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
191 
192 WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
193 
194 WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
195 
196 WeightClass Power(const WeightClass &w, size_t n);
197 
198 std::ostream &operator<<(std::ostream &o, const WeightClass &c);
199 
200 // Registration for generic weight types.
201 
202 using StrToWeightImplBaseT =
203  std::unique_ptr<WeightImplBase> (*)(std::string_view str);
204 
205 template <class W>
206 std::unique_ptr<WeightImplBase> StrToWeightImplBase(std::string_view str) {
207  if (str == WeightClass::__ZERO__) {
208  return std::make_unique<WeightClassImpl<W>>(W::Zero());
209  } else if (str == WeightClass::__ONE__) {
210  return std::make_unique<WeightClassImpl<W>>(W::One());
211  } else if (str == WeightClass::__NOWEIGHT__) {
212  return std::make_unique<WeightClassImpl<W>>(W::NoWeight());
213  }
214  return std::make_unique<WeightClassImpl<W>>(StrToWeight<W>(str));
215 }
216 
218  : public GenericRegister<std::string, StrToWeightImplBaseT,
219  WeightClassRegister> {
220  protected:
221  std::string ConvertKeyToSoFilename(std::string_view key) const final {
222  std::string legal_type(key);
223  ConvertToLegalCSymbol(&legal_type);
224  legal_type.append(".so");
225  return legal_type;
226  }
227 };
228 
230 
231 // Internal version; needs to be called by wrapper in order for macro args to
232 // expand.
233 #define REGISTER_FST_WEIGHT__(Weight, line) \
234  static WeightClassRegisterer weight_registerer##_##line( \
235  Weight::Type(), StrToWeightImplBase<Weight>)
236 
237 // This layer is where __FILE__ and __LINE__ are expanded.
238 #define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
239  REGISTER_FST_WEIGHT__(Weight, line)
240 
241 // Macro for registering new weight types; clients call this.
242 #define REGISTER_FST_WEIGHT(Weight) \
243  REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
244 
245 } // namespace script
246 } // namespace fst
247 
248 #endif // FST_SCRIPT_WEIGHT_CLASS_H_
WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs)
WeightClassImpl< W > * Copy() const final
Definition: weight-class.h:59
virtual bool operator!=(const WeightImplBase &other) const =0
void ConvertToLegalCSymbol(std::string *s)
Definition: util.cc:75
WeightClassImpl< W > & DivideEq(const WeightImplBase &other) final
Definition: weight-class.h:94
WeightClass(const WeightClass &other)
Definition: weight-class.h:125
bool operator!=(const WeightImplBase &other) const final
Definition: weight-class.h:78
WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs)
Definition: weight-class.cc:86
static constexpr std::string_view __ONE__
Definition: weight-class.h:134
virtual WeightImplBase & PowerEq(size_t n)=0
WeightClass & operator=(const WeightClass &other)
Definition: weight-class.h:128
const std::string & Type() const
Definition: weight-class.h:155
virtual bool Member() const =0
virtual WeightImplBase * Copy() const =0
WeightClass(const W &weight)
Definition: weight-class.h:116
virtual ~WeightImplBase()=default
To down_cast(From *f)
Definition: compat.h:50
static constexpr std::string_view __NOWEIGHT__
Definition: weight-class.h:135
std::string ConvertKeyToSoFilename(std::string_view key) const final
Definition: weight-class.h:221
const std::string & Type() const final
Definition: weight-class.h:63
virtual const std::string & Type() const =0
virtual WeightImplBase & DivideEq(const WeightImplBase &other)=0
std::unique_ptr< WeightImplBase > StrToWeightImplBase(std::string_view str)
Definition: weight-class.h:206
WeightClassImpl< W > & TimesEq(const WeightImplBase &other) final
Definition: weight-class.h:88
virtual std::string ToString() const =0
std::ostream & operator<<(std::ostream &o, const WeightClass &c)
virtual void Print(std::ostream *o) const =0
std::unique_ptr< WeightImplBase >(*)(std::string_view str) StrToWeightImplBaseT
Definition: weight-class.h:203
std::string WeightToStr(Weight w)
Definition: util.h:359
virtual WeightImplBase & PlusEq(const WeightImplBase &other)=0
const W * GetWeight() const
Definition: weight-class.h:144
bool Member() const final
Definition: weight-class.h:71
WeightClassImpl< W > & PlusEq(const WeightImplBase &other) final
Definition: weight-class.h:82
std::string ToString() const final
Definition: weight-class.h:67
std::string ToString() const
Definition: weight-class.h:153
WeightClassImpl(const W &weight)
Definition: weight-class.h:57
virtual WeightImplBase & TimesEq(const WeightImplBase &other)=0
virtual bool operator==(const WeightImplBase &other) const =0
WeightClass(const WeightClassImpl< W > &impl)
Definition: weight-class.h:120
void Print(std::ostream *ostrm) const final
Definition: weight-class.h:65
WeightClassImpl< W > & PowerEq(size_t n) final
Definition: weight-class.h:100
WeightClass Times(const WeightClass &lhs, const WeightClass &rhs)
Definition: weight-class.cc:97
bool operator==(const WeightImplBase &other) const final
Definition: weight-class.h:73
WeightClass Power(const WeightClass &w, size_t n)
static constexpr std::string_view __ZERO__
Definition: weight-class.h:133