FST  openfst-1.8.2
OpenFst Library
float-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Float weight set and associated semiring operation definitions.
19 
20 #ifndef FST_FLOAT_WEIGHT_H_
21 #define FST_FLOAT_WEIGHT_H_
22 
23 #include <algorithm>
24 #include <climits>
25 #include <cmath>
26 #include <cstdint>
27 #include <cstring>
28 #include <limits>
29 #include <random>
30 #include <sstream>
31 #include <string>
32 #include <type_traits>
33 
34 #include <fst/util.h>
35 #include <fst/weight.h>
36 
37 #include <fst/compat.h>
38 #include <string_view>
39 
40 namespace fst {
41 
42 namespace internal {
43 // TODO(wolfsonkin): Replace with `std::isnan` if and when that ends up
44 // constexpr. For context, see
45 // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0533r6.pdf.
46 template <class T>
47 inline constexpr bool IsNan(T value) {
48  return value != value;
49 }
50 } // namespace internal
51 
52 // Numeric limits class.
53 template <class T>
54 class FloatLimits {
55  public:
56  static constexpr T PosInfinity() {
57  return std::numeric_limits<T>::infinity();
58  }
59 
60  static constexpr T NegInfinity() { return -PosInfinity(); }
61 
62  static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); }
63 };
64 
65 // Weight class to be templated on floating-points types.
66 template <class T = float>
68  public:
69  using ValueType = T;
70 
71  FloatWeightTpl() noexcept {}
72 
73  constexpr FloatWeightTpl(T f) : value_(f) {} // NOLINT
74 
75  std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
76 
77  std::ostream &Write(std::ostream &strm) const {
78  return WriteType(strm, value_);
79  }
80 
81  size_t Hash() const {
82  size_t hash = 0;
83  // Avoid using union, which would be undefined behavior.
84  // Use memcpy, similar to bit_cast, but sizes may be different.
85  // This should be optimized into a single move instruction by
86  // any reasonable compiler.
87  std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_)));
88  return hash;
89  }
90 
91  constexpr const T &Value() const { return value_; }
92 
93  protected:
94  void SetValue(const T &f) { value_ = f; }
95 
96  static constexpr std::string_view GetPrecisionString() {
97  return sizeof(T) == 4
98  ? ""
99  : sizeof(T) == 1
100  ? "8"
101  : sizeof(T) == 2 ? "16"
102  : sizeof(T) == 8 ? "64" : "unknown";
103  }
104 
105  private:
106  T value_;
107 };
108 
109 // Single-precision float weight.
111 
112 template <class T>
113 constexpr bool operator==(const FloatWeightTpl<T> &w1,
114  const FloatWeightTpl<T> &w2) {
115 #if (defined(__i386__) || defined(__x86_64__)) && !defined(__SSE2_MATH__)
116 // With i387 instructions, excess precision on a weight in an 80-bit
117 // register may cause it to compare unequal to that same weight when
118 // stored to memory. This breaks =='s reflexivity, in turn breaking
119 // NaturalLess.
120 #error "Please compile with -msse -mfpmath=sse, or equivalent."
121 #endif
122  return w1.Value() == w2.Value();
123 }
124 
125 // These seemingly unnecessary overloads are actually needed to make
126 // comparisons like FloatWeightTpl<float> == float compile. If only the
127 // templated version exists, the FloatWeightTpl<float>(float) conversion
128 // won't be found.
129 constexpr bool operator==(const FloatWeightTpl<float> &w1,
130  const FloatWeightTpl<float> &w2) {
131  return operator==<float>(w1, w2);
132 }
133 
134 constexpr bool operator==(const FloatWeightTpl<double> &w1,
135  const FloatWeightTpl<double> &w2) {
136  return operator==<double>(w1, w2);
137 }
138 
139 template <class T>
140 constexpr bool operator!=(const FloatWeightTpl<T> &w1,
141  const FloatWeightTpl<T> &w2) {
142  return !(w1 == w2);
143 }
144 
145 constexpr bool operator!=(const FloatWeightTpl<float> &w1,
146  const FloatWeightTpl<float> &w2) {
147  return operator!=<float>(w1, w2);
148 }
149 
150 constexpr bool operator!=(const FloatWeightTpl<double> &w1,
151  const FloatWeightTpl<double> &w2) {
152  return operator!=<double>(w1, w2);
153 }
154 
155 template <class T>
156 constexpr bool FloatApproxEqual(T w1, T w2, float delta = kDelta) {
157  return w1 <= w2 + delta && w2 <= w1 + delta;
158 }
159 
160 template <class T>
161 constexpr bool ApproxEqual(const FloatWeightTpl<T> &w1,
162  const FloatWeightTpl<T> &w2, float delta = kDelta) {
163  return FloatApproxEqual(w1.Value(), w2.Value(), delta);
164 }
165 
166 template <class T>
167 inline std::ostream &operator<<(std::ostream &strm,
168  const FloatWeightTpl<T> &w) {
169  if (w.Value() == FloatLimits<T>::PosInfinity()) {
170  return strm << "Infinity";
171  } else if (w.Value() == FloatLimits<T>::NegInfinity()) {
172  return strm << "-Infinity";
173  } else if (internal::IsNan(w.Value())) {
174  return strm << "BadNumber";
175  } else {
176  return strm << w.Value();
177  }
178 }
179 
180 template <class T>
181 inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) {
182  std::string s;
183  strm >> s;
184  if (s == "Infinity") {
186  } else if (s == "-Infinity") {
188  } else {
189  char *p;
190  T f = strtod(s.c_str(), &p);
191  if (p < s.c_str() + s.size()) {
192  strm.clear(std::ios::badbit);
193  } else {
194  w = FloatWeightTpl<T>(f);
195  }
196  }
197  return strm;
198 }
199 
200 // Tropical semiring: (min, +, inf, 0).
201 template <class T>
203  public:
204  using typename FloatWeightTpl<T>::ValueType;
208 
209  TropicalWeightTpl() noexcept : FloatWeightTpl<T>() {}
210 
211  constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
212 
213  static constexpr TropicalWeightTpl<T> Zero() { return Limits::PosInfinity(); }
214 
215  static constexpr TropicalWeightTpl<T> One() { return 0; }
216 
217  static constexpr TropicalWeightTpl<T> NoWeight() {
218  return Limits::NumberBad();
219  }
220 
221  static const std::string &Type() {
222  static const std::string *const type = new std::string(
224  return *type;
225  }
226 
227  constexpr bool Member() const {
228  // All floating point values except for NaNs and negative infinity are valid
229  // tropical weights.
230  //
231  // Testing membership of a given value can be done by simply checking that
232  // it is strictly greater than negative infinity, which fails for negative
233  // infinity itself but also for NaNs. This can usually be accomplished in a
234  // single instruction (such as *UCOMI* on x86) without branching logic.
235  //
236  // An additional wrinkle involves constexpr correctness of floating point
237  // comparisons against NaN. GCC is uneven when it comes to which expressions
238  // it considers compile-time constants. In particular, current versions of
239  // GCC do not always consider (nan < inf) to be a constant expression, but
240  // do consider (inf < nan) to be a constant expression. (See
241  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88173 and
242  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88683 for details.) In order
243  // to allow Member() to be a constexpr function accepted by GCC, we write
244  // the comparison here as (-inf < v).
245  return Limits::NegInfinity() < Value();
246  }
247 
248  TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
249  if (!Member() || Value() == Limits::PosInfinity()) {
250  return *this;
251  } else {
252  return TropicalWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
253  }
254  }
255 
256  constexpr TropicalWeightTpl<T> Reverse() const { return *this; }
257 
258  static constexpr uint64_t Properties() {
260  }
261 };
262 
263 // Single precision tropical weight.
265 
266 template <class T>
268  const TropicalWeightTpl<T> &w2) {
269  return (!w1.Member() || !w2.Member()) ? TropicalWeightTpl<T>::NoWeight()
270  : w1.Value() < w2.Value() ? w1 : w2;
271 }
272 
273 // See comment at operator==(FloatWeightTpl<float>, FloatWeightTpl<float>)
274 // for why these overloads are present.
276  const TropicalWeightTpl<float> &w2) {
277  return Plus<float>(w1, w2);
278 }
279 
281  const TropicalWeightTpl<double> &w2) {
282  return Plus<double>(w1, w2);
283 }
284 
285 template <class T>
287  const TropicalWeightTpl<T> &w2) {
288  // The following is safe in the context of the Tropical (and Log) semiring
289  // for all IEEE floating point values, including infinities and NaNs,
290  // because:
291  //
292  // * If one or both of the floating point Values is NaN and hence not a
293  // Member, the result of addition below is NaN, so the result is not a
294  // Member. This supersedes all other cases, so we only consider non-NaN
295  // values next.
296  //
297  // * If both Values are finite, there is no issue.
298  //
299  // * If one of the Values is infinite, or if both are infinities with the
300  // same sign, the result of floating point addition is the same infinity,
301  // so there is no issue.
302  //
303  // * If both of the Values are infinities with opposite signs, the result of
304  // adding IEEE floating point -inf + inf is NaN and hence not a Member. But
305  // since -inf was not a Member to begin with, returning a non-Member result
306  // is fine as well.
307  return TropicalWeightTpl<T>(w1.Value() + w2.Value());
308 }
309 
311  const TropicalWeightTpl<float> &w2) {
312  return Times<float>(w1, w2);
313 }
314 
316  const TropicalWeightTpl<double> &w2) {
317  return Times<double>(w1, w2);
318 }
319 
320 template <class T>
322  const TropicalWeightTpl<T> &w2,
323  DivideType typ = DIVIDE_ANY) {
324  // The following is safe in the context of the Tropical (and Log) semiring
325  // for all IEEE floating point values, including infinities and NaNs,
326  // because:
327  //
328  // * If one or both of the floating point Values is NaN and hence not a
329  // Member, the result of subtraction below is NaN, so the result is not a
330  // Member. This supersedes all other cases, so we only consider non-NaN
331  // values next.
332  //
333  // * If both Values are finite, there is no issue.
334  //
335  // * If w2.Value() is -inf (and hence w2 is not a Member), the result of ?:
336  // below is NoWeight, which is not a Member.
337  //
338  // Whereas in IEEE floating point semantics 0/inf == 0, this does not carry
339  // over to this semiring (since TropicalWeight(-inf) would be the analogue
340  // of floating point inf) and instead Divide(Zero(), TropicalWeight(-inf))
341  // is NoWeight().
342  //
343  // * If w2.Value() is inf (and hence w2 is Zero), the resulting floating
344  // point value is either NaN (if w1 is Zero or if w1.Value() is NaN) and
345  // hence not a Member, or it is -inf and hence not a Member; either way,
346  // division by Zero results in a non-Member result.
347  using Weight = TropicalWeightTpl<T>;
348  return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
349 }
350 
352  const TropicalWeightTpl<float> &w2,
353  DivideType typ = DIVIDE_ANY) {
354  return Divide<float>(w1, w2, typ);
355 }
356 
358  const TropicalWeightTpl<double> &w2,
359  DivideType typ = DIVIDE_ANY) {
360  return Divide<double>(w1, w2, typ);
361 }
362 
363 // Power(w, n) calculates the n-th power of w with respect to semiring Times.
364 //
365 // In the case of the Tropical (and Log) semiring, the exponent n is not
366 // restricted to be an integer. It can be a floating point value, for example.
367 //
368 // In weight.h, a narrower and hence more broadly applicable version of
369 // Power(w, n) is defined for arbitrary weight types and non-negative integer
370 // exponents n (of type size_t) and implemented in terms of repeated
371 // multiplication using Times.
372 //
373 // Without further provisions this means that, when an expression such as
374 //
375 // Power(TropicalWeightTpl<float>::One(), static_cast<size_t>(2))
376 //
377 // is specified, the overload of Power() is ambiguous. The template function
378 // below could be instantiated as
379 //
380 // Power<float, size_t>(const TropicalWeightTpl<float> &, size_t)
381 //
382 // and the template function defined in weight.h (further specialized below)
383 // could be instantiated as
384 //
385 // Power<TropicalWeightTpl<float>>(const TropicalWeightTpl<float> &, size_t)
386 //
387 // That would lead to two definitions with identical signatures, which results
388 // in a compilation error. To avoid that, we hide the definition of Power<T, V>
389 // when V is size_t, so only Power<W> is visible. Power<W> is further
390 // specialized to Power<TropicalWeightTpl<...>>, and the overloaded definition
391 // of Power<T, V> is made conditionally available only to that template
392 // specialization.
393 
394 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
395  typename std::enable_if_t<Enable> * = nullptr>
397  using Weight = TropicalWeightTpl<T>;
398  return (!w.Member() || internal::IsNan(n))
399  ? Weight::NoWeight()
400  : (n == 0 || w == Weight::One()) ? Weight::One()
401  : Weight(w.Value() * n);
402 }
403 
404 // Specializes the library-wide template to use the above implementation; rules
405 // of function template instantiation require this be a full instantiation.
406 
407 template <>
408 constexpr TropicalWeightTpl<float> Power<TropicalWeightTpl<float>>(
409  const TropicalWeightTpl<float> &weight, size_t n) {
410  return Power<float, size_t, true>(weight, n);
411 }
412 
413 template <>
414 constexpr TropicalWeightTpl<double> Power<TropicalWeightTpl<double>>(
415  const TropicalWeightTpl<double> &weight, size_t n) {
416  return Power<double, size_t, true>(weight, n);
417 }
418 
419 // Log semiring: (log(e^-x + e^-y), +, inf, 0).
420 template <class T>
421 class LogWeightTpl : public FloatWeightTpl<T> {
422  public:
423  using typename FloatWeightTpl<T>::ValueType;
427 
428  LogWeightTpl() noexcept : FloatWeightTpl<T>() {}
429 
430  constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
431 
432  static constexpr LogWeightTpl Zero() { return Limits::PosInfinity(); }
433 
434  static constexpr LogWeightTpl One() { return 0; }
435 
436  static constexpr LogWeightTpl NoWeight() { return Limits::NumberBad(); }
437 
438  static const std::string &Type() {
439  static const std::string *const type = new std::string(
441  return *type;
442  }
443 
444  constexpr bool Member() const {
445  // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
446  return Limits::NegInfinity() < Value();
447  }
448 
449  LogWeightTpl<T> Quantize(float delta = kDelta) const {
450  if (!Member() || Value() == Limits::PosInfinity()) {
451  return *this;
452  } else {
453  return LogWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
454  }
455  }
456 
457  constexpr LogWeightTpl<T> Reverse() const { return *this; }
458 
459  static constexpr uint64_t Properties() {
461  }
462 };
463 
464 // Single-precision log weight.
466 
467 // Double-precision log weight.
469 
470 namespace internal {
471 
472 // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming y >= x.
473 inline double LogPosExp(double x) {
474  DCHECK(!(x < 0)); // NB: NaN values are allowed.
475  return log1p(exp(-x));
476 }
477 
478 // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming y >= x.
479 inline double LogNegExp(double x) {
480  DCHECK(!(x < 0)); // NB: NaN values are allowed.
481  return log1p(-exp(-x));
482 }
483 
484 // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
485 // Kahan compensated summation provides an error bound that is
486 // independent of the number of addends. Assumes b >= a;
487 // c is the compensation.
488 inline double KahanLogSum(double a, double b, double *c) {
489  DCHECK_GE(b, a);
490  double y = -LogPosExp(b - a) - *c;
491  double t = a + y;
492  *c = (t - a) - y;
493  return t;
494 }
495 
496 // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
497 // Kahan compensated summation provides an error bound that is
498 // independent of the number of addends. Assumes b > a;
499 // c is the compensation.
500 inline double KahanLogDiff(double a, double b, double *c) {
501  DCHECK_GT(b, a);
502  double y = -LogNegExp(b - a) - *c;
503  double t = a + y;
504  *c = (t - a) - y;
505  return t;
506 }
507 
508 } // namespace internal
509 
510 template <class T>
512  const LogWeightTpl<T> &w2) {
513  using Limits = FloatLimits<T>;
514  const T f1 = w1.Value();
515  const T f2 = w2.Value();
516  if (f1 == Limits::PosInfinity()) {
517  return w2;
518  } else if (f2 == Limits::PosInfinity()) {
519  return w1;
520  } else if (f1 > f2) {
521  return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2));
522  } else {
523  return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1));
524  }
525 }
526 
528  const LogWeightTpl<float> &w2) {
529  return Plus<float>(w1, w2);
530 }
531 
533  const LogWeightTpl<double> &w2) {
534  return Plus<double>(w1, w2);
535 }
536 
537 // Returns NoWeight if w1 < w2 (w1.Value() > w2.Value()).
538 template <class T>
540  const LogWeightTpl<T> &w2) {
541  using Limits = FloatLimits<T>;
542  const T f1 = w1.Value();
543  const T f2 = w2.Value();
544  if (f1 > f2) return LogWeightTpl<T>::NoWeight();
545  if (f2 == Limits::PosInfinity()) return f1;
546  const T d = f2 - f1;
547  if (d == Limits::PosInfinity()) return f1;
548  return f1 - internal::LogNegExp(d);
549 }
550 
552  const LogWeightTpl<float> &w2) {
553  return Minus<float>(w1, w2);
554 }
555 
557  const LogWeightTpl<double> &w2) {
558  return Minus<double>(w1, w2);
559 }
560 
561 template <class T>
563  const LogWeightTpl<T> &w2) {
564  // The comments for Times(Tropical...) above apply here unchanged.
565  return LogWeightTpl<T>(w1.Value() + w2.Value());
566 }
567 
569  const LogWeightTpl<float> &w2) {
570  return Times<float>(w1, w2);
571 }
572 
574  const LogWeightTpl<double> &w2) {
575  return Times<double>(w1, w2);
576 }
577 
578 template <class T>
580  const LogWeightTpl<T> &w2,
581  DivideType typ = DIVIDE_ANY) {
582  // The comments for Divide(Tropical...) above apply here unchanged.
583  using Weight = LogWeightTpl<T>;
584  return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
585 }
586 
588  const LogWeightTpl<float> &w2,
589  DivideType typ = DIVIDE_ANY) {
590  return Divide<float>(w1, w2, typ);
591 }
592 
594  const LogWeightTpl<double> &w2,
595  DivideType typ = DIVIDE_ANY) {
596  return Divide<double>(w1, w2, typ);
597 }
598 
599 // The comments for Power<>(Tropical...) above apply here unchanged.
600 
601 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
602  typename std::enable_if_t<Enable> * = nullptr>
603 constexpr LogWeightTpl<T> Power(const LogWeightTpl<T> &w, V n) {
604  using Weight = LogWeightTpl<T>;
605  return (!w.Member() || internal::IsNan(n))
606  ? Weight::NoWeight()
607  : (n == 0 || w == Weight::One()) ? Weight::One()
608  : Weight(w.Value() * n);
609 }
610 
611 // Specializes the library-wide template to use the above implementation; rules
612 // of function template instantiation require this be a full instantiation.
613 
614 template <>
615 constexpr LogWeightTpl<float> Power<LogWeightTpl<float>>(
616  const LogWeightTpl<float> &weight, size_t n) {
617  return Power<float, size_t, true>(weight, n);
618 }
619 
620 template <>
621 constexpr LogWeightTpl<double> Power<LogWeightTpl<double>>(
622  const LogWeightTpl<double> &weight, size_t n) {
623  return Power<double, size_t, true>(weight, n);
624 }
625 
626 // Specialization using the Kahan compensated summation.
627 template <class T>
628 class Adder<LogWeightTpl<T>> {
629  public:
631 
632  explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
633 
634  Weight Add(const Weight &w) {
635  using Limits = FloatLimits<T>;
636  const T f = w.Value();
637  if (f == Limits::PosInfinity()) {
638  return Sum();
639  } else if (sum_ == Limits::PosInfinity()) {
640  sum_ = f;
641  c_ = 0.0;
642  } else if (f > sum_) {
643  sum_ = internal::KahanLogSum(sum_, f, &c_);
644  } else {
645  sum_ = internal::KahanLogSum(f, sum_, &c_);
646  }
647  return Sum();
648  }
649 
650  Weight Sum() const { return Weight(sum_); }
651 
652  void Reset(Weight w = Weight::Zero()) {
653  sum_ = w.Value();
654  c_ = 0.0;
655  }
656 
657  private:
658  double sum_;
659  double c_; // Kahan compensation.
660 };
661 
662 // Real semiring: (+, *, 0, 1).
663 template <class T>
664 class RealWeightTpl : public FloatWeightTpl<T> {
665  public:
666  using typename FloatWeightTpl<T>::ValueType;
670 
671  RealWeightTpl() noexcept : FloatWeightTpl<T>() {}
672 
673  constexpr RealWeightTpl(T f) : FloatWeightTpl<T>(f) {}
674 
675  static constexpr RealWeightTpl Zero() { return 0; }
676 
677  static constexpr RealWeightTpl One() { return 1; }
678 
679  static constexpr RealWeightTpl NoWeight() { return Limits::NumberBad(); }
680 
681  static const std::string &Type() {
682  static const std::string *const type = new std::string(
684  return *type;
685  }
686 
687  constexpr bool Member() const {
688  // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
689  return Limits::NegInfinity() < Value();
690  }
691 
692  RealWeightTpl<T> Quantize(float delta = kDelta) const {
693  if (!Member() || Value() == Limits::PosInfinity()) {
694  return *this;
695  } else {
696  return RealWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
697  }
698  }
699 
700  constexpr RealWeightTpl<T> Reverse() const { return *this; }
701 
702  static constexpr uint64_t Properties() {
704  }
705 };
706 
707 // Single-precision log weight.
709 
710 // Double-precision log weight.
712 
713 namespace internal {
714 
715 // a + b = KahanRealSum(a, b, ...).
716 // Kahan compensated summation provides an error bound that is
717 // independent of the number of addends. c is the compensation.
718 inline double KahanRealSum(double a, double b, double *c) {
719  double y = b - *c;
720  double t = a + y;
721  *c = (t - a) - y;
722  return t;
723 }
724 
725 }; // namespace internal
726 
727 // The comments for Times(Tropical...) above apply here unchanged.
728 template <class T>
730  const RealWeightTpl<T> &w2) {
731  const T f1 = w1.Value();
732  const T f2 = w2.Value();
733  return RealWeightTpl<T>(f1 + f2);
734 }
735 
737  const RealWeightTpl<float> &w2) {
738  return Plus<float>(w1, w2);
739 }
740 
742  const RealWeightTpl<double> &w2) {
743  return Plus<double>(w1, w2);
744 }
745 
746 template <class T>
748  const RealWeightTpl<T> &w2) {
749  // The comments for Divide(Tropical...) above apply here unchanged.
750  const T f1 = w1.Value();
751  const T f2 = w2.Value();
752  return RealWeightTpl<T>(f1 - f2);
753 }
754 
756  const RealWeightTpl<float> &w2) {
757  return Minus<float>(w1, w2);
758 }
759 
761  const RealWeightTpl<double> &w2) {
762  return Minus<double>(w1, w2);
763 }
764 
765 // The comments for Times(Tropical...) above apply here similarly.
766 template <class T>
768  const RealWeightTpl<T> &w2) {
769  return RealWeightTpl<T>(w1.Value() * w2.Value());
770 }
771 
773  const RealWeightTpl<float> &w2) {
774  return Times<float>(w1, w2);
775 }
776 
778  const RealWeightTpl<double> &w2) {
779  return Times<double>(w1, w2);
780 }
781 
782 template <class T>
784  const RealWeightTpl<T> &w2,
785  DivideType typ = DIVIDE_ANY) {
786  using Weight = RealWeightTpl<T>;
787  return w2.Member() ? Weight(w1.Value() / w2.Value()) : Weight::NoWeight();
788 }
789 
791  const RealWeightTpl<float> &w2,
792  DivideType typ = DIVIDE_ANY) {
793  return Divide<float>(w1, w2, typ);
794 }
795 
797  const RealWeightTpl<double> &w2,
798  DivideType typ = DIVIDE_ANY) {
799  return Divide<double>(w1, w2, typ);
800 }
801 
802 // The comments for Power<>(Tropical...) above apply here unchanged.
803 
804 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
805  typename std::enable_if_t<Enable> * = nullptr>
806 constexpr RealWeightTpl<T> Power(const RealWeightTpl<T> &w, V n) {
807  using Weight = RealWeightTpl<T>;
808  return (!w.Member() || internal::IsNan(n))
809  ? Weight::NoWeight()
810  : (n == 0 || w == Weight::One()) ? Weight::One()
811  : Weight(pow(w.Value(), n));
812 }
813 
814 // Specializes the library-wide template to use the above implementation; rules
815 // of function template instantiation require this be a full instantiation.
816 
817 template <>
818 constexpr RealWeightTpl<float> Power<RealWeightTpl<float>>(
819  const RealWeightTpl<float> &weight, size_t n) {
820  return Power<float, size_t, true>(weight, n);
821 }
822 
823 template <>
824 constexpr RealWeightTpl<double> Power<RealWeightTpl<double>>(
825  const RealWeightTpl<double> &weight, size_t n) {
826  return Power<double, size_t, true>(weight, n);
827 }
828 
829 // Specialization using the Kahan compensated summation.
830 template <class T>
831 class Adder<RealWeightTpl<T>> {
832  public:
834 
835  explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
836 
837  Weight Add(const Weight &w) {
838  using Limits = FloatLimits<T>;
839  const T f = w.Value();
840  if (f == Limits::PosInfinity()) {
841  sum_ = f;
842  } else if (sum_ == Limits::PosInfinity()) {
843  return sum_;
844  } else {
845  sum_ = internal::KahanRealSum(sum_, f, &c_);
846  }
847  return Sum();
848  }
849 
850  Weight Sum() const { return Weight(sum_); }
851 
852  void Reset(Weight w = Weight::Zero()) {
853  sum_ = w.Value();
854  c_ = 0.0;
855  }
856 
857  private:
858  double sum_;
859  double c_; // Kahan compensation.
860 };
861 
862 // MinMax semiring: (min, max, inf, -inf).
863 template <class T>
864 class MinMaxWeightTpl : public FloatWeightTpl<T> {
865  public:
866  using typename FloatWeightTpl<T>::ValueType;
870 
871  MinMaxWeightTpl() noexcept : FloatWeightTpl<T>() {}
872 
873  constexpr MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} // NOLINT
874 
875  static constexpr MinMaxWeightTpl Zero() { return Limits::PosInfinity(); }
876 
877  static constexpr MinMaxWeightTpl One() { return Limits::NegInfinity(); }
878 
879  static constexpr MinMaxWeightTpl NoWeight() { return Limits::NumberBad(); }
880 
881  static const std::string &Type() {
882  static const std::string *const type = new std::string(
884  return *type;
885  }
886 
887  // Fails for IEEE NaN.
888  constexpr bool Member() const { return !internal::IsNan(Value()); }
889 
890  MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
891  // If one of infinities, or a NaN.
892  if (!Member() || Value() == Limits::NegInfinity() ||
893  Value() == Limits::PosInfinity()) {
894  return *this;
895  } else {
896  return MinMaxWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
897  }
898  }
899 
900  constexpr MinMaxWeightTpl<T> Reverse() const { return *this; }
901 
902  static constexpr uint64_t Properties() {
904  }
905 };
906 
907 // Single-precision min-max weight.
909 
910 // Min.
911 template <class T>
913  const MinMaxWeightTpl<T> &w2) {
914  return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
915  : w1.Value() < w2.Value() ? w1 : w2;
916 }
917 
919  const MinMaxWeightTpl<float> &w2) {
920  return Plus<float>(w1, w2);
921 }
922 
924  const MinMaxWeightTpl<double> &w2) {
925  return Plus<double>(w1, w2);
926 }
927 
928 // Max.
929 template <class T>
931  const MinMaxWeightTpl<T> &w2) {
932  return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
933  : w1.Value() >= w2.Value() ? w1 : w2;
934 }
935 
937  const MinMaxWeightTpl<float> &w2) {
938  return Times<float>(w1, w2);
939 }
940 
942  const MinMaxWeightTpl<double> &w2) {
943  return Times<double>(w1, w2);
944 }
945 
946 // Defined only for special cases.
947 template <class T>
949  const MinMaxWeightTpl<T> &w2,
950  DivideType typ = DIVIDE_ANY) {
951  return w1.Value() >= w2.Value() ? w1 : MinMaxWeightTpl<T>::NoWeight();
952 }
953 
955  const MinMaxWeightTpl<float> &w2,
956  DivideType typ = DIVIDE_ANY) {
957  return Divide<float>(w1, w2, typ);
958 }
959 
961  const MinMaxWeightTpl<double> &w2,
962  DivideType typ = DIVIDE_ANY) {
963  return Divide<double>(w1, w2, typ);
964 }
965 
966 // Converts to tropical.
967 template <>
969  constexpr TropicalWeight operator()(const LogWeight &w) const {
970  return w.Value();
971  }
972 };
973 
974 template <>
976  constexpr TropicalWeight operator()(const Log64Weight &w) const {
977  return w.Value();
978  }
979 };
980 
981 // Converts to log.
982 template <>
984  constexpr LogWeight operator()(const TropicalWeight &w) const {
985  return w.Value();
986  }
987 };
988 
989 template <>
991  LogWeight operator()(const RealWeight &w) const { return -log(w.Value()); }
992 };
993 
994 template <>
996  LogWeight operator()(const Real64Weight &w) const { return -log(w.Value()); }
997 };
998 
999 template <>
1001  constexpr LogWeight operator()(const Log64Weight &w) const {
1002  return w.Value();
1003  }
1004 };
1005 
1006 // Converts to log64.
1007 template <>
1009  constexpr Log64Weight operator()(const TropicalWeight &w) const {
1010  return w.Value();
1011  }
1012 };
1013 
1014 template <>
1016  Log64Weight operator()(const RealWeight &w) const { return -log(w.Value()); }
1017 };
1018 
1019 template <>
1022  return -log(w.Value());
1023  }
1024 };
1025 
1026 template <>
1028  constexpr Log64Weight operator()(const LogWeight &w) const {
1029  return w.Value();
1030  }
1031 };
1032 
1033 // Converts to real.
1034 template <>
1036  RealWeight operator()(const LogWeight &w) const { return exp(-w.Value()); }
1037 };
1038 
1039 template <>
1041  RealWeight operator()(const Log64Weight &w) const { return exp(-w.Value()); }
1042 };
1043 
1044 template <>
1046  constexpr RealWeight operator()(const Real64Weight &w) const {
1047  return w.Value();
1048  }
1049 };
1050 
1051 // Converts to real64
1052 template <>
1054  Real64Weight operator()(const LogWeight &w) const { return exp(-w.Value()); }
1055 };
1056 
1057 template <>
1060  return exp(-w.Value());
1061  }
1062 };
1063 
1064 template <>
1066  constexpr Real64Weight operator()(const RealWeight &w) const {
1067  return w.Value();
1068  }
1069 };
1070 
1071 // This function object returns random integers chosen from [0,
1072 // num_random_weights). The allow_zero argument determines whether Zero() and
1073 // zero divisors should be returned in the random weight generation. This is
1074 // intended primary for testing.
1075 template <class Weight>
1077  public:
1079  uint64_t seed = std::random_device()(), bool allow_zero = true,
1080  const size_t num_random_weights = kNumRandomWeights)
1081  : rand_(seed),
1082  allow_zero_(allow_zero),
1083  num_random_weights_(num_random_weights) {}
1084 
1085  Weight operator()() const {
1086  const int sample = std::uniform_int_distribution<>(
1087  0, num_random_weights_ + allow_zero_ - 1)(rand_);
1088  if (allow_zero_ && sample == num_random_weights_) return Weight::Zero();
1089  return Weight(sample);
1090  }
1091 
1092  private:
1093  mutable std::mt19937_64 rand_;
1094  const bool allow_zero_;
1095  const size_t num_random_weights_;
1096 };
1097 
1098 template <class T>
1100  : public FloatWeightGenerate<TropicalWeightTpl<T>> {
1101  public:
1104 
1105  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1106  bool allow_zero = true,
1107  size_t num_random_weights = kNumRandomWeights)
1108  : Generate(seed, allow_zero, num_random_weights) {}
1109 
1110  Weight operator()() const { return Weight(Generate::operator()()); }
1111 };
1112 
1113 template <class T>
1115  : public FloatWeightGenerate<LogWeightTpl<T>> {
1116  public:
1119 
1120  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1121  bool allow_zero = true,
1122  size_t num_random_weights = kNumRandomWeights)
1123  : Generate(seed, allow_zero, num_random_weights) {}
1124 
1125  Weight operator()() const { return Weight(Generate::operator()()); }
1126 };
1127 
1128 template <class T>
1130  : public FloatWeightGenerate<RealWeightTpl<T>> {
1131  public:
1134 
1135  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1136  bool allow_zero = true,
1137  size_t num_random_weights = kNumRandomWeights)
1138  : Generate(seed, allow_zero, num_random_weights) {}
1139 
1140  Weight operator()() const { return Weight(Generate::operator()()); }
1141 };
1142 
1143 // This function object returns random integers chosen from [0,
1144 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
1145 // zero divisors should be returned in the random weight generation. This is
1146 // intended primary for testing.
1147 template <class T>
1149  public:
1151 
1152  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1153  bool allow_zero = true,
1154  size_t num_random_weights = kNumRandomWeights)
1155  : rand_(seed),
1156  allow_zero_(allow_zero),
1157  num_random_weights_(num_random_weights) {}
1158 
1159  Weight operator()() const {
1160  const int sample = std::uniform_int_distribution<>(
1161  -num_random_weights_, num_random_weights_ + allow_zero_)(rand_);
1162  if (allow_zero_ && sample == 0) {
1163  return Weight::Zero();
1164  } else if (sample == -num_random_weights_) {
1165  return Weight::One();
1166  } else {
1167  return Weight(sample);
1168  }
1169  }
1170 
1171  private:
1172  mutable std::mt19937_64 rand_;
1173  const bool allow_zero_;
1174  const size_t num_random_weights_;
1175 };
1176 
1177 } // namespace fst
1178 
1179 #endif // FST_FLOAT_WEIGHT_H_
RealWeight operator()(const Log64Weight &w) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Weight Add(const Weight &w)
Definition: float-weight.h:837
static constexpr RealWeightTpl Zero()
Definition: float-weight.h:675
void Reset(Weight w=Weight::Zero())
Definition: float-weight.h:652
static constexpr TropicalWeightTpl< T > Zero()
Definition: float-weight.h:213
constexpr TropicalWeight operator()(const LogWeight &w) const
Definition: float-weight.h:969
FloatWeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, const size_t num_random_weights=kNumRandomWeights)
constexpr RealWeight operator()(const Real64Weight &w) const
constexpr Log64Weight operator()(const LogWeight &w) const
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:60
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
constexpr bool Member() const
Definition: float-weight.h:687
Adder(Weight w=Weight::Zero())
Definition: float-weight.h:835
RealWeight operator()(const LogWeight &w) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:63
static constexpr std::string_view GetPrecisionString()
Definition: float-weight.h:96
RealWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:692
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:432
static constexpr MinMaxWeightTpl NoWeight()
Definition: float-weight.h:879
static constexpr MinMaxWeightTpl One()
Definition: float-weight.h:877
constexpr uint64_t kIdempotent
Definition: weight.h:144
RealWeightTpl() noexcept
Definition: float-weight.h:671
#define DCHECK_GT(x, y)
Definition: log.h:73
std::ostream & Write(std::ostream &strm) const
Definition: float-weight.h:77
Weight operator()() const
constexpr Log64Weight operator()(const TropicalWeight &w) const
void SetValue(const T &f)
Definition: float-weight.h:94
static constexpr uint64_t Properties()
Definition: float-weight.h:258
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
static constexpr T PosInfinity()
Definition: float-weight.h:56
constexpr RealWeightTpl< T > Reverse() const
Definition: float-weight.h:700
static constexpr uint64_t Properties()
Definition: float-weight.h:902
TropicalWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:248
constexpr uint64_t kRightSemiring
Definition: weight.h:136
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:211
static constexpr LogWeightTpl One()
Definition: float-weight.h:434
static constexpr TropicalWeightTpl< T > One()
Definition: float-weight.h:215
constexpr bool FloatApproxEqual(T w1, T w2, float delta=kDelta)
Definition: float-weight.h:156
constexpr FloatWeightTpl(T f)
Definition: float-weight.h:73
static constexpr T NumberBad()
Definition: float-weight.h:62
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:181
Log64Weight operator()(const Real64Weight &w) const
Real64Weight operator()(const LogWeight &w) const
constexpr bool IsNan(T value)
Definition: float-weight.h:47
constexpr TropicalWeightTpl< T > Reverse() const
Definition: float-weight.h:256
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:53
constexpr LogWeight operator()(const TropicalWeight &w) const
Definition: float-weight.h:984
MinMaxWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:890
static constexpr LogWeightTpl NoWeight()
Definition: float-weight.h:436
constexpr uint64_t kCommutative
Definition: weight.h:141
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:70
LogWeight operator()(const RealWeight &w) const
Definition: float-weight.h:991
Adder(Weight w=Weight::Zero())
Definition: float-weight.h:632
void Reset(Weight w=Weight::Zero())
Definition: float-weight.h:852
constexpr MinMaxWeightTpl(T f)
Definition: float-weight.h:873
static constexpr RealWeightTpl NoWeight()
Definition: float-weight.h:679
static const std::string & Type()
Definition: float-weight.h:681
LogWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:449
static constexpr TropicalWeightTpl< T > NoWeight()
Definition: float-weight.h:217
static constexpr uint64_t Properties()
Definition: float-weight.h:702
static constexpr RealWeightTpl One()
Definition: float-weight.h:677
constexpr LogWeightTpl< T > Reverse() const
Definition: float-weight.h:457
double LogPosExp(double x)
Definition: float-weight.h:473
double KahanLogSum(double a, double b, double *c)
Definition: float-weight.h:488
TropicalWeightTpl() noexcept
Definition: float-weight.h:209
LogWeight operator()(const Real64Weight &w) const
Definition: float-weight.h:996
std::istream & Read(std::istream &strm)
Definition: float-weight.h:75
constexpr uint64_t kPath
Definition: weight.h:147
static const std::string & Type()
Definition: float-weight.h:221
Log64Weight operator()(const RealWeight &w) const
Weight Add(const Weight &w)
Definition: float-weight.h:634
LogWeightTpl() noexcept
Definition: float-weight.h:428
constexpr bool Member() const
Definition: float-weight.h:227
FloatWeightTpl() noexcept
Definition: float-weight.h:71
double LogNegExp(double x)
Definition: float-weight.h:479
constexpr RealWeightTpl(T f)
Definition: float-weight.h:673
MinMaxWeightTpl() noexcept
Definition: float-weight.h:871
LogWeightTpl< T > Minus(const LogWeightTpl< T > &w1, const LogWeightTpl< T > &w2)
Definition: float-weight.h:539
constexpr bool Member() const
Definition: float-weight.h:888
std::string StrCat(const StringOrInt &s1, const StringOrInt &s2)
Definition: compat.h:279
static const std::string & Type()
Definition: float-weight.h:438
constexpr bool Member() const
Definition: float-weight.h:444
size_t Hash() const
Definition: float-weight.h:81
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:66
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:50
constexpr size_t kNumRandomWeights
Definition: weight.h:151
static constexpr MinMaxWeightTpl Zero()
Definition: float-weight.h:875
constexpr TropicalWeight operator()(const Log64Weight &w) const
Definition: float-weight.h:976
static const std::string & Type()
Definition: float-weight.h:881
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:65
constexpr MinMaxWeightTpl< T > Reverse() const
Definition: float-weight.h:900
#define DCHECK(x)
Definition: log.h:70
static constexpr uint64_t Properties()
Definition: float-weight.h:459
constexpr TropicalWeightTpl< T > Power(const TropicalWeightTpl< T > &w, V n)
Definition: float-weight.h:396
#define DCHECK_GE(x, y)
Definition: log.h:75
constexpr LogWeight operator()(const Log64Weight &w) const
DivideType
Definition: weight.h:162
constexpr uint64_t kLeftSemiring
Definition: weight.h:133
constexpr float kDelta
Definition: weight.h:130
constexpr Real64Weight operator()(const RealWeight &w) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Real64Weight operator()(const Log64Weight &w) const
constexpr const T & Value() const
Definition: float-weight.h:91
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:57
constexpr LogWeightTpl(T f)
Definition: float-weight.h:430
constexpr TropicalWeightTpl(T f)
Definition: float-weight.h:211
double KahanLogDiff(double a, double b, double *c)
Definition: float-weight.h:500
static constexpr T NegInfinity()
Definition: float-weight.h:60
double KahanRealSum(double a, double b, double *c)
Definition: float-weight.h:718