FST  openfst-1.8.3
OpenFst Library
float-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Float weight set and associated semiring operation definitions.
19 
20 #ifndef FST_FLOAT_WEIGHT_H_
21 #define FST_FLOAT_WEIGHT_H_
22 
23 #include <algorithm>
24 #include <climits>
25 #include <cmath>
26 #include <cstddef>
27 #include <cstdint>
28 #include <cstdlib>
29 #include <cstring>
30 #include <ios>
31 #include <istream>
32 #include <limits>
33 #include <ostream>
34 #include <random>
35 #include <sstream>
36 #include <string>
37 #include <type_traits>
38 
39 #include <fst/log.h>
40 #include <fst/util.h>
41 #include <fst/weight.h>
42 #include <fst/compat.h>
43 #include <string_view>
44 
45 namespace fst {
46 
47 namespace internal {
48 // TODO(wolfsonkin): Replace with `std::isnan` if and when that ends up
49 // constexpr. For context, see
50 // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0533r6.pdf.
51 template <class T>
52 inline constexpr bool IsNan(T value) {
53  return value != value;
54 }
55 } // namespace internal
56 
57 // Numeric limits class.
58 template <class T>
59 class FloatLimits {
60  public:
61  static constexpr T PosInfinity() {
62  return std::numeric_limits<T>::infinity();
63  }
64 
65  static constexpr T NegInfinity() { return -PosInfinity(); }
66 
67  static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); }
68 };
69 
70 // Weight class to be templated on floating-points types.
71 template <class T = float>
73  public:
74  using ValueType = T;
75 
76  FloatWeightTpl() noexcept = default;
77 
78  constexpr FloatWeightTpl(T f) : value_(f) {} // NOLINT
79 
80  std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
81 
82  std::ostream &Write(std::ostream &strm) const {
83  return WriteType(strm, value_);
84  }
85 
86  size_t Hash() const {
87  size_t hash = 0;
88  // Avoid using union, which would be undefined behavior.
89  // Use memcpy, similar to bit_cast, but sizes may be different.
90  // This should be optimized into a single move instruction by
91  // any reasonable compiler.
92  std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_)));
93  return hash;
94  }
95 
96  constexpr const T &Value() const { return value_; }
97 
98  protected:
99  void SetValue(const T &f) { value_ = f; }
100 
101  static constexpr std::string_view GetPrecisionString() {
102  return sizeof(T) == 4 ? ""
103  : sizeof(T) == 1 ? "8"
104  : sizeof(T) == 2 ? "16"
105  : sizeof(T) == 8 ? "64"
106  : "unknown";
107  }
108 
109  private:
110  T value_;
111 };
112 
113 // Single-precision float weight.
115 
116 template <class T>
117 constexpr bool operator==(const FloatWeightTpl<T> &w1,
118  const FloatWeightTpl<T> &w2) {
119 #if (defined(__i386__) || defined(__x86_64__)) && !defined(__SSE2_MATH__)
120 // With i387 instructions, excess precision on a weight in an 80-bit
121 // register may cause it to compare unequal to that same weight when
122 // stored to memory. This breaks =='s reflexivity, in turn breaking
123 // NaturalLess.
124 #error "Please compile with -msse -mfpmath=sse, or equivalent."
125 #endif
126  return w1.Value() == w2.Value();
127 }
128 
129 // These seemingly unnecessary overloads are actually needed to make
130 // comparisons like FloatWeightTpl<float> == float compile. If only the
131 // templated version exists, the FloatWeightTpl<float>(float) conversion
132 // won't be found.
133 constexpr bool operator==(const FloatWeightTpl<float> &w1,
134  const FloatWeightTpl<float> &w2) {
135  return operator==<float>(w1, w2);
136 }
137 
138 constexpr bool operator==(const FloatWeightTpl<double> &w1,
139  const FloatWeightTpl<double> &w2) {
140  return operator==<double>(w1, w2);
141 }
142 
143 template <class T>
144 constexpr bool operator!=(const FloatWeightTpl<T> &w1,
145  const FloatWeightTpl<T> &w2) {
146  return !(w1 == w2);
147 }
148 
149 constexpr bool operator!=(const FloatWeightTpl<float> &w1,
150  const FloatWeightTpl<float> &w2) {
151  return operator!=<float>(w1, w2);
152 }
153 
154 constexpr bool operator!=(const FloatWeightTpl<double> &w1,
155  const FloatWeightTpl<double> &w2) {
156  return operator!=<double>(w1, w2);
157 }
158 
159 template <class T>
160 constexpr bool FloatApproxEqual(T w1, T w2, float delta = kDelta) {
161  return w1 <= w2 + delta && w2 <= w1 + delta;
162 }
163 
164 template <class T>
165 constexpr bool ApproxEqual(const FloatWeightTpl<T> &w1,
166  const FloatWeightTpl<T> &w2, float delta = kDelta) {
167  return FloatApproxEqual(w1.Value(), w2.Value(), delta);
168 }
169 
170 template <class T>
171 inline std::ostream &operator<<(std::ostream &strm,
172  const FloatWeightTpl<T> &w) {
173  if (w.Value() == FloatLimits<T>::PosInfinity()) {
174  return strm << "Infinity";
175  } else if (w.Value() == FloatLimits<T>::NegInfinity()) {
176  return strm << "-Infinity";
177  } else if (internal::IsNan(w.Value())) {
178  return strm << "BadNumber";
179  } else {
180  return strm << w.Value();
181  }
182 }
183 
184 template <class T>
185 inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) {
186  std::string s;
187  strm >> s;
188  if (s == "Infinity") {
190  } else if (s == "-Infinity") {
192  } else {
193  char *p;
194  T f = strtod(s.c_str(), &p);
195  if (p < s.c_str() + s.size()) {
196  strm.clear(std::ios::badbit);
197  } else {
198  w = FloatWeightTpl<T>(f);
199  }
200  }
201  return strm;
202 }
203 
204 // Tropical semiring: (min, +, inf, 0).
205 template <class T>
207  public:
208  using typename FloatWeightTpl<T>::ValueType;
212 
213  TropicalWeightTpl() noexcept : FloatWeightTpl<T>() {}
214 
215  constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
216 
217  static constexpr TropicalWeightTpl<T> Zero() { return Limits::PosInfinity(); }
218 
219  static constexpr TropicalWeightTpl<T> One() { return 0; }
220 
221  static constexpr TropicalWeightTpl<T> NoWeight() {
222  return Limits::NumberBad();
223  }
224 
225  static const std::string &Type() {
226  static const std::string *const type = new std::string(
228  return *type;
229  }
230 
231  constexpr bool Member() const {
232  // All floating point values except for NaNs and negative infinity are valid
233  // tropical weights.
234  //
235  // Testing membership of a given value can be done by simply checking that
236  // it is strictly greater than negative infinity, which fails for negative
237  // infinity itself but also for NaNs. This can usually be accomplished in a
238  // single instruction (such as *UCOMI* on x86) without branching logic.
239  //
240  // An additional wrinkle involves constexpr correctness of floating point
241  // comparisons against NaN. GCC is uneven when it comes to which expressions
242  // it considers compile-time constants. In particular, current versions of
243  // GCC do not always consider (nan < inf) to be a constant expression, but
244  // do consider (inf < nan) to be a constant expression. (See
245  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88173 and
246  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88683 for details.) In order
247  // to allow Member() to be a constexpr function accepted by GCC, we write
248  // the comparison here as (-inf < v).
249  return Limits::NegInfinity() < Value();
250  }
251 
252  TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
253  if (!Member() || Value() == Limits::PosInfinity()) {
254  return *this;
255  } else {
256  return TropicalWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
257  }
258  }
259 
260  constexpr TropicalWeightTpl<T> Reverse() const { return *this; }
261 
262  static constexpr uint64_t Properties() {
264  }
265 };
266 
267 // Single precision tropical weight.
269 
270 template <class T>
272  const TropicalWeightTpl<T> &w2) {
273  return (!w1.Member() || !w2.Member()) ? TropicalWeightTpl<T>::NoWeight()
274  : w1.Value() < w2.Value() ? w1
275  : w2;
276 }
277 
278 // See comment at operator==(FloatWeightTpl<float>, FloatWeightTpl<float>)
279 // for why these overloads are present.
281  const TropicalWeightTpl<float> &w2) {
282  return Plus<float>(w1, w2);
283 }
284 
286  const TropicalWeightTpl<double> &w2) {
287  return Plus<double>(w1, w2);
288 }
289 
290 template <class T>
292  const TropicalWeightTpl<T> &w2) {
293  // The following is safe in the context of the Tropical (and Log) semiring
294  // for all IEEE floating point values, including infinities and NaNs,
295  // because:
296  //
297  // * If one or both of the floating point Values is NaN and hence not a
298  // Member, the result of addition below is NaN, so the result is not a
299  // Member. This supersedes all other cases, so we only consider non-NaN
300  // values next.
301  //
302  // * If both Values are finite, there is no issue.
303  //
304  // * If one of the Values is infinite, or if both are infinities with the
305  // same sign, the result of floating point addition is the same infinity,
306  // so there is no issue.
307  //
308  // * If both of the Values are infinities with opposite signs, the result of
309  // adding IEEE floating point -inf + inf is NaN and hence not a Member. But
310  // since -inf was not a Member to begin with, returning a non-Member result
311  // is fine as well.
312  return TropicalWeightTpl<T>(w1.Value() + w2.Value());
313 }
314 
316  const TropicalWeightTpl<float> &w2) {
317  return Times<float>(w1, w2);
318 }
319 
321  const TropicalWeightTpl<double> &w2) {
322  return Times<double>(w1, w2);
323 }
324 
325 template <class T>
327  const TropicalWeightTpl<T> &w2,
328  DivideType typ = DIVIDE_ANY) {
329  // The following is safe in the context of the Tropical (and Log) semiring
330  // for all IEEE floating point values, including infinities and NaNs,
331  // because:
332  //
333  // * If one or both of the floating point Values is NaN and hence not a
334  // Member, the result of subtraction below is NaN, so the result is not a
335  // Member. This supersedes all other cases, so we only consider non-NaN
336  // values next.
337  //
338  // * If both Values are finite, there is no issue.
339  //
340  // * If w2.Value() is -inf (and hence w2 is not a Member), the result of ?:
341  // below is NoWeight, which is not a Member.
342  //
343  // Whereas in IEEE floating point semantics 0/inf == 0, this does not carry
344  // over to this semiring (since TropicalWeight(-inf) would be the analogue
345  // of floating point inf) and instead Divide(Zero(), TropicalWeight(-inf))
346  // is NoWeight().
347  //
348  // * If w2.Value() is inf (and hence w2 is Zero), the resulting floating
349  // point value is either NaN (if w1 is Zero or if w1.Value() is NaN) and
350  // hence not a Member, or it is -inf and hence not a Member; either way,
351  // division by Zero results in a non-Member result.
352  using Weight = TropicalWeightTpl<T>;
353  return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
354 }
355 
357  const TropicalWeightTpl<float> &w2,
358  DivideType typ = DIVIDE_ANY) {
359  return Divide<float>(w1, w2, typ);
360 }
361 
363  const TropicalWeightTpl<double> &w2,
364  DivideType typ = DIVIDE_ANY) {
365  return Divide<double>(w1, w2, typ);
366 }
367 
368 // Power(w, n) calculates the n-th power of w with respect to semiring Times.
369 //
370 // In the case of the Tropical (and Log) semiring, the exponent n is not
371 // restricted to be an integer. It can be a floating point value, for example.
372 //
373 // In weight.h, a narrower and hence more broadly applicable version of
374 // Power(w, n) is defined for arbitrary weight types and non-negative integer
375 // exponents n (of type size_t) and implemented in terms of repeated
376 // multiplication using Times.
377 //
378 // Without further provisions this means that, when an expression such as
379 //
380 // Power(TropicalWeightTpl<float>::One(), static_cast<size_t>(2))
381 //
382 // is specified, the overload of Power() is ambiguous. The template function
383 // below could be instantiated as
384 //
385 // Power<float, size_t>(const TropicalWeightTpl<float> &, size_t)
386 //
387 // and the template function defined in weight.h (further specialized below)
388 // could be instantiated as
389 //
390 // Power<TropicalWeightTpl<float>>(const TropicalWeightTpl<float> &, size_t)
391 //
392 // That would lead to two definitions with identical signatures, which results
393 // in a compilation error. To avoid that, we hide the definition of Power<T, V>
394 // when V is size_t, so only Power<W> is visible. Power<W> is further
395 // specialized to Power<TropicalWeightTpl<...>>, and the overloaded definition
396 // of Power<T, V> is made conditionally available only to that template
397 // specialization.
398 
399 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
400  typename std::enable_if_t<Enable> * = nullptr>
402  using Weight = TropicalWeightTpl<T>;
403  return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
404  : (n == 0 || w == Weight::One()) ? Weight::One()
405  : Weight(w.Value() * n);
406 }
407 
408 // Specializes the library-wide template to use the above implementation; rules
409 // of function template instantiation require this be a full instantiation.
410 
411 template <>
412 constexpr TropicalWeightTpl<float> Power<TropicalWeightTpl<float>>(
413  const TropicalWeightTpl<float> &weight, size_t n) {
414  return Power<float, size_t, true>(weight, n);
415 }
416 
417 template <>
418 constexpr TropicalWeightTpl<double> Power<TropicalWeightTpl<double>>(
419  const TropicalWeightTpl<double> &weight, size_t n) {
420  return Power<double, size_t, true>(weight, n);
421 }
422 
423 // Log semiring: (log(e^-x + e^-y), +, inf, 0).
424 template <class T>
425 class LogWeightTpl : public FloatWeightTpl<T> {
426  public:
427  using typename FloatWeightTpl<T>::ValueType;
431 
432  LogWeightTpl() noexcept : FloatWeightTpl<T>() {}
433 
434  constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
435 
436  static constexpr LogWeightTpl Zero() { return Limits::PosInfinity(); }
437 
438  static constexpr LogWeightTpl One() { return 0; }
439 
440  static constexpr LogWeightTpl NoWeight() { return Limits::NumberBad(); }
441 
442  static const std::string &Type() {
443  static const std::string *const type = new std::string(
445  return *type;
446  }
447 
448  constexpr bool Member() const {
449  // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
450  return Limits::NegInfinity() < Value();
451  }
452 
453  LogWeightTpl<T> Quantize(float delta = kDelta) const {
454  if (!Member() || Value() == Limits::PosInfinity()) {
455  return *this;
456  } else {
457  return LogWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
458  }
459  }
460 
461  constexpr LogWeightTpl<T> Reverse() const { return *this; }
462 
463  static constexpr uint64_t Properties() {
465  }
466 };
467 
468 // Single-precision log weight.
470 
471 // Double-precision log weight.
473 
474 namespace internal {
475 
476 // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming y >= x.
477 inline double LogPosExp(double x) {
478  DCHECK(!(x < 0)); // NB: NaN values are allowed.
479  return log1p(exp(-x));
480 }
481 
482 // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming y >= x.
483 inline double LogNegExp(double x) {
484  DCHECK(!(x < 0)); // NB: NaN values are allowed.
485  return log1p(-exp(-x));
486 }
487 
488 // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
489 // Kahan compensated summation provides an error bound that is
490 // independent of the number of addends. Assumes b >= a;
491 // c is the compensation.
492 inline double KahanLogSum(double a, double b, double *c) {
493  DCHECK_GE(b, a);
494  double y = -LogPosExp(b - a) - *c;
495  double t = a + y;
496  *c = (t - a) - y;
497  return t;
498 }
499 
500 // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
501 // Kahan compensated summation provides an error bound that is
502 // independent of the number of addends. Assumes b > a;
503 // c is the compensation.
504 inline double KahanLogDiff(double a, double b, double *c) {
505  DCHECK_GT(b, a);
506  double y = -LogNegExp(b - a) - *c;
507  double t = a + y;
508  *c = (t - a) - y;
509  return t;
510 }
511 
512 } // namespace internal
513 
514 template <class T>
516  const LogWeightTpl<T> &w2) {
517  using Limits = FloatLimits<T>;
518  const T f1 = w1.Value();
519  const T f2 = w2.Value();
520  if (f1 == Limits::PosInfinity()) {
521  return w2;
522  } else if (f2 == Limits::PosInfinity()) {
523  return w1;
524  } else if (f1 > f2) {
525  return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2));
526  } else {
527  return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1));
528  }
529 }
530 
532  const LogWeightTpl<float> &w2) {
533  return Plus<float>(w1, w2);
534 }
535 
537  const LogWeightTpl<double> &w2) {
538  return Plus<double>(w1, w2);
539 }
540 
541 // Returns NoWeight if w1 < w2 (w1.Value() > w2.Value()).
542 template <class T>
544  const LogWeightTpl<T> &w2) {
545  using Limits = FloatLimits<T>;
546  const T f1 = w1.Value();
547  const T f2 = w2.Value();
548  if (f1 > f2) return LogWeightTpl<T>::NoWeight();
549  if (f2 == Limits::PosInfinity()) return f1;
550  const T d = f2 - f1;
551  if (d == Limits::PosInfinity()) return f1;
552  return f1 - internal::LogNegExp(d);
553 }
554 
556  const LogWeightTpl<float> &w2) {
557  return Minus<float>(w1, w2);
558 }
559 
561  const LogWeightTpl<double> &w2) {
562  return Minus<double>(w1, w2);
563 }
564 
565 template <class T>
567  const LogWeightTpl<T> &w2) {
568  // The comments for Times(Tropical...) above apply here unchanged.
569  return LogWeightTpl<T>(w1.Value() + w2.Value());
570 }
571 
573  const LogWeightTpl<float> &w2) {
574  return Times<float>(w1, w2);
575 }
576 
578  const LogWeightTpl<double> &w2) {
579  return Times<double>(w1, w2);
580 }
581 
582 template <class T>
584  const LogWeightTpl<T> &w2,
585  DivideType typ = DIVIDE_ANY) {
586  // The comments for Divide(Tropical...) above apply here unchanged.
587  using Weight = LogWeightTpl<T>;
588  return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
589 }
590 
592  const LogWeightTpl<float> &w2,
593  DivideType typ = DIVIDE_ANY) {
594  return Divide<float>(w1, w2, typ);
595 }
596 
598  const LogWeightTpl<double> &w2,
599  DivideType typ = DIVIDE_ANY) {
600  return Divide<double>(w1, w2, typ);
601 }
602 
603 // The comments for Power<>(Tropical...) above apply here unchanged.
604 
605 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
606  typename std::enable_if_t<Enable> * = nullptr>
607 constexpr LogWeightTpl<T> Power(const LogWeightTpl<T> &w, V n) {
608  using Weight = LogWeightTpl<T>;
609  return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
610  : (n == 0 || w == Weight::One()) ? Weight::One()
611  : Weight(w.Value() * n);
612 }
613 
614 // Specializes the library-wide template to use the above implementation; rules
615 // of function template instantiation require this be a full instantiation.
616 
617 template <>
618 constexpr LogWeightTpl<float> Power<LogWeightTpl<float>>(
619  const LogWeightTpl<float> &weight, size_t n) {
620  return Power<float, size_t, true>(weight, n);
621 }
622 
623 template <>
624 constexpr LogWeightTpl<double> Power<LogWeightTpl<double>>(
625  const LogWeightTpl<double> &weight, size_t n) {
626  return Power<double, size_t, true>(weight, n);
627 }
628 
629 // Specialization using the Kahan compensated summation.
630 template <class T>
631 class Adder<LogWeightTpl<T>> {
632  public:
634 
635  explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
636 
637  Weight Add(const Weight &w) {
638  using Limits = FloatLimits<T>;
639  const T f = w.Value();
640  if (f == Limits::PosInfinity()) {
641  return Sum();
642  } else if (sum_ == Limits::PosInfinity()) {
643  sum_ = f;
644  c_ = 0.0;
645  } else if (f > sum_) {
646  sum_ = internal::KahanLogSum(sum_, f, &c_);
647  } else {
648  sum_ = internal::KahanLogSum(f, sum_, &c_);
649  }
650  return Sum();
651  }
652 
653  Weight Sum() const { return Weight(sum_); }
654 
655  void Reset(Weight w = Weight::Zero()) {
656  sum_ = w.Value();
657  c_ = 0.0;
658  }
659 
660  private:
661  double sum_;
662  double c_; // Kahan compensation.
663 };
664 
665 // Real semiring: (+, *, 0, 1).
666 template <class T>
667 class RealWeightTpl : public FloatWeightTpl<T> {
668  public:
669  using typename FloatWeightTpl<T>::ValueType;
673 
674  RealWeightTpl() noexcept : FloatWeightTpl<T>() {}
675 
676  constexpr RealWeightTpl(T f) : FloatWeightTpl<T>(f) {}
677 
678  static constexpr RealWeightTpl Zero() { return 0; }
679 
680  static constexpr RealWeightTpl One() { return 1; }
681 
682  static constexpr RealWeightTpl NoWeight() { return Limits::NumberBad(); }
683 
684  static const std::string &Type() {
685  static const std::string *const type = new std::string(
687  return *type;
688  }
689 
690  constexpr bool Member() const {
691  // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
692  return Limits::NegInfinity() < Value();
693  }
694 
695  RealWeightTpl<T> Quantize(float delta = kDelta) const {
696  if (!Member() || Value() == Limits::PosInfinity()) {
697  return *this;
698  } else {
699  return RealWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
700  }
701  }
702 
703  constexpr RealWeightTpl<T> Reverse() const { return *this; }
704 
705  static constexpr uint64_t Properties() {
707  }
708 };
709 
710 // Single-precision log weight.
712 
713 // Double-precision log weight.
715 
716 namespace internal {
717 
718 // a + b = KahanRealSum(a, b, ...).
719 // Kahan compensated summation provides an error bound that is
720 // independent of the number of addends. c is the compensation.
721 inline double KahanRealSum(double a, double b, double *c) {
722  double y = b - *c;
723  double t = a + y;
724  *c = (t - a) - y;
725  return t;
726 }
727 
728 }; // namespace internal
729 
730 // The comments for Times(Tropical...) above apply here unchanged.
731 template <class T>
733  const RealWeightTpl<T> &w2) {
734  const T f1 = w1.Value();
735  const T f2 = w2.Value();
736  return RealWeightTpl<T>(f1 + f2);
737 }
738 
740  const RealWeightTpl<float> &w2) {
741  return Plus<float>(w1, w2);
742 }
743 
745  const RealWeightTpl<double> &w2) {
746  return Plus<double>(w1, w2);
747 }
748 
749 template <class T>
751  const RealWeightTpl<T> &w2) {
752  // The comments for Divide(Tropical...) above apply here unchanged.
753  const T f1 = w1.Value();
754  const T f2 = w2.Value();
755  return RealWeightTpl<T>(f1 - f2);
756 }
757 
759  const RealWeightTpl<float> &w2) {
760  return Minus<float>(w1, w2);
761 }
762 
764  const RealWeightTpl<double> &w2) {
765  return Minus<double>(w1, w2);
766 }
767 
768 // The comments for Times(Tropical...) above apply here similarly.
769 template <class T>
771  const RealWeightTpl<T> &w2) {
772  return RealWeightTpl<T>(w1.Value() * w2.Value());
773 }
774 
776  const RealWeightTpl<float> &w2) {
777  return Times<float>(w1, w2);
778 }
779 
781  const RealWeightTpl<double> &w2) {
782  return Times<double>(w1, w2);
783 }
784 
785 template <class T>
787  const RealWeightTpl<T> &w2,
788  DivideType typ = DIVIDE_ANY) {
789  using Weight = RealWeightTpl<T>;
790  return w2.Member() ? Weight(w1.Value() / w2.Value()) : Weight::NoWeight();
791 }
792 
794  const RealWeightTpl<float> &w2,
795  DivideType typ = DIVIDE_ANY) {
796  return Divide<float>(w1, w2, typ);
797 }
798 
800  const RealWeightTpl<double> &w2,
801  DivideType typ = DIVIDE_ANY) {
802  return Divide<double>(w1, w2, typ);
803 }
804 
805 // The comments for Power<>(Tropical...) above apply here unchanged.
806 
807 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
808  typename std::enable_if_t<Enable> * = nullptr>
809 constexpr RealWeightTpl<T> Power(const RealWeightTpl<T> &w, V n) {
810  using Weight = RealWeightTpl<T>;
811  return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
812  : (n == 0 || w == Weight::One()) ? Weight::One()
813  : Weight(pow(w.Value(), n));
814 }
815 
816 // Specializes the library-wide template to use the above implementation; rules
817 // of function template instantiation require this be a full instantiation.
818 
819 template <>
820 constexpr RealWeightTpl<float> Power<RealWeightTpl<float>>(
821  const RealWeightTpl<float> &weight, size_t n) {
822  return Power<float, size_t, true>(weight, n);
823 }
824 
825 template <>
826 constexpr RealWeightTpl<double> Power<RealWeightTpl<double>>(
827  const RealWeightTpl<double> &weight, size_t n) {
828  return Power<double, size_t, true>(weight, n);
829 }
830 
831 // Specialization using the Kahan compensated summation.
832 template <class T>
833 class Adder<RealWeightTpl<T>> {
834  public:
836 
837  explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
838 
839  Weight Add(const Weight &w) {
840  using Limits = FloatLimits<T>;
841  const T f = w.Value();
842  if (f == Limits::PosInfinity()) {
843  sum_ = f;
844  } else if (sum_ == Limits::PosInfinity()) {
845  return sum_;
846  } else {
847  sum_ = internal::KahanRealSum(sum_, f, &c_);
848  }
849  return Sum();
850  }
851 
852  Weight Sum() const { return Weight(sum_); }
853 
854  void Reset(Weight w = Weight::Zero()) {
855  sum_ = w.Value();
856  c_ = 0.0;
857  }
858 
859  private:
860  double sum_;
861  double c_; // Kahan compensation.
862 };
863 
864 // MinMax semiring: (min, max, inf, -inf).
865 template <class T>
866 class MinMaxWeightTpl : public FloatWeightTpl<T> {
867  public:
868  using typename FloatWeightTpl<T>::ValueType;
872 
873  MinMaxWeightTpl() noexcept : FloatWeightTpl<T>() {}
874 
875  constexpr MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} // NOLINT
876 
877  static constexpr MinMaxWeightTpl Zero() { return Limits::PosInfinity(); }
878 
879  static constexpr MinMaxWeightTpl One() { return Limits::NegInfinity(); }
880 
881  static constexpr MinMaxWeightTpl NoWeight() { return Limits::NumberBad(); }
882 
883  static const std::string &Type() {
884  static const std::string *const type = new std::string(
886  return *type;
887  }
888 
889  // Fails for IEEE NaN.
890  constexpr bool Member() const { return !internal::IsNan(Value()); }
891 
892  MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
893  // If one of infinities, or a NaN.
894  if (!Member() || Value() == Limits::NegInfinity() ||
895  Value() == Limits::PosInfinity()) {
896  return *this;
897  } else {
898  return MinMaxWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
899  }
900  }
901 
902  constexpr MinMaxWeightTpl<T> Reverse() const { return *this; }
903 
904  static constexpr uint64_t Properties() {
906  }
907 };
908 
909 // Single-precision min-max weight.
911 
912 // Min.
913 template <class T>
915  const MinMaxWeightTpl<T> &w2) {
916  return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
917  : w1.Value() < w2.Value() ? w1
918  : w2;
919 }
920 
922  const MinMaxWeightTpl<float> &w2) {
923  return Plus<float>(w1, w2);
924 }
925 
927  const MinMaxWeightTpl<double> &w2) {
928  return Plus<double>(w1, w2);
929 }
930 
931 // Max.
932 template <class T>
934  const MinMaxWeightTpl<T> &w2) {
935  return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
936  : w1.Value() >= w2.Value() ? w1
937  : w2;
938 }
939 
941  const MinMaxWeightTpl<float> &w2) {
942  return Times<float>(w1, w2);
943 }
944 
946  const MinMaxWeightTpl<double> &w2) {
947  return Times<double>(w1, w2);
948 }
949 
950 // Defined only for special cases.
951 template <class T>
953  const MinMaxWeightTpl<T> &w2,
954  DivideType typ = DIVIDE_ANY) {
955  return w1.Value() >= w2.Value() ? w1 : MinMaxWeightTpl<T>::NoWeight();
956 }
957 
959  const MinMaxWeightTpl<float> &w2,
960  DivideType typ = DIVIDE_ANY) {
961  return Divide<float>(w1, w2, typ);
962 }
963 
965  const MinMaxWeightTpl<double> &w2,
966  DivideType typ = DIVIDE_ANY) {
967  return Divide<double>(w1, w2, typ);
968 }
969 
970 // Converts to tropical.
971 template <>
973  constexpr TropicalWeight operator()(const LogWeight &w) const {
974  return w.Value();
975  }
976 };
977 
978 template <>
980  constexpr TropicalWeight operator()(const Log64Weight &w) const {
981  return w.Value();
982  }
983 };
984 
985 // Converts to log.
986 template <>
988  constexpr LogWeight operator()(const TropicalWeight &w) const {
989  return w.Value();
990  }
991 };
992 
993 template <>
995  LogWeight operator()(const RealWeight &w) const { return -log(w.Value()); }
996 };
997 
998 template <>
1000  LogWeight operator()(const Real64Weight &w) const { return -log(w.Value()); }
1001 };
1002 
1003 template <>
1005  constexpr LogWeight operator()(const Log64Weight &w) const {
1006  return w.Value();
1007  }
1008 };
1009 
1010 // Converts to log64.
1011 template <>
1013  constexpr Log64Weight operator()(const TropicalWeight &w) const {
1014  return w.Value();
1015  }
1016 };
1017 
1018 template <>
1020  Log64Weight operator()(const RealWeight &w) const { return -log(w.Value()); }
1021 };
1022 
1023 template <>
1026  return -log(w.Value());
1027  }
1028 };
1029 
1030 template <>
1032  constexpr Log64Weight operator()(const LogWeight &w) const {
1033  return w.Value();
1034  }
1035 };
1036 
1037 // Converts to real.
1038 template <>
1040  RealWeight operator()(const LogWeight &w) const { return exp(-w.Value()); }
1041 };
1042 
1043 template <>
1045  RealWeight operator()(const Log64Weight &w) const { return exp(-w.Value()); }
1046 };
1047 
1048 template <>
1050  constexpr RealWeight operator()(const Real64Weight &w) const {
1051  return w.Value();
1052  }
1053 };
1054 
1055 // Converts to real64
1056 template <>
1058  Real64Weight operator()(const LogWeight &w) const { return exp(-w.Value()); }
1059 };
1060 
1061 template <>
1064  return exp(-w.Value());
1065  }
1066 };
1067 
1068 template <>
1070  constexpr Real64Weight operator()(const RealWeight &w) const {
1071  return w.Value();
1072  }
1073 };
1074 
1075 // This function object returns random integers chosen from [0,
1076 // num_random_weights). The allow_zero argument determines whether Zero() and
1077 // zero divisors should be returned in the random weight generation. This is
1078 // intended primary for testing.
1079 template <class Weight>
1081  public:
1083  uint64_t seed = std::random_device()(), bool allow_zero = true,
1084  const size_t num_random_weights = kNumRandomWeights)
1085  : rand_(seed),
1086  allow_zero_(allow_zero),
1087  num_random_weights_(num_random_weights) {}
1088 
1089  Weight operator()() const {
1090  const int sample = std::uniform_int_distribution<>(
1091  0, num_random_weights_ + allow_zero_ - 1)(rand_);
1092  if (allow_zero_ && sample == num_random_weights_) return Weight::Zero();
1093  return Weight(sample);
1094  }
1095 
1096  private:
1097  mutable std::mt19937_64 rand_;
1098  const bool allow_zero_;
1099  const size_t num_random_weights_;
1100 };
1101 
1102 template <class T>
1104  : public FloatWeightGenerate<TropicalWeightTpl<T>> {
1105  public:
1108 
1109  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1110  bool allow_zero = true,
1111  size_t num_random_weights = kNumRandomWeights)
1112  : Generate(seed, allow_zero, num_random_weights) {}
1113 
1114  Weight operator()() const { return Weight(Generate::operator()()); }
1115 };
1116 
1117 template <class T>
1119  : public FloatWeightGenerate<LogWeightTpl<T>> {
1120  public:
1123 
1124  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1125  bool allow_zero = true,
1126  size_t num_random_weights = kNumRandomWeights)
1127  : Generate(seed, allow_zero, num_random_weights) {}
1128 
1129  Weight operator()() const { return Weight(Generate::operator()()); }
1130 };
1131 
1132 template <class T>
1134  : public FloatWeightGenerate<RealWeightTpl<T>> {
1135  public:
1138 
1139  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1140  bool allow_zero = true,
1141  size_t num_random_weights = kNumRandomWeights)
1142  : Generate(seed, allow_zero, num_random_weights) {}
1143 
1144  Weight operator()() const { return Weight(Generate::operator()()); }
1145 };
1146 
1147 // This function object returns random integers chosen from [0,
1148 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
1149 // zero divisors should be returned in the random weight generation. This is
1150 // intended primary for testing.
1151 template <class T>
1153  public:
1155 
1156  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1157  bool allow_zero = true,
1158  size_t num_random_weights = kNumRandomWeights)
1159  : rand_(seed),
1160  allow_zero_(allow_zero),
1161  num_random_weights_(num_random_weights) {}
1162 
1163  Weight operator()() const {
1164  const int sample = std::uniform_int_distribution<>(
1165  -num_random_weights_, num_random_weights_ + allow_zero_)(rand_);
1166  if (allow_zero_ && sample == 0) {
1167  return Weight::Zero();
1168  } else if (sample == -num_random_weights_) {
1169  return Weight::One();
1170  } else {
1171  return Weight(sample);
1172  }
1173  }
1174 
1175  private:
1176  mutable std::mt19937_64 rand_;
1177  const bool allow_zero_;
1178  const size_t num_random_weights_;
1179 };
1180 
1181 } // namespace fst
1182 
1183 #endif // FST_FLOAT_WEIGHT_H_
RealWeight operator()(const Log64Weight &w) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Weight Add(const Weight &w)
Definition: float-weight.h:839
static constexpr RealWeightTpl Zero()
Definition: float-weight.h:678
void Reset(Weight w=Weight::Zero())
Definition: float-weight.h:655
static constexpr TropicalWeightTpl< T > Zero()
Definition: float-weight.h:217
constexpr TropicalWeight operator()(const LogWeight &w) const
Definition: float-weight.h:973
FloatWeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, const size_t num_random_weights=kNumRandomWeights)
constexpr RealWeight operator()(const Real64Weight &w) const
constexpr Log64Weight operator()(const LogWeight &w) const
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
constexpr bool Member() const
Definition: float-weight.h:690
Adder(Weight w=Weight::Zero())
Definition: float-weight.h:837
RealWeight operator()(const LogWeight &w) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
static constexpr std::string_view GetPrecisionString()
Definition: float-weight.h:101
RealWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:695
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:436
static constexpr MinMaxWeightTpl NoWeight()
Definition: float-weight.h:881
static constexpr MinMaxWeightTpl One()
Definition: float-weight.h:879
constexpr uint64_t kIdempotent
Definition: weight.h:147
RealWeightTpl() noexcept
Definition: float-weight.h:674
#define DCHECK_GT(x, y)
Definition: log.h:77
std::ostream & Write(std::ostream &strm) const
Definition: float-weight.h:82
Weight operator()() const
constexpr Log64Weight operator()(const TropicalWeight &w) const
void SetValue(const T &f)
Definition: float-weight.h:99
static constexpr uint64_t Properties()
Definition: float-weight.h:262
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
static constexpr T PosInfinity()
Definition: float-weight.h:61
constexpr RealWeightTpl< T > Reverse() const
Definition: float-weight.h:703
static constexpr uint64_t Properties()
Definition: float-weight.h:904
TropicalWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:252
constexpr uint64_t kRightSemiring
Definition: weight.h:139
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:228
static constexpr LogWeightTpl One()
Definition: float-weight.h:438
static constexpr TropicalWeightTpl< T > One()
Definition: float-weight.h:219
constexpr bool FloatApproxEqual(T w1, T w2, float delta=kDelta)
Definition: float-weight.h:160
static constexpr T NumberBad()
Definition: float-weight.h:67
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:185
Log64Weight operator()(const Real64Weight &w) const
Real64Weight operator()(const LogWeight &w) const
constexpr bool IsNan(T value)
Definition: float-weight.h:52
constexpr TropicalWeightTpl< T > Reverse() const
Definition: float-weight.h:260
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:54
constexpr LogWeight operator()(const TropicalWeight &w) const
Definition: float-weight.h:988
MinMaxWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:892
static constexpr LogWeightTpl NoWeight()
Definition: float-weight.h:440
constexpr uint64_t kCommutative
Definition: weight.h:144
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:71
LogWeight operator()(const RealWeight &w) const
Definition: float-weight.h:995
Adder(Weight w=Weight::Zero())
Definition: float-weight.h:635
void Reset(Weight w=Weight::Zero())
Definition: float-weight.h:854
constexpr MinMaxWeightTpl(T f)
Definition: float-weight.h:875
static constexpr RealWeightTpl NoWeight()
Definition: float-weight.h:682
static const std::string & Type()
Definition: float-weight.h:684
LogWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:453
static constexpr TropicalWeightTpl< T > NoWeight()
Definition: float-weight.h:221
static constexpr uint64_t Properties()
Definition: float-weight.h:705
static constexpr RealWeightTpl One()
Definition: float-weight.h:680
constexpr LogWeightTpl< T > Reverse() const
Definition: float-weight.h:461
double LogPosExp(double x)
Definition: float-weight.h:477
double KahanLogSum(double a, double b, double *c)
Definition: float-weight.h:492
TropicalWeightTpl() noexcept
Definition: float-weight.h:213
LogWeight operator()(const Real64Weight &w) const
std::istream & Read(std::istream &strm)
Definition: float-weight.h:80
constexpr uint64_t kPath
Definition: weight.h:150
static const std::string & Type()
Definition: float-weight.h:225
Log64Weight operator()(const RealWeight &w) const
Weight Add(const Weight &w)
Definition: float-weight.h:637
LogWeightTpl() noexcept
Definition: float-weight.h:432
constexpr bool Member() const
Definition: float-weight.h:231
double LogNegExp(double x)
Definition: float-weight.h:483
constexpr RealWeightTpl(T f)
Definition: float-weight.h:676
MinMaxWeightTpl() noexcept
Definition: float-weight.h:873
LogWeightTpl< T > Minus(const LogWeightTpl< T > &w1, const LogWeightTpl< T > &w2)
Definition: float-weight.h:543
constexpr bool Member() const
Definition: float-weight.h:890
std::string StrCat(const StringOrInt &s1, const StringOrInt &s2)
Definition: compat.h:286
static const std::string & Type()
Definition: float-weight.h:442
constexpr bool Member() const
Definition: float-weight.h:448
size_t Hash() const
Definition: float-weight.h:86
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:51
constexpr size_t kNumRandomWeights
Definition: weight.h:154
static constexpr MinMaxWeightTpl Zero()
Definition: float-weight.h:877
constexpr TropicalWeight operator()(const Log64Weight &w) const
Definition: float-weight.h:980
static const std::string & Type()
Definition: float-weight.h:883
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:80
constexpr MinMaxWeightTpl< T > Reverse() const
Definition: float-weight.h:902
#define DCHECK(x)
Definition: log.h:74
static constexpr uint64_t Properties()
Definition: float-weight.h:463
constexpr TropicalWeightTpl< T > Power(const TropicalWeightTpl< T > &w, V n)
Definition: float-weight.h:401
#define DCHECK_GE(x, y)
Definition: log.h:79
constexpr LogWeight operator()(const Log64Weight &w) const
DivideType
Definition: weight.h:165
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
constexpr float kDelta
Definition: weight.h:133
constexpr Real64Weight operator()(const RealWeight &w) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Real64Weight operator()(const Log64Weight &w) const
constexpr const T & Value() const
Definition: float-weight.h:96
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
constexpr LogWeightTpl(T f)
Definition: float-weight.h:434
constexpr TropicalWeightTpl(T f)
Definition: float-weight.h:215
double KahanLogDiff(double a, double b, double *c)
Definition: float-weight.h:504
static constexpr T NegInfinity()
Definition: float-weight.h:65
double KahanRealSum(double a, double b, double *c)
Definition: float-weight.h:721