FST  openfst-1.8.2
OpenFst Library
signed-log-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // LogWeight along with sign information that represents the value X in the
19 // linear domain as <sign(X), -ln(|X|)>
20 //
21 // The sign is a TropicalWeight:
22 // positive, TropicalWeight.Value() > 0.0, recommended value 1.0
23 // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
24 
25 #ifndef FST_SIGNED_LOG_WEIGHT_H_
26 #define FST_SIGNED_LOG_WEIGHT_H_
27 
28 #include <cstdint>
29 #include <random>
30 
31 
32 #include <fst/float-weight.h>
33 #include <fst/pair-weight.h>
34 #include <fst/product-weight.h>
35 
36 
37 namespace fst {
38 template <class T>
39 class SignedLogWeightTpl : public PairWeight<TropicalWeight, LogWeightTpl<T>> {
40  public:
41  using W1 = TropicalWeight;
44 
47 
48  SignedLogWeightTpl() noexcept : PairWeight<W1, W2>() {}
49 
50  // Conversion from plain LogWeightTpl.
51  // NOLINTNEXTLINE(google-explicit-constructor)
52  SignedLogWeightTpl(const W2 &w2) : PairWeight<W1, W2>(W1(1.0), w2) {}
53 
54  explicit SignedLogWeightTpl(const PairWeight<W1, W2> &weight)
55  : PairWeight<W1, W2>(weight) {}
56 
57  SignedLogWeightTpl(const W1 &w1, const W2 &w2) : PairWeight<W1, W2>(w1, w2) {}
58 
59  static const SignedLogWeightTpl &Zero() {
60  static const SignedLogWeightTpl zero(W1(1.0), W2::Zero());
61  return zero;
62  }
63 
64  static const SignedLogWeightTpl &One() {
65  static const SignedLogWeightTpl one(W1(1.0), W2::One());
66  return one;
67  }
68 
69  static const SignedLogWeightTpl &NoWeight() {
70  static const SignedLogWeightTpl no_weight(W1(1.0), W2::NoWeight());
71  return no_weight;
72  }
73 
74  static const std::string &Type() {
75  static const std::string *const type =
76  new std::string("signed_log_" + W1::Type() + "_" + W2::Type());
77  return *type;
78  }
79 
80  bool IsPositive() const { return Value1().Value() > 0; }
81 
82  SignedLogWeightTpl Quantize(float delta = kDelta) const {
84  }
85 
88  }
89 
90  bool Member() const { return PairWeight<W1, W2>::Member(); }
91 
92  // Neither idempotent nor path.
93  static constexpr uint64_t Properties() {
95  }
96 
97  size_t Hash() const {
98  size_t h1;
99  if (Value2() == W2::Zero() || IsPositive()) {
100  h1 = TropicalWeight(1.0).Hash();
101  } else {
102  h1 = TropicalWeight(-1.0).Hash();
103  }
104  size_t h2 = Value2().Hash();
105  static constexpr int lshift = 5;
106  static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
107  return h1 << lshift ^ h1 >> rshift ^ h2;
108  }
109 };
110 
111 template <class T>
113  const SignedLogWeightTpl<T> &w2) {
114  using W1 = TropicalWeight;
115  using W2 = LogWeightTpl<T>;
116  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
117  const auto s1 = w1.IsPositive();
118  const auto s2 = w2.IsPositive();
119  const bool equal = (s1 == s2);
120  const auto f1 = w1.Value2().Value();
121  const auto f2 = w2.Value2().Value();
122  if (f1 == FloatLimits<T>::PosInfinity()) {
123  return w2;
124  } else if (f2 == FloatLimits<T>::PosInfinity()) {
125  return w1;
126  } else if (f1 == f2) {
127  if (equal) {
128  return SignedLogWeightTpl<T>(W1(w1.Value1()), W2(f2 - M_LN2));
129  } else {
131  }
132  } else if (f1 > f2) {
133  if (equal) {
134  return SignedLogWeightTpl<T>(W1(w1.Value1()),
135  W2(f2 - internal::LogPosExp(f1 - f2)));
136  } else {
137  return SignedLogWeightTpl<T>(W1(w2.Value1()),
138  W2((f2 - internal::LogNegExp(f1 - f2))));
139  }
140  } else {
141  if (equal) {
142  return SignedLogWeightTpl<T>(W1(w2.Value1()),
143  W2((f1 - internal::LogPosExp(f2 - f1))));
144  } else {
145  return SignedLogWeightTpl<T>(W1(w1.Value1()),
146  W2((f1 - internal::LogNegExp(f2 - f1))));
147  }
148  }
149 }
150 
151 template <class T>
153  const SignedLogWeightTpl<T> &w2) {
154  SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
155  return Plus(w1, minus_w2);
156 }
157 
158 template <class T>
160  const SignedLogWeightTpl<T> &w2) {
161  using W2 = LogWeightTpl<T>;
162  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
163  const auto s1 = w1.IsPositive();
164  const auto s2 = w2.IsPositive();
165  const auto f1 = w1.Value2().Value();
166  const auto f2 = w2.Value2().Value();
167  if (s1 == s2) {
168  return SignedLogWeightTpl<T>(TropicalWeight(1.0), W2(f1 + f2));
169  } else {
170  return SignedLogWeightTpl<T>(TropicalWeight(-1.0), W2(f1 + f2));
171  }
172 }
173 
174 template <class T>
176  const SignedLogWeightTpl<T> &w2,
177  DivideType typ = DIVIDE_ANY) {
178  using W2 = LogWeightTpl<T>;
179  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
180  const auto s1 = w1.IsPositive();
181  const auto s2 = w2.IsPositive();
182  const auto f1 = w1.Value2().Value();
183  const auto f2 = w2.Value2().Value();
184  if (f2 == FloatLimits<T>::PosInfinity()) {
187  } else if (f1 == FloatLimits<T>::PosInfinity()) {
190  } else if (s1 == s2) {
191  return SignedLogWeightTpl<T>(TropicalWeight(1.0), W2(f1 - f2));
192  } else {
193  return SignedLogWeightTpl<T>(TropicalWeight(-1.0), W2(f1 - f2));
194  }
195 }
196 
197 template <class T>
198 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
199  const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
200  using W2 = LogWeightTpl<T>;
201  if (w1.IsPositive() == w2.IsPositive()) {
202  return ApproxEqual(w1.Value2(), w2.Value2(), delta);
203  } else {
204  return ApproxEqual(w1.Value2(), W2::Zero(), delta) &&
205  ApproxEqual(w2.Value2(), W2::Zero(), delta);
206  }
207 }
208 
209 template <class T>
210 inline bool operator==(const SignedLogWeightTpl<T> &w1,
211  const SignedLogWeightTpl<T> &w2) {
212  using W2 = LogWeightTpl<T>;
213  if (w1.IsPositive() == w2.IsPositive()) {
214  return w1.Value2() == w2.Value2();
215  } else {
216  return w1.Value2() == W2::Zero() && w2.Value2() == W2::Zero();
217  }
218 }
219 
220 template <class T>
221 inline bool operator!=(const SignedLogWeightTpl<T> &w1,
222  const SignedLogWeightTpl<T> &w2) {
223  return !(w1 == w2);
224 }
225 
226 // All functions and operators with a LogWeightTpl arg need to be
227 // explicitly specified since the implicit constructor will not be
228 // tried in conjunction with function overloading.
229 
230 template <class T>
232  const SignedLogWeightTpl<T> &w2) {
233  return Plus(SignedLogWeightTpl<T>(w1), w2);
234 }
235 
236 template <class T>
238  const LogWeightTpl<T> &w2) {
239  return Plus(w1, SignedLogWeightTpl<T>(w2));
240 }
241 
242 template <class T>
244  const SignedLogWeightTpl<T> &w2) {
245  return Minus(SignedLogWeightTpl<T>(w1), w2);
246 }
247 
248 template <class T>
250  const LogWeightTpl<T> &w2) {
251  return Minus(w1, SignedLogWeightTpl<T>(w2));
252 }
253 
254 template <class T>
256  const SignedLogWeightTpl<T> &w2) {
257  return Times(SignedLogWeightTpl<T>(w1), w2);
258 }
259 
260 template <class T>
262  const LogWeightTpl<T> &w2) {
263  return Times(w1, SignedLogWeightTpl<T>(w2));
264 }
265 
266 template <class T>
268  const SignedLogWeightTpl<T> &w2,
269  DivideType typ = DIVIDE_ANY) {
270  return Divide(SignedLogWeightTpl<T>(w1), w2, typ);
271 }
272 
273 template <class T>
275  const LogWeightTpl<T> &w2,
276  DivideType typ = DIVIDE_ANY) {
277  return Divide(w1, SignedLogWeightTpl<T>(w2), typ);
278 }
279 
280 template <class T>
281 inline bool ApproxEqual(const LogWeightTpl<T> &w1,
282  const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
283  return ApproxEqual(LogWeightTpl<T>(w1), w2, delta);
284 }
285 
286 template <class T>
287 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
288  const LogWeightTpl<T> &w2, float delta = kDelta) {
289  return ApproxEqual(w1, LogWeightTpl<T>(w2), delta);
290 }
291 
292 template <class T>
293 inline bool operator==(const LogWeightTpl<T> &w1,
294  const SignedLogWeightTpl<T> &w2) {
295  return SignedLogWeightTpl<T>(w1) == w2;
296 }
297 
298 template <class T>
299 inline bool operator==(const SignedLogWeightTpl<T> &w1,
300  const LogWeightTpl<T> &w2) {
301  return w1 == SignedLogWeightTpl<T>(w2);
302 }
303 
304 template <class T>
305 inline bool operator!=(const LogWeightTpl<T> &w1,
306  const SignedLogWeightTpl<T> &w2) {
307  return SignedLogWeightTpl<T>(w1) != w2;
308 }
309 
310 template <class T>
311 inline bool operator!=(const SignedLogWeightTpl<T> &w1,
312  const LogWeightTpl<T> &w2) {
313  return w1 != SignedLogWeightTpl<T>(w2);
314 }
315 
316 // Single-precision signed-log weight.
318 
319 // Double-precision signed-log weight.
321 
322 template <class W1, class W2>
323 bool SignedLogConvertCheck(W1 weight) {
324  if (weight.Value1().Value() < 0.0) {
325  FSTERROR() << "WeightConvert: Can't convert weight " << weight << " from "
326  << W1::Type() << " to " << W2::Type();
327  return false;
328  }
329  return true;
330 }
331 
332 // Specialization using the Kahan compensated summation
333 template <class T>
335  public:
339 
340  explicit Adder(Weight w = Weight::Zero())
341  : ssum_(w.IsPositive()), sum_(w.Value2().Value()), c_(0.0) {}
342 
343  Weight Add(const Weight &w) {
344  const auto sw = w.IsPositive();
345  const auto f = w.Value2().Value();
346  const bool equal = (ssum_ == sw);
347 
348  if (!Sum().Member() || f == FloatLimits<T>::PosInfinity()) {
349  return Sum();
350  } else if (!w.Member() || sum_ == FloatLimits<T>::PosInfinity()) {
351  sum_ = f;
352  ssum_ = sw;
353  c_ = 0.0;
354  } else if (f == sum_) {
355  if (equal) {
356  sum_ = internal::KahanLogSum(sum_, f, &c_);
357  } else {
359  ssum_ = true;
360  c_ = 0.0;
361  }
362  } else if (f > sum_) {
363  if (equal) {
364  sum_ = internal::KahanLogSum(sum_, f, &c_);
365  } else {
366  sum_ = internal::KahanLogDiff(sum_, f, &c_);
367  }
368  } else {
369  if (equal) {
370  sum_ = internal::KahanLogSum(f, sum_, &c_);
371  } else {
372  sum_ = internal::KahanLogDiff(f, sum_, &c_);
373  ssum_ = sw;
374  }
375  }
376  return Sum();
377  }
378 
379  Weight Sum() const { return Weight(W1(ssum_ ? 1.0 : -1.0), W2(sum_)); }
380 
381  void Reset(Weight w = Weight::Zero()) {
382  ssum_ = w.IsPositive();
383  sum_ = w.Value2().Value();
384  c_ = 0.0;
385  }
386 
387  private:
388  bool ssum_; // true iff sign of sum is positive
389  double sum_; // unsigned sum
390  double c_; // Kahan compensation
391 };
392 
393 // Converts to tropical.
394 template <>
397  if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(weight)) {
398  return TropicalWeight::NoWeight();
399  }
400  return TropicalWeight(weight.Value2().Value());
401  }
402 };
403 
404 template <>
407  if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(weight)) {
408  return TropicalWeight::NoWeight();
409  }
410  return TropicalWeight(weight.Value2().Value());
411  }
412 };
413 
414 // Converts to log.
415 template <>
417  LogWeight operator()(const SignedLogWeight &weight) const {
418  if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(weight)) {
419  return LogWeight::NoWeight();
420  }
421  return LogWeight(weight.Value2().Value());
422  }
423 };
424 
425 template <>
427  LogWeight operator()(const SignedLog64Weight &weight) const {
428  if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(weight)) {
429  return LogWeight::NoWeight();
430  }
431  return LogWeight(weight.Value2().Value());
432  }
433 };
434 
435 // Converts to log64.
436 template <>
438  Log64Weight operator()(const SignedLogWeight &weight) const {
439  if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(weight)) {
440  return Log64Weight::NoWeight();
441  }
442  return Log64Weight(weight.Value2().Value());
443  }
444 };
445 
446 template <>
448  Log64Weight operator()(const SignedLog64Weight &weight) const {
449  if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(weight)) {
450  return Log64Weight::NoWeight();
451  }
452  return Log64Weight(weight.Value2().Value());
453  }
454 };
455 
456 // Converts to real.
457 template <>
459  RealWeight operator()(const SignedLogWeight &weight) const {
460  return RealWeight(weight.Value1().Value() * exp(-weight.Value2().Value()));
461  }
462 };
463 
464 template <>
466  RealWeight operator()(const SignedLog64Weight &weight) const {
467  return RealWeight(weight.Value1().Value() * exp(-weight.Value2().Value()));
468  }
469 };
470 
471 // Converts to real64.
472 template <>
474  Real64Weight operator()(const SignedLogWeight &weight) const {
475  return Real64Weight(weight.Value1().Value() *
476  exp(-weight.Value2().Value()));
477  }
478 };
479 
480 template <>
483  return Real64Weight(weight.Value1().Value() *
484  exp(-weight.Value2().Value()));
485  }
486 };
487 
488 // Converts to signed log.
489 template <>
492  return SignedLogWeight(1.0, weight.Value());
493  }
494 };
495 
496 template <>
498  SignedLogWeight operator()(const LogWeight &weight) const {
499  return SignedLogWeight(1.0, weight.Value());
500  }
501 };
502 
503 template <>
505  SignedLogWeight operator()(const Log64Weight &weight) const {
506  return SignedLogWeight(1.0, weight.Value());
507  }
508 };
509 
510 template <>
512  SignedLogWeight operator()(const RealWeight &weight) const {
513  return SignedLogWeight(weight.Value() >= 0 ? 1.0 : -1.0,
514  -log(std::abs(weight.Value())));
515  }
516 };
517 
518 template <>
520  SignedLogWeight operator()(const Real64Weight &weight) const {
521  return SignedLogWeight(weight.Value() >= 0 ? 1.0 : -1.0,
522  -log(std::abs(weight.Value())));
523  }
524 };
525 
526 template <>
529  return SignedLogWeight(weight.Value1(), weight.Value2().Value());
530  }
531 };
532 
533 // Converts to signed log64.
534 template <>
537  return SignedLog64Weight(1.0, weight.Value());
538  }
539 };
540 
541 template <>
543  SignedLog64Weight operator()(const LogWeight &weight) const {
544  return SignedLog64Weight(1.0, weight.Value());
545  }
546 };
547 
548 template <>
550  SignedLog64Weight operator()(const Log64Weight &weight) const {
551  return SignedLog64Weight(1.0, weight.Value());
552  }
553 };
554 
555 template <>
557  SignedLog64Weight operator()(const RealWeight &weight) const {
558  return SignedLog64Weight(weight.Value() >= 0 ? 1.0 : -1.0,
559  -log(std::abs(weight.Value())));
560  }
561 };
562 
563 template <>
566  return SignedLog64Weight(weight.Value() >= 0 ? 1.0 : -1.0,
567  -log(std::abs(weight.Value())));
568  }
569 };
570 
571 template <>
574  return SignedLog64Weight(weight.Value1(), weight.Value2().Value());
575  }
576 };
577 
578 // This function object returns SignedLogWeightTpl<T>'s that are random integers
579 // chosen from [0, num_random_weights) times a random sign. This is intended
580 // primarily for testing.
581 template <class T>
583  public:
585  using W1 = typename Weight::W1;
586  using W2 = typename Weight::W2;
587 
588  explicit WeightGenerate(uint64_t seed = std::random_device()(),
589  bool allow_zero = true,
590  size_t num_random_weights = kNumRandomWeights)
591  : rand_(seed),
592  allow_zero_(allow_zero),
593  num_random_weights_(num_random_weights) {}
594 
595  Weight operator()() const {
596  static constexpr W1 negative(-1.0);
597  static constexpr W1 positive(+1.0);
598  const bool sign = std::bernoulli_distribution(.5)(rand_);
599  const int sample = std::uniform_int_distribution<>(
600  0, num_random_weights_ + allow_zero_ - 1)(rand_);
601  if (allow_zero_ && sample == num_random_weights_) {
602  return Weight(sign ? positive : negative, W2::Zero());
603  }
604  return Weight(sign ? positive : negative, W2(sample));
605  }
606 
607  private:
608  mutable std::mt19937_64 rand_;
609  const bool allow_zero_;
610  const size_t num_random_weights_;
611 };
612 
613 } // namespace fst
614 
615 #endif // FST_SIGNED_LOG_WEIGHT_H_
static const SignedLogWeightTpl & NoWeight()
Real64Weight operator()(const SignedLogWeight &weight) const
TropicalWeight operator()(const SignedLog64Weight &weight) const
Real64Weight operator()(const SignedLog64Weight &weight) const
RealWeightTpl< float > RealWeight
Definition: float-weight.h:708
static const SignedLogWeightTpl & Zero()
LogWeight operator()(const SignedLogWeight &weight) const
SignedLog64Weight operator()(const RealWeight &weight) const
RealWeight operator()(const SignedLogWeight &weight) const
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
SignedLogWeightTpl(const PairWeight< W1, W2 > &weight)
void Reset(Weight w=Weight::Zero())
SignedLog64Weight operator()(const TropicalWeight &weight) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
bool Member() const
Definition: pair-weight.h:73
SignedLogWeight operator()(const TropicalWeight &weight) const
const LogWeightTpl< T > & Value2() const
Definition: pair-weight.h:93
static constexpr uint64_t Properties()
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:432
SignedLogWeightTpl Quantize(float delta=kDelta) const
SignedLogWeight operator()(const Log64Weight &weight) const
static constexpr T PosInfinity()
Definition: float-weight.h:56
constexpr uint64_t kRightSemiring
Definition: weight.h:136
static constexpr LogWeightTpl One()
Definition: float-weight.h:434
#define FSTERROR()
Definition: util.h:53
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:468
Log64Weight operator()(const SignedLog64Weight &weight) const
LogWeightTpl< float > LogWeight
Definition: float-weight.h:465
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
TropicalWeightTpl< float > TropicalWeight
Definition: float-weight.h:264
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:53
static const SignedLogWeightTpl & One()
static constexpr LogWeightTpl NoWeight()
Definition: float-weight.h:436
constexpr uint64_t kCommutative
Definition: weight.h:141
SignedLogWeightTpl(const W1 &w1, const W2 &w2)
SignedLog64Weight operator()(const SignedLogWeight &weight) const
static constexpr TropicalWeightTpl< T > NoWeight()
Definition: float-weight.h:217
double LogPosExp(double x)
Definition: float-weight.h:473
double KahanLogSum(double a, double b, double *c)
Definition: float-weight.h:488
RealWeight operator()(const SignedLog64Weight &weight) const
SignedLogWeightTpl< double > SignedLog64Weight
static const std::string & Type()
Definition: float-weight.h:221
SignedLog64Weight operator()(const Real64Weight &weight) const
SignedLogWeight operator()(const RealWeight &weight) const
double LogNegExp(double x)
Definition: float-weight.h:479
LogWeightTpl< T > Minus(const LogWeightTpl< T > &w1, const LogWeightTpl< T > &w2)
Definition: float-weight.h:539
SignedLogWeightTpl(const W2 &w2)
static const std::string & Type()
Definition: float-weight.h:438
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:66
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:50
constexpr size_t kNumRandomWeights
Definition: weight.h:151
Log64Weight operator()(const SignedLogWeight &weight) const
bool SignedLogConvertCheck(W1 weight)
RealWeightTpl< double > Real64Weight
Definition: float-weight.h:711
SignedLog64Weight operator()(const Log64Weight &weight) const
SignedLogWeight operator()(const SignedLog64Weight &weight) const
ReverseWeight Reverse() const
SignedLogWeight operator()(const Real64Weight &weight) const
SignedLog64Weight operator()(const LogWeight &weight) const
SignedLogWeight operator()(const LogWeight &weight) const
DivideType
Definition: weight.h:162
LogWeight operator()(const SignedLog64Weight &weight) const
constexpr uint64_t kLeftSemiring
Definition: weight.h:133
constexpr float kDelta
Definition: weight.h:130
SignedLogWeightTpl< float > SignedLogWeight
static const std::string & Type()
constexpr const T & Value() const
Definition: float-weight.h:91
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:57
TropicalWeight operator()(const SignedLogWeight &weight) const
double KahanLogDiff(double a, double b, double *c)
Definition: float-weight.h:500