FST  openfst-1.7.1
OpenFst Library
signed-log-weight.h
Go to the documentation of this file.
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // LogWeight along with sign information that represents the value X in the
5 // linear domain as <sign(X), -ln(|X|)>
6 //
7 // The sign is a TropicalWeight:
8 // positive, TropicalWeight.Value() > 0.0, recommended value 1.0
9 // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
10 
11 #ifndef FST_SIGNED_LOG_WEIGHT_H_
12 #define FST_SIGNED_LOG_WEIGHT_H_
13 
14 #include <cstdlib>
15 
16 #include <fst/float-weight.h>
17 #include <fst/pair-weight.h>
18 #include <fst/product-weight.h>
19 
20 
21 namespace fst {
22 template <class T>
23 class SignedLogWeightTpl : public PairWeight<TropicalWeight, LogWeightTpl<T>> {
24  public:
25  using X1 = TropicalWeight;
28 
31 
32  SignedLogWeightTpl() noexcept : PairWeight<X1, X2>() {}
33 
35  : PairWeight<X1, X2>(w) {}
36 
37  SignedLogWeightTpl(const X1 &x1, const X2 &x2) : PairWeight<X1, X2>(x1, x2) {}
38 
39  static const SignedLogWeightTpl &Zero() {
40  static const SignedLogWeightTpl zero(X1(1.0), X2::Zero());
41  return zero;
42  }
43 
44  static const SignedLogWeightTpl &One() {
45  static const SignedLogWeightTpl one(X1(1.0), X2::One());
46  return one;
47  }
48 
49  static const SignedLogWeightTpl &NoWeight() {
50  static const SignedLogWeightTpl no_weight(X1(1.0), X2::NoWeight());
51  return no_weight;
52  }
53 
54  static const string &Type() {
55  static const string *const type =
56  new string("signed_log_" + X1::Type() + "_" + X2::Type());
57  return *type;
58  }
59 
60  bool IsPositive() const { return Value1().Value() > 0; }
61 
62  SignedLogWeightTpl Quantize(float delta = kDelta) const {
64  }
65 
68  }
69 
70  bool Member() const { return PairWeight<X1, X2>::Member(); }
71 
72  // Neither idempotent nor path.
73  static constexpr uint64 Properties() {
75  }
76 
77  size_t Hash() const {
78  size_t h1;
79  if (Value2() == X2::Zero() || IsPositive()) {
80  h1 = TropicalWeight(1.0).Hash();
81  } else {
82  h1 = TropicalWeight(-1.0).Hash();
83  }
84  size_t h2 = Value2().Hash();
85  static constexpr int lshift = 5;
86  static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
87  return h1 << lshift ^ h1 >> rshift ^ h2;
88  }
89 };
90 
91 template <class T>
93  const SignedLogWeightTpl<T> &w2) {
94  using X1 = TropicalWeight;
95  using X2 = LogWeightTpl<T>;
96  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
97  const auto s1 = w1.IsPositive();
98  const auto s2 = w2.IsPositive();
99  const bool equal = (s1 == s2);
100  const auto f1 = w1.Value2().Value();
101  const auto f2 = w2.Value2().Value();
102  if (f1 == FloatLimits<T>::PosInfinity()) {
103  return w2;
104  } else if (f2 == FloatLimits<T>::PosInfinity()) {
105  return w1;
106  } else if (f1 == f2) {
107  if (equal) {
108  return SignedLogWeightTpl<T>(X1(w1.Value1()), X2(f2 - log(2.0F)));
109  } else {
111  }
112  } else if (f1 > f2) {
113  if (equal) {
114  return SignedLogWeightTpl<T>(X1(w1.Value1()),
115  X2(f2 - internal::LogPosExp(f1 - f2)));
116  } else {
117  return SignedLogWeightTpl<T>(X1(w2.Value1()),
118  X2((f2 - internal::LogNegExp(f1 - f2))));
119  }
120  } else {
121  if (equal) {
122  return SignedLogWeightTpl<T>(X1(w2.Value1()),
123  X2((f1 - internal::LogPosExp(f2 - f1))));
124  } else {
125  return SignedLogWeightTpl<T>(X1(w1.Value1()),
126  X2((f1 - internal::LogNegExp(f2 - f1))));
127  }
128  }
129 }
130 
131 template <class T>
133  const SignedLogWeightTpl<T> &w2) {
134  SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
135  return Plus(w1, minus_w2);
136 }
137 
138 template <class T>
140  const SignedLogWeightTpl<T> &w2) {
141  using X2 = LogWeightTpl<T>;
142  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
143  const auto s1 = w1.IsPositive();
144  const auto s2 = w2.IsPositive();
145  const auto f1 = w1.Value2().Value();
146  const auto f2 = w2.Value2().Value();
147  if (s1 == s2) {
148  return SignedLogWeightTpl<T>(TropicalWeight(1.0), X2(f1 + f2));
149  } else {
150  return SignedLogWeightTpl<T>(TropicalWeight(-1.0), X2(f1 + f2));
151  }
152 }
153 
154 template <class T>
156  const SignedLogWeightTpl<T> &w2,
157  DivideType typ = DIVIDE_ANY) {
158  using X2 = LogWeightTpl<T>;
159  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
160  const auto s1 = w1.IsPositive();
161  const auto s2 = w2.IsPositive();
162  const auto f1 = w1.Value2().Value();
163  const auto f2 = w2.Value2().Value();
164  if (f2 == FloatLimits<T>::PosInfinity()) {
167  } else if (f1 == FloatLimits<T>::PosInfinity()) {
170  } else if (s1 == s2) {
171  return SignedLogWeightTpl<T>(TropicalWeight(1.0), X2(f1 - f2));
172  } else {
173  return SignedLogWeightTpl<T>(TropicalWeight(-1.0), X2(f1 - f2));
174  }
175 }
176 
177 template <class T>
178 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
179  const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
180  using X2 = LogWeightTpl<T>;
181  if (w1.IsPositive() == w2.IsPositive()) {
182  return ApproxEqual(w1.Value2(), w2.Value2(), delta);
183  } else {
184  return ApproxEqual(w1.Value2(), X2::Zero(), delta)
185  && ApproxEqual(w2.Value2(), X2::Zero(), delta);
186  }
187 }
188 
189 template <class T>
190 inline bool operator==(const SignedLogWeightTpl<T> &w1,
191  const SignedLogWeightTpl<T> &w2) {
192  using X2 = LogWeightTpl<T>;
193  if (w1.IsPositive() == w2.IsPositive()) {
194  return w1.Value2() == w2.Value2();
195  } else {
196  return w1.Value2() == X2::Zero()
197  && w2.Value2() == X2::Zero();
198  }
199 }
200 
201 template <class T>
202 inline bool operator!=(const SignedLogWeightTpl<T> &w1,
203  const SignedLogWeightTpl<T> &w2) {
204  return !(w1 == w2);
205 }
206 
207 // Single-precision signed-log weight.
209 
210 // Double-precision signed-log weight.
212 
213 template <class W1, class W2>
214 bool SignedLogConvertCheck(W1 weight) {
215  if (weight.Value1().Value() < 0.0) {
216  FSTERROR() << "WeightConvert: Can't convert weight " << weight
217  << " from " << W1::Type() << " to " << W2::Type();
218  return false;
219  }
220  return true;
221 }
222 
223 // Specialization using the Kahan compensated summation
224 template <class T>
226  public:
230 
231  explicit Adder(Weight w = Weight::Zero())
232  : ssum_(w.IsPositive()),
233  sum_(w.Value2().Value()),
234  c_(0.0) { }
235 
236  Weight Add(const Weight &w) {
237  const auto sw = w.IsPositive();
238  const auto f = w.Value2().Value();
239  const bool equal = (ssum_ == sw);
240 
241  if (!Sum().Member() || f == FloatLimits<T>::PosInfinity()) {
242  return Sum();
243  } else if (!w.Member() || sum_ == FloatLimits<T>::PosInfinity()) {
244  sum_ = f;
245  ssum_ = sw;
246  c_ = 0.0;
247  } else if (f == sum_) {
248  if (equal) {
249  sum_ = internal::KahanLogSum(sum_, f, &c_);
250  } else {
252  ssum_ = true;
253  c_ = 0.0;
254  }
255  } else if (f > sum_) {
256  if (equal) {
257  sum_ = internal::KahanLogSum(sum_, f, &c_);
258  } else {
259  sum_ = internal::KahanLogDiff(sum_, f, &c_);
260  }
261  } else {
262  if (equal) {
263  sum_ = internal::KahanLogSum(f, sum_, &c_);
264  } else {
265  sum_ = internal::KahanLogDiff(f, sum_, &c_);
266  ssum_ = sw;
267  }
268  }
269  return Sum();
270  }
271 
272  Weight Sum() { return Weight(X1(ssum_ ? 1.0 : -1.0), X2(sum_)); }
273 
274  void Reset(Weight w = Weight::Zero()) {
275  ssum_ = w.IsPositive();
276  sum_ = w.Value2().Value();
277  c_ = 0.0;
278  }
279 
280  private:
281  bool ssum_; // true iff sign of sum is positive
282  double sum_; // unsigned sum
283  double c_; // Kahan compensation
284 };
285 
286 // Converts to tropical.
287 template <>
290  if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(weight)) {
291  return TropicalWeight::NoWeight();
292  }
293  return TropicalWeight(weight.Value2().Value());
294  }
295 };
296 
297 template <>
300  if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(weight)) {
301  return TropicalWeight::NoWeight();
302  }
303  return TropicalWeight(weight.Value2().Value());
304  }
305 };
306 
307 // Converts to log.
308 template <>
310  LogWeight operator()(const SignedLogWeight &weight) const {
311  if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(weight)) {
312  return LogWeight::NoWeight();
313  }
314  return LogWeight(weight.Value2().Value());
315  }
316 };
317 
318 template <>
320  LogWeight operator()(const SignedLog64Weight &weight) const {
321  if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(weight)) {
322  return LogWeight::NoWeight();
323  }
324  return LogWeight(weight.Value2().Value());
325  }
326 };
327 
328 // Converts to log64.
329 template <>
331  Log64Weight operator()(const SignedLogWeight &weight) const {
332  if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(weight)) {
333  return Log64Weight::NoWeight();
334  }
335  return Log64Weight(weight.Value2().Value());
336  }
337 };
338 
339 template <>
341  Log64Weight operator()(const SignedLog64Weight &weight) const {
342  if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(weight)) {
343  return Log64Weight::NoWeight();
344  }
345  return Log64Weight(weight.Value2().Value());
346  }
347 };
348 
349 // Converts to signed log.
350 template <>
353  return SignedLogWeight(1.0, weight.Value());
354  }
355 };
356 
357 template <>
359  SignedLogWeight operator()(const LogWeight &weight) const {
360  return SignedLogWeight(1.0, weight.Value());
361  }
362 };
363 
364 template <>
366  SignedLogWeight operator()(const Log64Weight &weight) const {
367  return SignedLogWeight(1.0, weight.Value());
368  }
369 };
370 
371 template <>
374  return SignedLogWeight(weight.Value1(), weight.Value2().Value());
375  }
376 };
377 
378 // Converts to signed log64.
379 template <>
382  return SignedLog64Weight(1.0, weight.Value());
383  }
384 };
385 
386 template <>
388  SignedLog64Weight operator()(const LogWeight &weight) const {
389  return SignedLog64Weight(1.0, weight.Value());
390  }
391 };
392 
393 template <>
395  SignedLog64Weight operator()(const Log64Weight &weight) const {
396  return SignedLog64Weight(1.0, weight.Value());
397  }
398 };
399 
400 template <>
403  return SignedLog64Weight(weight.Value1(), weight.Value2().Value());
404  }
405 };
406 
407 // This function object returns SignedLogWeightTpl<T>'s that are random integers
408 // chosen from [0, num_random_weights) times a random sign. This is intended
409 // primarily for testing.
410 template <class T>
412  public:
414  using X1 = typename Weight::X1;
415  using X2 = typename Weight::X2;
416 
417  explicit WeightGenerate(bool allow_zero = true,
418  size_t num_random_weights = kNumRandomWeights)
419  : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
420 
421  Weight operator()() const {
422  static const X1 negative_one(-1.0);
423  static const X1 positive_one(+1.0);
424  const int m = rand() % 2; // NOLINT
425  const int n = rand() % (num_random_weights_ + allow_zero_); // NOLINT
426  return Weight((m == 0) ? negative_one : positive_one,
427  (allow_zero_ && n == num_random_weights_) ?
428  X2::Zero() : X2(n));
429  }
430 
431  private:
432  // Permits Zero() and zero divisors.
433  const bool allow_zero_;
434  // Number of alternative random weights.
435  const size_t num_random_weights_;
436 };
437 
438 } // namespace fst
439 
440 #endif // FST_SIGNED_LOG_WEIGHT_H_
static const SignedLogWeightTpl & NoWeight()
WeightGenerate(bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
TropicalWeight operator()(const SignedLog64Weight &weight) const
static const SignedLogWeightTpl & Zero()
ExpectationWeight< X1, X2 > Divide(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2, DivideType typ=DIVIDE_ANY)
LogWeight operator()(const SignedLogWeight &weight) const
uint64_t uint64
Definition: types.h:32
constexpr uint64 kRightSemiring
Definition: weight.h:115
void Reset(Weight w=Weight::Zero())
SignedLog64Weight operator()(const TropicalWeight &weight) const
bool Member() const
Definition: pair-weight.h:58
SignedLogWeight operator()(const TropicalWeight &weight) const
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
const LogWeightTpl< T > & Value2() const
Definition: pair-weight.h:78
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:413
constexpr uint64 kCommutative
Definition: weight.h:120
SignedLogWeightTpl Quantize(float delta=kDelta) const
constexpr uint64 kLeftSemiring
Definition: weight.h:112
SignedLogWeightTpl< T > Minus(const SignedLogWeightTpl< T > &w1, const SignedLogWeightTpl< T > &w2)
SignedLogWeightTpl(const PairWeight< X1, X2 > &w)
SignedLogWeight operator()(const Log64Weight &weight) const
static constexpr T PosInfinity()
Definition: float-weight.h:29
static constexpr LogWeightTpl One()
Definition: float-weight.h:415
#define FSTERROR()
Definition: util.h:35
SignedLogWeightTpl(const X1 &x1, const X2 &x2)
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:449
Log64Weight operator()(const SignedLog64Weight &weight) const
LogWeightTpl< float > LogWeight
Definition: float-weight.h:446
TropicalWeightTpl< float > TropicalWeight
Definition: float-weight.h:244
static const SignedLogWeightTpl & One()
bool operator==(const PdtStateTuple< S, K > &x, const PdtStateTuple< S, K > &y)
Definition: pdt.h:133
static constexpr LogWeightTpl NoWeight()
Definition: float-weight.h:417
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
SignedLog64Weight operator()(const SignedLogWeight &weight) const
constexpr bool operator!=(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2)
Definition: float-weight.h:119
static constexpr TropicalWeightTpl< T > NoWeight()
Definition: float-weight.h:196
double LogPosExp(double x)
Definition: float-weight.h:454
double KahanLogSum(double a, double b, double *c)
Definition: float-weight.h:469
SignedLogWeightTpl< double > SignedLog64Weight
double LogNegExp(double x)
Definition: float-weight.h:460
constexpr bool ApproxEqual(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2, float delta=kDelta)
Definition: float-weight.h:140
static const string & Type()
constexpr size_t kNumRandomWeights
Definition: weight.h:130
static const string & Type()
Definition: float-weight.h:200
static constexpr uint64 Properties()
Log64Weight operator()(const SignedLogWeight &weight) const
bool SignedLogConvertCheck(W1 weight)
SignedLog64Weight operator()(const Log64Weight &weight) const
SignedLogWeight operator()(const SignedLog64Weight &weight) const
ReverseWeight Reverse() const
static const string & Type()
Definition: float-weight.h:419
SignedLog64Weight operator()(const LogWeight &weight) const
SignedLogWeight operator()(const LogWeight &weight) const
DivideType
Definition: weight.h:142
LogWeight operator()(const SignedLog64Weight &weight) const
constexpr float kDelta
Definition: weight.h:109
SignedLogWeightTpl< float > SignedLogWeight
constexpr const T & Value() const
Definition: float-weight.h:71
TropicalWeight operator()(const SignedLogWeight &weight) const
double KahanLogDiff(double a, double b, double *c)
Definition: float-weight.h:481